diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td index a46c5119296bf..b6ec45ced3512 100644 --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -781,6 +781,7 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> { def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> { let summary = "floating point division remainder operation"; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp index 04a41acc041bc..f1caac9943e64 100644 --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -756,6 +756,19 @@ OpFoldResult arith::DivFOp::fold(ArrayRef operands) { operands, [](const APFloat &a, const APFloat &b) { return a / b; }); } +//===----------------------------------------------------------------------===// +// RemFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::RemFOp::fold(ArrayRef operands) { + return constFoldBinaryOp(operands, + [](const APFloat &a, const APFloat &b) { + APFloat Result(a); + (void)Result.remainder(b); + return Result; + }); +} + //===----------------------------------------------------------------------===// // Utility functions for verifying cast ops //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir index 1a78bf7ea0161..a9472c4264257 100644 --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -1374,3 +1374,25 @@ func.func @test_remsi_1(%arg : vector<4xi32>) -> (vector<4xi32>) { %0 = arith.remsi %arg, %v : vector<4xi32> return %0 : vector<4xi32> } + +// ----- + +// CHECK-LABEL: @test_remf( +// CHECK: %[[res:.+]] = arith.constant -1.000000e+00 : f32 +// CHECK: return %[[res]] +func.func @test_remf() -> (f32) { + %v1 = arith.constant 3.0 : f32 + %v2 = arith.constant 2.0 : f32 + %0 = arith.remf %v1, %v2 : f32 + return %0 : f32 +} + +// CHECK-LABEL: @test_remf_vec( +// CHECK: %[[res:.+]] = arith.constant dense<[1.000000e+00, 0.000000e+00, -1.000000e+00, 0.000000e+00]> : vector<4xf32> +// CHECK: return %[[res]] +func.func @test_remf_vec() -> (vector<4xf32>) { + %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32> + %v2 = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf32> + %0 = arith.remf %v1, %v2 : vector<4xf32> + return %0 : vector<4xf32> +}