diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp index cd68039d0d964..e9f4811aae3fe 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp @@ -232,6 +232,37 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); Type opType = operand.getType(); + + auto operandETy = getElementTypeOrSelf(opType); + unsigned bitWidth = operandETy.getIntOrFloatBitWidth(); + unsigned mantissaWidth = + llvm::cast(operandETy).getFPMantissaWidth() - 1; + unsigned exponentWidth = bitWidth - mantissaWidth - 1; + + Type iTy = rewriter.getIntegerType(bitWidth); + if (auto shapedTy = dyn_cast(opType)) + iTy = shapedTy.clone(iTy); + + Value cMantissaWidth = createIntConst(op->getLoc(), iTy, mantissaWidth, b); + Value cBias = + createIntConst(op->getLoc(), iTy, (1ull << (exponentWidth - 1)) - 1, b); + Value cExpMask = + createIntConst(op->getLoc(), iTy, (1ull << exponentWidth) - 1, b); + + // Any floating-point value with an unbiased exponent ≥ `mantissaWidth` + // falls into one of these categories: + // - a large finite value (|x| ≥ 2^mantissaWidth), where all representable + // numbers are already integral, or + // - a special value (NaN or ±Inf), which also satisfies this exponent + // condition. + // For all such cases, `ceilf(x)` is defined to return `x` directly. + Value operandBitcast = arith::BitcastOp::create(b, iTy, operand); + Value operandExp = arith::AndIOp::create( + b, arith::ShRUIOp::create(b, operandBitcast, cMantissaWidth), cExpMask); + Value operandBiasedExp = arith::SubIOp::create(b, operandExp, cBias); + Value isSpecialValOrLargeVal = arith::CmpIOp::create( + b, arith::CmpIPredicate::sge, operandBiasedExp, cMantissaWidth); + Value fpFixedConvert = createTruncatedFPValue(operand, b); // Creating constants for later use. @@ -243,7 +274,8 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { Value incrValue = arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero); - Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue); + Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue); + Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add); rewriter.replaceOp(op, ret); return success(); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 615c607efc3c3..75f8e65b334a2 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -145,13 +145,22 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 { func.func @ceilf_func(%a: f64) -> f64 { // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000 // CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000 + // CHECK-DAG: [[C52:%.*]] = arith.constant 52 + // CHECK-DAG: [[C1023:%.*]] = arith.constant 1023 + // CHECK-DAG: [[EXP_MASK:%.*]] = arith.constant 2047 + // CHECK-NEXT: [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f64 to i64 + // CHECK-NEXT: [[ARG_BITCAST_SHIFTED:%.*]] = arith.shrui [[ARG_BITCAST]], [[C52]] + // CHECK-NEXT: [[ARG_EXP:%.*]] = arith.andi [[ARG_BITCAST_SHIFTED]], [[EXP_MASK]] + // CHECK-NEXT: [[ARG_BIASED_EXP:%.*]] = arith.subi [[ARG_EXP]], [[C1023]] + // CHECK-NEXT: [[IS_SPECIAL_VAL:%.*]] = arith.cmpi sge, [[ARG_BIASED_EXP]], [[C52]] // CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]] // CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]] // CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]] // CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]] // CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]] // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]] - // CHECK-NEXT: return [[ADDF]] + // CHECK-NEXT: [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]] + // CHECK-NEXT: return [[RESULT]] // CHECK-FILTER: math.ceil %ret = math.ceil %a : f64 return %ret : f64