diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index 77b10cec48d8e..2e5b48ebbb1eb 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -42,20 +42,24 @@ LogicalResult PowFStrengthReduction::matchAndRewrite(math::PowFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); + // pow(x, y) Value x = op.getLhs(); + Value y = op.getRhs(); - FloatAttr scalarExponent; - DenseFPElementsAttr vectorExponent; + FloatAttr scalarBase, scalarExponent; + DenseFPElementsAttr vectorBase, vectorExponent; - bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent)); - bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent)); + bool isScalarBase = matchPattern(x, m_Constant(&scalarBase)); + bool isVectorBase = matchPattern(x, m_Constant(&vectorBase)); + bool isScalarExponent = matchPattern(y, m_Constant(&scalarExponent)); + bool isVectorExponent = matchPattern(y, m_Constant(&vectorExponent)); // Returns true if exponent is a constant equal to `value`. auto isExponentValue = [&](double value) -> bool { - if (isScalar) + if (isScalarExponent) return scalarExponent.getValue().isExactlyValue(value); - if (isVector && vectorExponent.isSplat()) + if (isVectorExponent && vectorExponent.isSplat()) return vectorExponent.getSplatValue() .getValue() .isExactlyValue(value); @@ -120,6 +124,24 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, return success(); } + // Replace `pow(2.0^n, y)` with `exp2(n * y)` + if (isScalarBase || (isVectorBase && vectorBase.isSplat())) { + APFloat baseValue = isScalarBase + ? scalarBase.getValue() + : vectorBase.getSplatValue().getValue(); + // Check if base is an exact power of 2 + int n = baseValue.getExactLog2(); + if (n != INT_MIN) { + Type opType = getElementTypeOrSelf(op.getType()); + Value nValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(opType, n)); + Value nTimesY = + arith::MulFOp::create(rewriter, loc, ValueRange({bcast(nValue), y})); + rewriter.replaceOpWithNewOp(op, nTimesY); + return success(); + } + } + return failure(); } diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir index e0e2b9853a2a1..239be5eeeb6ac 100644 --- a/mlir/test/Dialect/Math/algebraic-simplification.mlir +++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -90,6 +90,66 @@ func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { return %0, %1 : f32, vector<4xf32> } +// CHECK-LABEL: @pow_of_two +func.func @pow_of_two(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[SCALAR:.*]] = math.exp2 %arg0 + // CHECK: %[[VECTOR:.*]] = math.exp2 %arg1 + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant 2.0 : f32 + %v = arith.constant dense <2.0> : vector<4xf32> + %0 = math.powf %c, %arg0 : f32 + %1 = math.powf %v, %arg1 : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + +// CHECK-LABEL: @pow_of_four +func.func @pow_of_four(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32> + // CHECK-DAG: %[[CST_S:.*]] = arith.constant 2.000000e+00 : f32 + // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]] + // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]] + // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]] + // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]] + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant 4.0 : f32 + %v = arith.constant dense <4.0> : vector<4xf32> + %0 = math.powf %c, %arg0 : f32 + %1 = math.powf %v, %arg1 : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + +// CHECK-LABEL: @pow_of_half +func.func @pow_of_half(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<-1.000000e+00> : vector<4xf32> + // CHECK-DAG: %[[CST_S:.*]] = arith.constant -1.000000e+00 : f32 + // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]] + // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]] + // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]] + // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]] + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant 0.5 : f32 + %v = arith.constant dense <0.5> : vector<4xf32> + %0 = math.powf %c, %arg0 : f32 + %1 = math.powf %v, %arg1 : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + +// CHECK-LABEL: @pow_of_quarter +func.func @pow_of_quarter(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<-2.000000e+00> : vector<4xf32> + // CHECK-DAG: %[[CST_S:.*]] = arith.constant -2.000000e+00 : f32 + // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]] + // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]] + // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]] + // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]] + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = arith.constant 0.25 : f32 + %v = arith.constant dense <0.25> : vector<4xf32> + %0 = math.powf %c, %arg0 : f32 + %1 = math.powf %v, %arg1 : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + // CHECK-LABEL: @ipowi_zero_exp( // CHECK-SAME: %[[ARG0:.+]]: i32 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>