diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index 77b10cec48d8e..ff5f7f685903f 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -43,6 +43,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value x = op.getLhs(); + arith::FastMathFlags fmf = op.getFastmathAttr().getValue(); FloatAttr scalarExponent; DenseFPElementsAttr vectorExponent; @@ -66,7 +67,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, // Maybe broadcasts scalar value into vector type compatible with `op`. auto bcast = [&](Value value) -> Value { if (auto vec = dyn_cast(op.getType())) - return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value); + return vector::BroadcastOp::create(rewriter, loc, vec, value); return value; }; @@ -78,15 +79,14 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, // Replace `pow(x, 2.0)` with `x * x`. if (isExponentValue(2.0)) { - rewriter.replaceOpWithNewOp(op, ValueRange({x, x})); + rewriter.replaceOpWithNewOp(op, x, x, fmf); return success(); } // Replace `pow(x, 3.0)` with `x * x * x`. if (isExponentValue(3.0)) { - Value square = - arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x})); - rewriter.replaceOpWithNewOp(op, ValueRange({x, square})); + Value square = arith::MulFOp::create(rewriter, loc, x, x, fmf); + rewriter.replaceOpWithNewOp(op, x, square, fmf); return success(); } @@ -95,28 +95,27 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, Value one = arith::ConstantOp::create( rewriter, loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); - rewriter.replaceOpWithNewOp(op, ValueRange({bcast(one), x})); + rewriter.replaceOpWithNewOp(op, bcast(one), x, fmf); return success(); } // Replace `pow(x, 0.5)` with `sqrt(x)`. if (isExponentValue(0.5)) { - rewriter.replaceOpWithNewOp(op, x); + rewriter.replaceOpWithNewOp(op, x, fmf); return success(); } // Replace `pow(x, -0.5)` with `rsqrt(x)`. if (isExponentValue(-0.5)) { - rewriter.replaceOpWithNewOp(op, x); + rewriter.replaceOpWithNewOp(op, x, fmf); return success(); } // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`. if (isExponentValue(0.75)) { - Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x); - Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf); - rewriter.replaceOpWithNewOp(op, - ValueRange{powHalf, powQuarter}); + Value powHalf = math::SqrtOp::create(rewriter, loc, x, fmf); + Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf, fmf); + rewriter.replaceOpWithNewOp(op, powHalf, powQuarter, fmf); return success(); } diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir index e0e2b9853a2a1..7342600748967 100644 --- a/mlir/test/Dialect/Math/algebraic-simplification.mlir +++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -22,6 +22,18 @@ func.func @pow_square(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) return %0, %1 : f32, vector<4xf32> } +// CHECK-LABEL: @pow_square_fast +func.func @pow_square_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[SCALAR:.*]] = arith.mulf %arg0, %arg0 fastmath + // CHECK: %[[VECTOR:.*]] = arith.mulf %arg1, %arg1 fastmath + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant 2.0 : f32 + %v = arith.constant dense <2.0> : vector<4xf32> + %0 = math.powf %arg0, %c fastmath : f32 + %1 = math.powf %arg1, %v fastmath : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + // CHECK-LABEL: @pow_cube func.func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { // CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0 @@ -36,6 +48,20 @@ func.func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { return %0, %1 : f32, vector<4xf32> } +// CHECK-LABEL: @pow_cube_fast +func.func @pow_cube_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0 fastmath + // CHECK: %[[SCALAR:.*]] = arith.mulf %arg0, %[[TMP_S]] fastmath + // CHECK: %[[TMP_V:.*]] = arith.mulf %arg1, %arg1 fastmath + // CHECK: %[[VECTOR:.*]] = arith.mulf %arg1, %[[TMP_V]] fastmath + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant 3.0 : f32 + %v = arith.constant dense <3.0> : vector<4xf32> + %0 = math.powf %arg0, %c fastmath : f32 + %1 = math.powf %arg1, %v fastmath : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + // CHECK-LABEL: @pow_recip func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { // CHECK-DAG: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32 @@ -50,6 +76,20 @@ func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) return %0, %1 : f32, vector<4xf32> } +// CHECK-LABEL: @pow_recip_fast +func.func @pow_recip_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK-DAG: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32 + // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1.0{{.*}}> : vector<4xf32> + // CHECK: %[[SCALAR:.*]] = arith.divf %[[CST_S]], %arg0 fastmath + // CHECK: %[[VECTOR:.*]] = arith.divf %[[CST_V]], %arg1 fastmath + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant -1.0 : f32 + %v = arith.constant dense <-1.0> : vector<4xf32> + %0 = math.powf %arg0, %c fastmath : f32 + %1 = math.powf %arg1, %v fastmath : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + // CHECK-LABEL: @pow_sqrt func.func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { // CHECK: %[[SCALAR:.*]] = math.sqrt %arg0 @@ -62,6 +102,18 @@ func.func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { return %0, %1 : f32, vector<4xf32> } +// CHECK-LABEL: @pow_sqrt_fast +func.func @pow_sqrt_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[SCALAR:.*]] = math.sqrt %arg0 fastmath + // CHECK: %[[VECTOR:.*]] = math.sqrt %arg1 fastmath + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant 0.5 : f32 + %v = arith.constant dense <0.5> : vector<4xf32> + %0 = math.powf %arg0, %c fastmath : f32 + %1 = math.powf %arg1, %v fastmath : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + // CHECK-LABEL: @pow_rsqrt func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { // CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0 @@ -74,6 +126,18 @@ func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) return %0, %1 : f32, vector<4xf32> } +// CHECK-LABEL: @pow_rsqrt_fast +func.func @pow_rsqrt_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0 fastmath + // CHECK: %[[VECTOR:.*]] = math.rsqrt %arg1 fastmath + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant -0.5 : f32 + %v = arith.constant dense <-0.5> : vector<4xf32> + %0 = math.powf %arg0, %c fastmath : f32 + %1 = math.powf %arg1, %v fastmath : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + // CHECK-LABEL: @pow_0_75 func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { // CHECK: %[[SQRT1S:.*]] = math.sqrt %arg0 @@ -90,6 +154,22 @@ func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { return %0, %1 : f32, vector<4xf32> } +// CHECK-LABEL: @pow_0_75_fast +func.func @pow_0_75_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[SQRT1S:.*]] = math.sqrt %arg0 fastmath + // CHECK: %[[SQRT2S:.*]] = math.sqrt %[[SQRT1S]] fastmath + // CHECK: %[[SCALAR:.*]] = arith.mulf %[[SQRT1S]], %[[SQRT2S]] fastmath + // CHECK: %[[SQRT1V:.*]] = math.sqrt %arg1 fastmath + // CHECK: %[[SQRT2V:.*]] = math.sqrt %[[SQRT1V]] fastmath + // CHECK: %[[VECTOR:.*]] = arith.mulf %[[SQRT1V]], %[[SQRT2V]] fastmath + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant 0.75 : f32 + %v = arith.constant dense <0.75> : vector<4xf32> + %0 = math.powf %arg0, %c fastmath : f32 + %1 = math.powf %arg1, %v fastmath : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + // CHECK-LABEL: @ipowi_zero_exp( // CHECK-SAME: %[[ARG0:.+]]: i32 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>