Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,73 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
bool isUnsigned;
};

template <typename OpTy>
struct IntToFpConversion final : OpRewritePattern<OpTy> {
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
bool isUnsigned, PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
isUnsigned(isUnsigned) {}

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
if (op.getIn().getType().getIntOrFloatBitWidth() > 64) {
return rewriter.notifyMatchFailure(
loc, "integer bitwidth > 64 is not supported");
}

// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
{i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;

rewriter.setInsertionPoint(op);
// Cast operands to 64-bit integers.
auto inIntTy = cast<IntegerType>(op.getOperand().getType());
Value operandBits = op.getOperand();
if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
if (isUnsigned) {
operandBits =
arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
} else {
operandBits =
arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
}
}

// Call APFloat function.
auto outFloatTy = cast<FloatType>(op.getType());
Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
Value inWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
Value isUnsignedValue = arith::ConstantOp::create(
rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
SmallVector<Value> params = {outSemValue, inWidthValue, isUnsignedValue,
operandBits};
auto resultOp =
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
SymbolRefAttr::get(*fn), params);

// Truncate result to the original width.
auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
resultOp->getResult(0));
Value result =
arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits);
rewriter.replaceOp(op, result);
return success();
}

SymbolOpInterface symTable;
bool isUnsigned;
};

namespace {
struct ArithToAPFloatConversionPass final
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
Expand Down Expand Up @@ -269,6 +336,10 @@ void ArithToAPFloatConversionPass::runOnOperation() {
/*isUnsigned=*/false);
patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
/*isUnsigned=*/true);
patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
/*isUnsigned=*/false);
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
/*isUnsigned=*/true);
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/ExecutionEngine/APFloatWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,16 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int(
// result to the desired result width.
return result.getZExtValue();
}

MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
int32_t semantics, int32_t inputWidth, bool isUnsigned, uint64_t a) {
llvm::APInt val(inputWidth, a, /*isSigned=*/!isUnsigned);
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
static_cast<llvm::APFloatBase::Semantics>(semantics));
llvm::APFloat result(sem);
// TODO: Custom rounding modes are not supported yet.
result.convertFromAPInt(val, /*IsSigned=*/!isUnsigned,
llvm::RoundingMode::NearestTiesToEven);
return result.bitcastToAPInt().getZExtValue();
}
}
24 changes: 24 additions & 0 deletions mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,27 @@ func.func @fptoui(%arg0: f16) {
%0 = arith.fptoui %arg0 : f16 to i4
return
}

// -----

// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
// CHECK: %[[is_unsigned:.*]] = arith.constant false
// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
func.func @sitofp(%arg0: i32) {
%0 = arith.sitofp %arg0 : i32 to f4E2M1FN
return
}

// -----

// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
// CHECK: %[[is_unsigned:.*]] = arith.constant true
// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
func.func @uitofp(%arg0: i32) {
%0 = arith.uitofp %arg0 : i32 to f4E2M1FN
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,18 @@ func.func @entry() {
%cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2
vector.print %cvt_int_unsigned : i2

// CHECK-NEXT: -6
// Bit pattern: 1...11110111, interpreted as signed: -9
// Closest f4E2M1FN value: -6.0
%c9 = arith.constant -9 : i16
%cvt_from_signed_int = arith.sitofp %c9 : i16 to f4E2M1FN
vector.print %cvt_from_signed_int : f4E2M1FN

// CHECK-NEXT: 6
// Bit pattern: 1...11110111, interpreted as unsigned: 65527
// Closest f4E2M1FN value: 6.0
%cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN
vector.print %cvt_from_unsigned_int : f4E2M1FN

return
}