Skip to content

Commit

Permalink
[mlir][Math] Fix NaN handling in ExpM1 approximation.
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D119822
  • Loading branch information
akuegel committed Feb 15, 2022
1 parent f35af77 commit 87de451
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,8 +1033,8 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
Value cstNegOne = bcast(f32Cst(builder, -1.0f));
Value x = op.getOperand();
Value u = builder.create<math::ExpOp>(x);
Value uEqOne =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
Value uEqOneOrNaN =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
Expand All @@ -1050,7 +1050,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
uMinusOne, builder.create<arith::DivFOp>(x, logU));
expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
Value approximation = builder.create<arith::SelectOp>(
uEqOne, x,
uEqOneOrNaN, x,
builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
rewriter.replaceOp(op, approximation);
return success();
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Math/polynomial-approximation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-NOT: exp
// CHECK-COUNT-3: select
// CHECK: %[[EXP_X:.*]] = arith.select
// CHECK: %[[VAL_58:.*]] = arith.cmpf oeq, %[[EXP_X]], %[[CST_ONE]] : f32
// CHECK: %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32
// CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32
// CHECK: %[[VAL_60:.*]] = arith.cmpf oeq, %[[VAL_59]], %[[CST_MINUSONE]] : f32
// CHECK-NOT: log
Expand All @@ -174,7 +174,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK: %[[VAL_106:.*]] = arith.mulf %[[VAL_59]], %[[VAL_105]] : f32
// CHECK: %[[VAL_107:.*]] = arith.select %[[VAL_104]], %[[EXP_X]], %[[VAL_106]] : f32
// CHECK: %[[VAL_108:.*]] = arith.select %[[VAL_60]], %[[CST_MINUSONE]], %[[VAL_107]] : f32
// CHECK: %[[VAL_109:.*]] = arith.select %[[VAL_58]], %[[X]], %[[VAL_108]] : f32
// CHECK: %[[VAL_109:.*]] = arith.select %[[IS_ONE_OR_NAN]], %[[X]], %[[VAL_108]] : f32
// CHECK: return %[[VAL_109]] : f32
// CHECK: }
func @expm1_scalar(%arg0: f32) -> f32 {
Expand Down

0 comments on commit 87de451

Please sign in to comment.