diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp index 1de67d0fd184c..81fbdb1611deb 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -241,6 +241,73 @@ struct FpToIntConversion final : OpRewritePattern { bool isUnsigned; }; +template +struct IntToFpConversion final : OpRewritePattern { + IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, + bool isUnsigned, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), symTable(symTable), + isUnsigned(isUnsigned) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + if (op.getIn().getType().getIntOrFloatBitWidth() > 64) { + return rewriter.notifyMatchFailure( + loc, "integer bitwidth > 64 is not supported"); + } + + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr fn = + lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int", + {i32Type, i32Type, i1Type, i64Type}); + if (failed(fn)) + return fn; + + rewriter.setInsertionPoint(op); + // Cast operands to 64-bit integers. + auto inIntTy = cast(op.getOperand().getType()); + Value operandBits = op.getOperand(); + if (operandBits.getType().getIntOrFloatBitWidth() < 64) { + if (isUnsigned) { + operandBits = + arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits); + } else { + operandBits = + arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits); + } + } + + // Call APFloat function. + auto outFloatTy = cast(op.getType()); + Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy); + Value inWidthValue = arith::ConstantOp::create( + rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, inIntTy.getWidth())); + Value isUnsignedValue = arith::ConstantOp::create( + rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned)); + SmallVector params = {outSemValue, inWidthValue, isUnsignedValue, + operandBits}; + auto resultOp = + func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth()); + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType, + resultOp->getResult(0)); + Value result = + arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits); + rewriter.replaceOp(op, result); + return success(); + } + + SymbolOpInterface symTable; + bool isUnsigned; +}; + namespace { struct ArithToAPFloatConversionPass final : impl::ArithToAPFloatConversionPassBase { @@ -269,6 +336,10 @@ void ArithToAPFloatConversionPass::runOnOperation() { /*isUnsigned=*/false); patterns.add>(context, getOperation(), /*isUnsigned=*/true); + patterns.add>(context, getOperation(), + /*isUnsigned=*/false); + patterns.add>(context, getOperation(), + /*isUnsigned=*/true); LogicalResult result = success(); ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { if (diag.getSeverity() == DiagnosticSeverity::Error) { diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 8b89c43446765..44980ccd77491 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -119,4 +119,16 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int( // result to the desired result width. return result.getZExtValue(); } + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int( + int32_t semantics, int32_t inputWidth, bool isUnsigned, uint64_t a) { + llvm::APInt val(inputWidth, a, /*isSigned=*/!isUnsigned); + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast(semantics)); + llvm::APFloat result(sem); + // TODO: Custom rounding modes are not supported yet. + result.convertFromAPInt(val, /*IsSigned=*/!isUnsigned, + llvm::RoundingMode::NearestTiesToEven); + return result.bitcastToAPInt().getZExtValue(); +} } diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir index f1acfd5e5618a..d71d81dddcd4f 100644 --- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -174,3 +174,27 @@ func.func @fptoui(%arg0: f16) { %0 = arith.fptoui %arg0 : f16 to i4 return } + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32 +// CHECK: %[[in_width:.*]] = arith.constant 32 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant false +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +func.func @sitofp(%arg0: i32) { + %0 = arith.sitofp %arg0 : i32 to f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32 +// CHECK: %[[in_width:.*]] = arith.constant 32 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant true +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +func.func @uitofp(%arg0: i32) { + %0 = arith.uitofp %arg0 : i32 to f4E2M1FN + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir index 5e93945c3eb60..8046610d479a8 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -53,5 +53,18 @@ func.func @entry() { %cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2 vector.print %cvt_int_unsigned : i2 + // CHECK-NEXT: -6 + // Bit pattern: 1...11110111, interpreted as signed: -9 + // Closest f4E2M1FN value: -6.0 + %c9 = arith.constant -9 : i16 + %cvt_from_signed_int = arith.sitofp %c9 : i16 to f4E2M1FN + vector.print %cvt_from_signed_int : f4E2M1FN + + // CHECK-NEXT: 6 + // Bit pattern: 1...11110111, interpreted as unsigned: 65527 + // Closest f4E2M1FN value: 6.0 + %cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN + vector.print %cvt_from_unsigned_int : f4E2M1FN + return }