Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,37 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();

auto operandETy = getElementTypeOrSelf(opType);
unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
unsigned mantissaWidth =
llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
unsigned exponentWidth = bitWidth - mantissaWidth - 1;

Type iTy = rewriter.getIntegerType(bitWidth);
if (auto shapedTy = dyn_cast<ShapedType>(opType))
iTy = shapedTy.clone(iTy);

Value cMantissaWidth = createIntConst(op->getLoc(), iTy, mantissaWidth, b);
Value cBias =
createIntConst(op->getLoc(), iTy, (1ull << (exponentWidth - 1)) - 1, b);
Value cExpMask =
createIntConst(op->getLoc(), iTy, (1ull << exponentWidth) - 1, b);

// Any floating-point value with an unbiased exponent ≥ `mantissaWidth`
// falls into one of these categories:
// - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
// numbers are already integral, or
// - a special value (NaN or ±Inf), which also satisfies this exponent
// condition.
// For all such cases, `ceilf(x)` is defined to return `x` directly.
Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
Value operandExp = arith::AndIOp::create(
b, arith::ShRUIOp::create(b, operandBitcast, cMantissaWidth), cExpMask);
Value operandBiasedExp = arith::SubIOp::create(b, operandExp, cBias);
Value isSpecialValOrLargeVal = arith::CmpIOp::create(
b, arith::CmpIPredicate::sge, operandBiasedExp, cMantissaWidth);

Value fpFixedConvert = createTruncatedFPValue(operand, b);

// Creating constants for later use.
Expand All @@ -243,7 +274,8 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value incrValue =
arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);

Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
rewriter.replaceOp(op, ret);
return success();
}
Expand Down
11 changes: 10 additions & 1 deletion mlir/test/Dialect/Math/expand-math.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,22 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
func.func @ceilf_func(%a: f64) -> f64 {
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
// CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000
// CHECK-DAG: [[C52:%.*]] = arith.constant 52
// CHECK-DAG: [[C1023:%.*]] = arith.constant 1023
// CHECK-DAG: [[EXP_MASK:%.*]] = arith.constant 2047
// CHECK-NEXT: [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f64 to i64
// CHECK-NEXT: [[ARG_BITCAST_SHIFTED:%.*]] = arith.shrui [[ARG_BITCAST]], [[C52]]
// CHECK-NEXT: [[ARG_EXP:%.*]] = arith.andi [[ARG_BITCAST_SHIFTED]], [[EXP_MASK]]
// CHECK-NEXT: [[ARG_BIASED_EXP:%.*]] = arith.subi [[ARG_EXP]], [[C1023]]
// CHECK-NEXT: [[IS_SPECIAL_VAL:%.*]] = arith.cmpi sge, [[ARG_BIASED_EXP]], [[C52]]
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
// CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
// CHECK-NEXT: return [[ADDF]]
// CHECK-NEXT: [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]]
// CHECK-NEXT: return [[RESULT]]
// CHECK-FILTER: math.ceil
%ret = math.ceil %a : f64
return %ret : f64
Expand Down