diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp index 5c68236526b7d..4776ba0f49b94 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -102,6 +102,10 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { + if (op.getType().getIntOrFloatBitWidth() > 64) + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + // Get APFloat function from runtime library. FailureOr fn = lookupOrCreateBinaryFn(rewriter, symTable, APFloatName); @@ -148,6 +152,11 @@ struct FpToFpConversion final : OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { + if (op.getType().getIntOrFloatBitWidth() > 64 || + op.getOperand().getType().getIntOrFloatBitWidth() > 64) + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + // Get APFloat function from runtime library. auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); @@ -195,9 +204,10 @@ struct FpToIntConversion final : OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - if (op.getType().getIntOrFloatBitWidth() > 64) - return rewriter.notifyMatchFailure( - op, "result type > 64 bits is not supported"); + if (op.getType().getIntOrFloatBitWidth() > 64 || + op.getOperand().getType().getIntOrFloatBitWidth() > 64) + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); // Get APFloat function from runtime library. auto i1Type = IntegerType::get(symTable->getContext(), 1); @@ -252,11 +262,10 @@ struct IntToFpConversion final : OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - if (op.getIn().getType().getIntOrFloatBitWidth() > 64) { - return rewriter.notifyMatchFailure( - loc, "integer bitwidth > 64 is not supported"); - } + if (op.getType().getIntOrFloatBitWidth() > 64 || + op.getOperand().getType().getIntOrFloatBitWidth() > 64) + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); // Get APFloat function from runtime library. auto i1Type = IntegerType::get(symTable->getContext(), 1); @@ -270,6 +279,7 @@ struct IntToFpConversion final : OpRewritePattern { rewriter.setInsertionPoint(op); // Cast operands to 64-bit integers. + Location loc = op.getLoc(); auto inIntTy = cast(op.getOperand().getType()); Value operandBits = op.getOperand(); if (operandBits.getType().getIntOrFloatBitWidth() < 64) { @@ -317,6 +327,10 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern { LogicalResult matchAndRewrite(arith::CmpFOp op, PatternRewriter &rewriter) const override { + if (op.getLhs().getType().getIntOrFloatBitWidth() > 64) + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + // Get APFloat function from runtime library. auto i1Type = IntegerType::get(symTable->getContext(), 1); auto i8Type = IntegerType::get(symTable->getContext(), 8); @@ -456,6 +470,10 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern { LogicalResult matchAndRewrite(arith::NegFOp op, PatternRewriter &rewriter) const override { + if (op.getOperand().getType().getIntOrFloatBitWidth() > 64) + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + // Get APFloat function from runtime library. auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir index 950d2cecefa95..ab05edebec71d 100644 --- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -263,3 +263,28 @@ func.func @maxnumf(%arg0: f32, %arg1: f32) { %0 = arith.maxnumf %arg0, %arg1 : f32 return } + +// ----- + +// CHECK-LABEL: func.func @unsupported_bitwidth +// CHECK: arith.addf {{.*}} : f128 +// CHECK: arith.negf {{.*}} : f128 +// CHECK: arith.cmpf {{.*}} : f128 +// CHECK: arith.extf {{.*}} : f32 to f128 +// CHECK: arith.truncf {{.*}} : f128 to f32 +// CHECK: arith.fptosi {{.*}} : f128 to i32 +// CHECK: arith.fptosi {{.*}} : f32 to i92 +// CHECK: arith.sitofp {{.*}} : i1 to f128 +// CHECK: arith.sitofp {{.*}} : i92 to f32 +func.func @unsupported_bitwidth(%arg0: f128, %arg1: f128, %arg2: f32) { + %0 = arith.addf %arg0, %arg1 : f128 + %1 = arith.negf %arg0 : f128 + %2 = arith.cmpf "ult", %arg0, %arg1 : f128 + %3 = arith.extf %arg2 : f32 to f128 + %4 = arith.truncf %arg0 : f128 to f32 + %5 = arith.fptosi %arg0 : f128 to i32 + %6 = arith.fptosi %arg2 : f32 to i92 + %7 = arith.sitofp %2 : i1 to f128 + %8 = arith.sitofp %6 : i92 to f32 + return +}