Skip to content

Conversation

@nurmukhametov
Copy link
Contributor

This PR adds a pow(2^n, y) to exp2(y) transformation to the algebraic simplifications of the math dialect.

@llvmbot
Copy link
Member

llvmbot commented Nov 3, 2025

@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Aleksei Nurmukhametov (nurmukhametov)

Changes

This PR adds a pow(2^n, y) to exp2(y) transformation to the algebraic simplifications of the math dialect.


Full diff: https://github.com/llvm/llvm-project/pull/166183.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+27-6)
  • (modified) mlir/test/Dialect/Math/algebraic-simplification.mlir (+60)
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 77b10cec48d8e..03e8a0d020919 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -42,20 +42,24 @@ LogicalResult
 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
                                        PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
+  // pow(x, y)
   Value x = op.getLhs();
+  Value y = op.getRhs();
 
-  FloatAttr scalarExponent;
-  DenseFPElementsAttr vectorExponent;
+  FloatAttr scalarBase, scalarExponent;
+  DenseFPElementsAttr vectorBase, vectorExponent;
 
-  bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
-  bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
+  bool isScalarBase = matchPattern(x, m_Constant(&scalarBase));
+  bool isVectorBase = matchPattern(x, m_Constant(&vectorBase));
+  bool isScalarExponent = matchPattern(y, m_Constant(&scalarExponent));
+  bool isVectorExponent = matchPattern(y, m_Constant(&vectorExponent));
 
   // Returns true if exponent is a constant equal to `value`.
   auto isExponentValue = [&](double value) -> bool {
-    if (isScalar)
+    if (isScalarExponent)
       return scalarExponent.getValue().isExactlyValue(value);
 
-    if (isVector && vectorExponent.isSplat())
+    if (isVectorExponent && vectorExponent.isSplat())
       return vectorExponent.getSplatValue<FloatAttr>()
           .getValue()
           .isExactlyValue(value);
@@ -120,6 +124,23 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
     return success();
   }
 
+  // Replace `pow(2.0^n, y)` with `exp2(n * y)`
+  if (isScalarBase || (isVectorBase && vectorBase.isSplat())) {
+    APFloat baseValue = isScalarBase
+                            ? scalarBase.getValue()
+                            : vectorBase.getSplatValue<FloatAttr>().getValue();
+    // Check if base is an exact power of 2
+    int n = baseValue.getExactLog2();
+    if (n != INT_MIN) {
+      Value nValue = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), n));
+      Value nTimesY =
+          rewriter.create<arith::MulFOp>(loc, ValueRange({bcast(nValue), y}));
+      rewriter.replaceOpWithNewOp<math::Exp2Op>(op, nTimesY);
+      return success();
+    }
+  }
+
   return failure();
 }
 
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index e0e2b9853a2a1..239be5eeeb6ac 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -90,6 +90,66 @@ func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   return %0, %1 : f32, vector<4xf32>
 }
 
+// CHECK-LABEL: @pow_of_two
+func.func @pow_of_two(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[SCALAR:.*]] = math.exp2 %arg0
+  // CHECK: %[[VECTOR:.*]] = math.exp2 %arg1
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 2.0 : f32
+  %v = arith.constant dense <2.0> : vector<4xf32>
+  %0 = math.powf %c, %arg0 : f32
+  %1 = math.powf %v, %arg1 : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_of_four
+func.func @pow_of_four(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+  // CHECK-DAG: %[[CST_S:.*]] = arith.constant 2.000000e+00 : f32
+  // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
+  // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
+  // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
+  // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 4.0 : f32
+  %v = arith.constant dense <4.0> : vector<4xf32>
+  %0 = math.powf %c, %arg0 : f32
+  %1 = math.powf %v, %arg1 : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_of_half
+func.func @pow_of_half(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<-1.000000e+00> : vector<4xf32>
+  // CHECK-DAG: %[[CST_S:.*]] = arith.constant -1.000000e+00 : f32
+  // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
+  // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
+  // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
+  // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 0.5 : f32
+  %v = arith.constant dense <0.5> : vector<4xf32>
+  %0 = math.powf %c, %arg0 : f32
+  %1 = math.powf %v, %arg1 : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_of_quarter
+func.func @pow_of_quarter(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<-2.000000e+00> : vector<4xf32>
+  // CHECK-DAG: %[[CST_S:.*]] = arith.constant -2.000000e+00 : f32
+  // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
+  // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
+  // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
+  // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 0.25 : f32
+  %v = arith.constant dense <0.25> : vector<4xf32>
+  %0 = math.powf %c, %arg0 : f32
+  %1 = math.powf %v, %arg1 : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
 // CHECK-LABEL: @ipowi_zero_exp(
 // CHECK-SAME: %[[ARG0:.+]]: i32
 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants