-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][arith] Add support for sitofp, uitofp to ArithToAPFloat
#169284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][arith] Add support for sitofp, uitofp to ArithToAPFloat
#169284
Conversation
|
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd support for Full diff: https://github.com/llvm/llvm-project/pull/169284.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 1fe698f1c8902..0a37688246537 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -239,6 +239,72 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
bool isUnsigned;
};
+template <typename OpTy>
+struct IntToFpConversion final : OpRewritePattern<OpTy> {
+ IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
+ bool isUnsigned, PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ isUnsigned(isUnsigned){};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i1Type = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
+ {i32Type, i32Type, i1Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ Location loc = op.getLoc();
+ auto inIntTy = cast<IntegerType>(op.getOperand().getType());
+ auto int64Type = rewriter.getI64Type();
+ Value operandBits = op.getOperand();
+ if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
+ if (isUnsigned) {
+ operandBits =
+ arith::ExtUIOp::create(rewriter, loc, int64Type, operandBits);
+ } else {
+ operandBits =
+ arith::ExtSIOp::create(rewriter, loc, int64Type, operandBits);
+ }
+ } else if (operandBits.getType().getIntOrFloatBitWidth() > 64) {
+ return rewriter.notifyMatchFailure(
+ loc, "integer bitwidth > 64 is not supported");
+ }
+
+ // Call APFloat function.
+ auto outFloatTy = cast<FloatType>(op.getType());
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value inWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {outSemValue, inWidthValue, isUnsignedValue,
+ operandBits};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
+ resultOp->getResult(0));
+ Value result =
+ arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ bool isUnsigned;
+};
+
namespace {
struct ArithToAPFloatConversionPass final
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -266,6 +332,10 @@ void ArithToAPFloatConversionPass::runOnOperation() {
/*isUnsigned=*/false);
patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
/*isUnsigned=*/true);
+ patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
+ /*isUnsigned=*/false);
+ patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
+ /*isUnsigned=*/true);
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 632fe9cf2269d..2fbcc26200540 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -115,4 +115,16 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int(
val.convertToInteger(result, llvm::RoundingMode::NearestTiesToEven, &isExact);
return result.getZExtValue();
}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
+ int32_t semantics, int32_t inputWidth, bool isUnsigned, uint64_t a) {
+ llvm::APInt val(inputWidth, a, /*isSigned=*/!isUnsigned);
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ llvm::APFloat result(sem);
+ // TODO: Custom rounding modes are not supported yet.
+ result.convertFromAPInt(val, /*IsSigned=*/!isUnsigned,
+ llvm::RoundingMode::NearestTiesToEven);
+ return result.bitcastToAPInt().getZExtValue();
+}
}
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index f1acfd5e5618a..d71d81dddcd4f 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -174,3 +174,27 @@ func.func @fptoui(%arg0: f16) {
%0 = arith.fptoui %arg0 : f16 to i4
return
}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
+// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
+// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
+// CHECK: %[[is_unsigned:.*]] = arith.constant false
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
+func.func @sitofp(%arg0: i32) {
+ %0 = arith.sitofp %arg0 : i32 to f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
+// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
+// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
+// CHECK: %[[is_unsigned:.*]] = arith.constant true
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
+func.func @uitofp(%arg0: i32) {
+ %0 = arith.uitofp %arg0 : i32 to f4E2M1FN
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 5e93945c3eb60..8046610d479a8 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -53,5 +53,18 @@ func.func @entry() {
%cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2
vector.print %cvt_int_unsigned : i2
+ // CHECK-NEXT: -6
+ // Bit pattern: 1...11110111, interpreted as signed: -9
+ // Closest f4E2M1FN value: -6.0
+ %c9 = arith.constant -9 : i16
+ %cvt_from_signed_int = arith.sitofp %c9 : i16 to f4E2M1FN
+ vector.print %cvt_from_signed_int : f4E2M1FN
+
+ // CHECK-NEXT: 6
+ // Bit pattern: 1...11110111, interpreted as unsigned: 65527
+ // Closest f4E2M1FN value: 6.0
+ %cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN
+ vector.print %cvt_from_unsigned_int : f4E2M1FN
+
return
}
|
|
@llvm/pr-subscribers-mlir-execution-engine Author: Matthias Springer (matthias-springer) ChangesAdd support for Full diff: https://github.com/llvm/llvm-project/pull/169284.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 1fe698f1c8902..0a37688246537 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -239,6 +239,72 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
bool isUnsigned;
};
+template <typename OpTy>
+struct IntToFpConversion final : OpRewritePattern<OpTy> {
+ IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
+ bool isUnsigned, PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ isUnsigned(isUnsigned){};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i1Type = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
+ {i32Type, i32Type, i1Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ Location loc = op.getLoc();
+ auto inIntTy = cast<IntegerType>(op.getOperand().getType());
+ auto int64Type = rewriter.getI64Type();
+ Value operandBits = op.getOperand();
+ if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
+ if (isUnsigned) {
+ operandBits =
+ arith::ExtUIOp::create(rewriter, loc, int64Type, operandBits);
+ } else {
+ operandBits =
+ arith::ExtSIOp::create(rewriter, loc, int64Type, operandBits);
+ }
+ } else if (operandBits.getType().getIntOrFloatBitWidth() > 64) {
+ return rewriter.notifyMatchFailure(
+ loc, "integer bitwidth > 64 is not supported");
+ }
+
+ // Call APFloat function.
+ auto outFloatTy = cast<FloatType>(op.getType());
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value inWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {outSemValue, inWidthValue, isUnsignedValue,
+ operandBits};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
+ resultOp->getResult(0));
+ Value result =
+ arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ bool isUnsigned;
+};
+
namespace {
struct ArithToAPFloatConversionPass final
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -266,6 +332,10 @@ void ArithToAPFloatConversionPass::runOnOperation() {
/*isUnsigned=*/false);
patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
/*isUnsigned=*/true);
+ patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
+ /*isUnsigned=*/false);
+ patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
+ /*isUnsigned=*/true);
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 632fe9cf2269d..2fbcc26200540 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -115,4 +115,16 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int(
val.convertToInteger(result, llvm::RoundingMode::NearestTiesToEven, &isExact);
return result.getZExtValue();
}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
+ int32_t semantics, int32_t inputWidth, bool isUnsigned, uint64_t a) {
+ llvm::APInt val(inputWidth, a, /*isSigned=*/!isUnsigned);
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ llvm::APFloat result(sem);
+ // TODO: Custom rounding modes are not supported yet.
+ result.convertFromAPInt(val, /*IsSigned=*/!isUnsigned,
+ llvm::RoundingMode::NearestTiesToEven);
+ return result.bitcastToAPInt().getZExtValue();
+}
}
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index f1acfd5e5618a..d71d81dddcd4f 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -174,3 +174,27 @@ func.func @fptoui(%arg0: f16) {
%0 = arith.fptoui %arg0 : f16 to i4
return
}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
+// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
+// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
+// CHECK: %[[is_unsigned:.*]] = arith.constant false
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
+func.func @sitofp(%arg0: i32) {
+ %0 = arith.sitofp %arg0 : i32 to f4E2M1FN
+ return
+}
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
+// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
+// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
+// CHECK: %[[is_unsigned:.*]] = arith.constant true
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
+func.func @uitofp(%arg0: i32) {
+ %0 = arith.uitofp %arg0 : i32 to f4E2M1FN
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 5e93945c3eb60..8046610d479a8 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -53,5 +53,18 @@ func.func @entry() {
%cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2
vector.print %cvt_int_unsigned : i2
+ // CHECK-NEXT: -6
+ // Bit pattern: 1...11110111, interpreted as signed: -9
+ // Closest f4E2M1FN value: -6.0
+ %c9 = arith.constant -9 : i16
+ %cvt_from_signed_int = arith.sitofp %c9 : i16 to f4E2M1FN
+ vector.print %cvt_from_signed_int : f4E2M1FN
+
+ // CHECK-NEXT: 6
+ // Bit pattern: 1...11110111, interpreted as unsigned: 65527
+ // Closest f4E2M1FN value: 6.0
+ %cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN
+ vector.print %cvt_from_unsigned_int : f4E2M1FN
+
return
}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
makslevental
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
eb5759f to
790e4bc
Compare
3ab9d96 to
727c97c
Compare
727c97c to
0737431
Compare
…lvm#169284) Add support for `arith.sitofp` and `arith.uitofp`.
Add support for
arith.sitofpandarith.uitofp.