Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,24 @@ LogicalResult
PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
// pow(x, y)
Value x = op.getLhs();
Value y = op.getRhs();

FloatAttr scalarExponent;
DenseFPElementsAttr vectorExponent;
FloatAttr scalarBase, scalarExponent;
DenseFPElementsAttr vectorBase, vectorExponent;

bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
bool isScalarBase = matchPattern(x, m_Constant(&scalarBase));
bool isVectorBase = matchPattern(x, m_Constant(&vectorBase));
bool isScalarExponent = matchPattern(y, m_Constant(&scalarExponent));
bool isVectorExponent = matchPattern(y, m_Constant(&vectorExponent));

// Returns true if exponent is a constant equal to `value`.
auto isExponentValue = [&](double value) -> bool {
if (isScalar)
if (isScalarExponent)
return scalarExponent.getValue().isExactlyValue(value);

if (isVector && vectorExponent.isSplat())
if (isVectorExponent && vectorExponent.isSplat())
return vectorExponent.getSplatValue<FloatAttr>()
.getValue()
.isExactlyValue(value);
Expand Down Expand Up @@ -120,6 +124,24 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
return success();
}

// Replace `pow(2.0^n, y)` with `exp2(n * y)`
if (isScalarBase || (isVectorBase && vectorBase.isSplat())) {
APFloat baseValue = isScalarBase
? scalarBase.getValue()
: vectorBase.getSplatValue<FloatAttr>().getValue();
// Check if base is an exact power of 2
int n = baseValue.getExactLog2();
if (n != INT_MIN) {
Type opType = getElementTypeOrSelf(op.getType());
Value nValue = arith::ConstantOp::create(
rewriter, loc, rewriter.getFloatAttr(opType, n));
Value nTimesY =
arith::MulFOp::create(rewriter, loc, ValueRange({bcast(nValue), y}));
rewriter.replaceOpWithNewOp<math::Exp2Op>(op, nTimesY);
return success();
}
}

return failure();
}

Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/Math/algebraic-simplification.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,66 @@ func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
return %0, %1 : f32, vector<4xf32>
}

// CHECK-LABEL: @pow_of_two
func.func @pow_of_two(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK: %[[SCALAR:.*]] = math.exp2 %arg0
// CHECK: %[[VECTOR:.*]] = math.exp2 %arg1
// CHECK: return %[[SCALAR]], %[[VECTOR]]
%c = arith.constant 2.0 : f32
%v = arith.constant dense <2.0> : vector<4xf32>
%0 = math.powf %c, %arg0 : f32
%1 = math.powf %v, %arg1 : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}

// CHECK-LABEL: @pow_of_four
func.func @pow_of_four(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK-DAG: %[[CST_S:.*]] = arith.constant 2.000000e+00 : f32
// CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
// CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
// CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
// CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
// CHECK: return %[[SCALAR]], %[[VECTOR]]
%c = arith.constant 4.0 : f32
%v = arith.constant dense <4.0> : vector<4xf32>
%0 = math.powf %c, %arg0 : f32
%1 = math.powf %v, %arg1 : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}

// CHECK-LABEL: @pow_of_half
func.func @pow_of_half(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<-1.000000e+00> : vector<4xf32>
// CHECK-DAG: %[[CST_S:.*]] = arith.constant -1.000000e+00 : f32
// CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
// CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
// CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
// CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
// CHECK: return %[[SCALAR]], %[[VECTOR]]
%c = arith.constant 0.5 : f32
%v = arith.constant dense <0.5> : vector<4xf32>
%0 = math.powf %c, %arg0 : f32
%1 = math.powf %v, %arg1 : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}

// CHECK-LABEL: @pow_of_quarter
func.func @pow_of_quarter(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<-2.000000e+00> : vector<4xf32>
// CHECK-DAG: %[[CST_S:.*]] = arith.constant -2.000000e+00 : f32
// CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
// CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
// CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
// CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
// CHECK: return %[[SCALAR]], %[[VECTOR]]
%c = arith.constant 0.25 : f32
%v = arith.constant dense <0.25> : vector<4xf32>
%0 = math.powf %c, %arg0 : f32
%1 = math.powf %v, %arg1 : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}

// CHECK-LABEL: @ipowi_zero_exp(
// CHECK-SAME: %[[ARG0:.+]]: i32
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
Expand Down
Loading