diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 4b830c05bf585..8f0755c8a8144 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1425,6 +1425,40 @@ def Arith_TruncFOp : attr-dict `:` type($in) `to` type($out) }]; } +//===----------------------------------------------------------------------===// +// ConvertFOp +//===----------------------------------------------------------------------===// + +def Arith_ConvertFOp : + Arith_Op<"convertf", + [Pure, SameOperandsAndResultShape, SameInputOutputTensorDims, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, + Arguments<(ins FloatLike:$in, + OptionalAttr:$roundingmode, + OptionalAttr:$fastmath)>, + Results<(outs FloatLike:$out)> { + let summary = "cast between floating-point types of the same bitwidth"; + let description = [{ + Cast a floating-point value to a different floating-point type of the same + bitwidth. This operation handles conversions between types that have the + same bitwidth but different semantics (e.g., f16 to bf16), which cannot + be represented by `arith.extf` or `arith.truncf`. + + The source and destination element types must be different and must have + the same bitwidth. If the value cannot be exactly represented, it is + rounded using the provided rounding mode or the default one if no rounding + mode is provided. When operating on vectors, casts elementwise. + }]; + + let hasFolder = 1; + let hasVerifier = 1; + let assemblyFormat = [{ $in ($roundingmode^)? + (`fastmath` `` $fastmath^)? + attr-dict `:` type($in) `to` type($out) }]; +} + //===----------------------------------------------------------------------===// // Scaling TruncFOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index e7f561e8a4d67..a0346ec6f4fb6 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -262,6 +262,67 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override; }; +/// Lower arith.convertf (same-bitwidth FP cast) to LLVM. +/// +/// Extends to f32 via llvm.fpext, then truncates to the target type via +/// llvm.fptrunc. This handles bf16 <-> f16, which is the only same-bitwidth +/// pair of LLVM-supported FP types. +struct ConvertFOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::ConvertFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (LLVM::detail::opHasUnsupportedFloatingPointTypes(op, + *getTypeConverter())) + return rewriter.notifyMatchFailure(op, "unsupported floating point type"); + + // Only bf16 <-> f16 conversions are supported. There is currently no other + // pair of FP types that are valid LLVM types. + auto srcType = getElementTypeOrSelf(op.getIn().getType()); + auto dstType = getElementTypeOrSelf(op.getType()); + assert((srcType.isBF16() && dstType.isF16()) || + (srcType.isF16() && dstType.isBF16()) && + "only bf16 <-> f16 conversions are supported"); + + Type convertedType = getTypeConverter()->convertType(op.getType()); + if (!convertedType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + Value input = adaptor.getIn(); + Location loc = op.getLoc(); + + if (!isa(input.getType())) { + rewriter.replaceOp(op, + emitConversion(rewriter, loc, input, convertedType)); + return success(); + } + + if (!isa(op.getType())) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) -> Value { + return emitConversion(rewriter, loc, operands.front(), + llvm1DVectorTy); + }, + rewriter); + } + +private: + static Value emitConversion(ConversionPatternRewriter &rewriter, Location loc, + Value input, Type targetType) { + Type f32Scalar = Float32Type::get(rewriter.getContext()); + Type f32Ty = f32Scalar; + if (auto vecTy = dyn_cast(targetType)) + f32Ty = VectorType::get(vecTy.getShape(), f32Scalar); + + Value ext = LLVM::FPExtOp::create(rewriter, loc, f32Ty, input); + return LLVM::FPTruncOp::create(rewriter, loc, targetType, ext); + } +}; + struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Adaptor = ConvertOpToLLVMPattern::OneToNOpAdaptor; @@ -642,6 +703,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns( ExtFOpLowering, ExtSIOpLowering, ExtUIOpLowering, + ConvertFOpLowering, FPToSIOpLowering, FPToUIOpLowering, IndexCastOpSILowering, diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 155edc5070a9d..6999e9153bb9a 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1699,6 +1699,55 @@ LogicalResult arith::TruncFOp::verify() { return verifyTruncateOp(*this); } +//===----------------------------------------------------------------------===// +// ConvertFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) { + auto resElemType = cast(getElementTypeOrSelf(getType())); + const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); + return constFoldCastOp( + adaptor.getOperands(), getType(), + [this, &targetSemantics](const APFloat &a, bool &castStatus) { + RoundingMode roundingMode = + getRoundingmode().value_or(RoundingMode::to_nearest_even); + llvm::RoundingMode llvmRoundingMode = + convertArithRoundingModeToLLVMIR(roundingMode); + FailureOr result = + convertFloatValue(a, targetSemantics, llvmRoundingMode); + if (failed(result)) { + castStatus = false; + return a; + } + return *result; + }); +} + +bool arith::ConvertFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (!areValidCastInputsAndOutputs(inputs, outputs)) + return false; + auto srcType = getTypeIfLike(inputs.front()); + auto dstType = getTypeIfLike(outputs.front()); + if (!srcType || !dstType) + return false; + return srcType != dstType && + srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); +} + +LogicalResult arith::ConvertFOp::verify() { + auto srcType = cast(getElementTypeOrSelf(getIn().getType())); + auto dstType = cast(getElementTypeOrSelf(getType())); + if (srcType == dstType) + return emitError("result element type ") + << dstType << " must be different from operand element type " + << srcType; + if (srcType.getWidth() != dstType.getWidth()) + return emitError("result element type ") + << dstType << " must have the same bitwidth as operand element type " + << srcType; + return success(); +} + //===----------------------------------------------------------------------===// // ScalingTruncFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 47069906fa110..6a6016c4f5b16 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -377,6 +377,39 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) { // ----- +// CHECK-LABEL: @convertf_f16_to_bf16 +func.func @convertf_f16_to_bf16(%arg0 : f16) -> bf16 { +// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : f16 to f32 +// CHECK-NEXT: %[[TRUNC:.*]] = llvm.fptrunc %[[EXT]] : f32 to bf16 + %0 = arith.convertf %arg0 : f16 to bf16 +// CHECK-NEXT: return %[[TRUNC]] + return %0 : bf16 +} + +// ----- + +// CHECK-LABEL: @convertf_bf16_to_f16 +func.func @convertf_bf16_to_f16(%arg0 : bf16) -> f16 { +// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : bf16 to f32 +// CHECK-NEXT: %[[TRUNC:.*]] = llvm.fptrunc %[[EXT]] : f32 to f16 + %0 = arith.convertf %arg0 : bf16 to f16 +// CHECK-NEXT: return %[[TRUNC]] + return %0 : f16 +} + +// ----- + +// CHECK-LABEL: @convertf_vector +func.func @convertf_vector(%arg0 : vector<2xf16>) -> vector<2xbf16> { +// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : vector<2xf16> to vector<2xf32> +// CHECK-NEXT: %[[TRUNC:.*]] = llvm.fptrunc %[[EXT]] : vector<2xf32> to vector<2xbf16> + %0 = arith.convertf %arg0 : vector<2xf16> to vector<2xbf16> +// CHECK-NEXT: return %[[TRUNC]] + return %0 : vector<2xbf16> +} + +// ----- + // Check sign and zero extension and truncation of integers. // CHECK-LABEL: @integer_extension_and_truncation func.func @integer_extension_and_truncation(%arg0 : i3) { @@ -838,3 +871,4 @@ func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8 %3 = arith.cmpf oeq, %arg0, %arg3 : f32 return } + diff --git a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir index 86f0be81ce99e..bf1e8580a5b76 100644 --- a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir +++ b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir @@ -125,6 +125,21 @@ func.func @fptrunc_vector(%arg0 : vector<1x2x3xf64>) -> vector<1x2x3xf16> { return %0 : vector<1x2x3xf16> } +// CHECK-LABEL: @convertf +func.func @convertf_vector(%arg0 : vector<1x2x3xf16>) -> vector<1x2x3xbf16> { + // CHECK: llvm.mlir.poison : !llvm.array<1 x array<2 x vector<3xbf16>>> + // CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf16>>> + // CHECK: llvm.fpext %{{.*}} : vector<3xf16> to vector<3xf32> + // CHECK: llvm.fptrunc %{{.*}} : vector<3xf32> to vector<3xbf16> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xbf16>>> + // CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xf16>>> + // CHECK: llvm.fpext %{{.*}} : vector<3xf16> to vector<3xf32> + // CHECK: llvm.fptrunc %{{.*}} : vector<3xf32> to vector<3xbf16> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xbf16>>> + %0 = arith.convertf %arg0: vector<1x2x3xf16> to vector<1x2x3xbf16> + return %0 : vector<1x2x3xbf16> +} + // CHECK-LABEL: @trunci func.func @trunci_vector(%arg0 : vector<1x2x3xi64>) -> vector<1x2x3xi16> { // CHECK: llvm.mlir.poison : !llvm.array<1 x array<2 x vector<3xi16>>> diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 643e4e076e7c6..18665e2eb6f4a 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -3553,3 +3553,14 @@ func.func @truncf_neg_inf_to_finite_only_no_fold() -> f4E2M1FN { return %result : f4E2M1FN } +// ----- + +// CHECK-LABEL: @convertf_fold_f8 +// CHECK: %[[C:.*]] = arith.constant 2.000000e+00 : f8E5M2 +// CHECK: return %[[C]] +func.func @convertf_fold_f8() -> f8E5M2 { + %c = arith.constant 2.0 : f8E4M3FN + %result = arith.convertf %c : f8E4M3FN to f8E5M2 + return %result : f8E5M2 +} + diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir index 0ea614e0d4b97..96013a4fadde5 100644 --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -535,6 +535,54 @@ func.func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) { // ----- +func.func @convertf_same_type(%arg0 : f16) { + // expected-error@+1 {{are cast incompatible}} + %0 = arith.convertf %arg0 : f16 to f16 + return +} + +// ----- + +func.func @convertf_different_bitwidth(%arg0 : f16) { + // expected-error@+1 {{are cast incompatible}} + %0 = arith.convertf %arg0 : f16 to f32 + return +} + +// ----- + +func.func @convertf_different_bitwidth_trunc(%arg0 : f32) { + // expected-error@+1 {{are cast incompatible}} + %0 = arith.convertf %arg0 : f32 to f16 + return +} + +// ----- + +func.func @convertf_vec_same_type(%arg0 : vector<2xf16>) { + // expected-error@+1 {{are cast incompatible}} + %0 = arith.convertf %arg0 : vector<2xf16> to vector<2xf16> + return +} + +// ----- + +func.func @convertf_vec_different_bitwidth(%arg0 : vector<2xf16>) { + // expected-error@+1 {{are cast incompatible}} + %0 = arith.convertf %arg0 : vector<2xf16> to vector<2xf32> + return +} + +// ----- + +func.func @convertf_shape_mismatch(%arg0 : vector<2xf16>) { + // expected-error@+1 {{op requires the same shape for all operands and results}} + %0 = arith.convertf %arg0 : vector<2xf16> to vector<3xbf16> + return +} + +// ----- + func.func @sexti_index_as_operand(%arg0 : index) { // expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}} %0 = arith.extsi %arg0 : index to i128 @@ -1016,3 +1064,43 @@ func.func @index_castui_i0(%a: i0) -> index { %0 = arith.index_castui %a : i0 to index return %0 : index } + +// ----- + +func.func @convertf_same_type(%arg0 : f32) { + // expected-error @+1 {{are cast incompatible}} + %0 = arith.convertf %arg0 : f32 to f32 + return +} + +// ----- + +func.func @convertf_same_type_vec(%arg0 : vector<2xf16>) { + // expected-error @+1 {{are cast incompatible}} + %0 = arith.convertf %arg0 : vector<2xf16> to vector<2xf16> + return +} + +// ----- + +func.func @convertf_shape_mismatch(%arg0 : vector<2xf16>) { + // expected-error @+1 {{op requires the same shape for all operands and results}} + %0 = arith.convertf %arg0 : vector<2xf16> to vector<3xf32> + return +} + +// ----- + +func.func @convertf_int_input(%arg0 : i32) { + // expected-error @+1 {{op operand #0 must be floating-point-like, but got 'i32'}} + %0 = arith.convertf %arg0 : i32 to f32 + return +} + +// ----- + +func.func @convertf_int_output(%arg0 : f32) { + // expected-error @+1 {{op result #0 must be floating-point-like, but got 'i32'}} + %0 = arith.convertf %arg0 : f32 to i32 + return +} diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index 9765db69d6dd5..2c5371de9ff24 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -751,6 +751,41 @@ func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) { return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32 } +// CHECK-LABEL: test_convertf +func.func @test_convertf(%arg0 : f16) -> bf16 { + // CHECK: arith.convertf %arg0 : f16 to bf16 + %0 = arith.convertf %arg0 : f16 to bf16 + return %0 : bf16 +} + +// CHECK-LABEL: test_convertf_vector +func.func @test_convertf_vector(%arg0 : vector<8xf16>) -> vector<8xbf16> { + // CHECK: arith.convertf %arg0 : vector<8xf16> to vector<8xbf16> + %0 = arith.convertf %arg0 : vector<8xf16> to vector<8xbf16> + return %0 : vector<8xbf16> +} + +// CHECK-LABEL: test_convertf_scalable_vector +func.func @test_convertf_scalable_vector(%arg0 : vector<[8]xbf16>) -> vector<[8]xf16> { + // CHECK: arith.convertf %arg0 : vector<[8]xbf16> to vector<[8]xf16> + %0 = arith.convertf %arg0 : vector<[8]xbf16> to vector<[8]xf16> + return %0 : vector<[8]xf16> +} + +// CHECK-LABEL: test_convertf_tensor +func.func @test_convertf_tensor(%arg0 : tensor<8x8xf16>) -> tensor<8x8xbf16> { + // CHECK: arith.convertf %arg0 : tensor<8x8xf16> to tensor<8x8xbf16> + %0 = arith.convertf %arg0 : tensor<8x8xf16> to tensor<8x8xbf16> + return %0 : tensor<8x8xbf16> +} + +// CHECK-LABEL: test_convertf_rounding_mode +func.func @test_convertf_rounding_mode(%arg0 : bf16) -> f16 { + // CHECK: arith.convertf %arg0 to_nearest_even : bf16 to f16 + %0 = arith.convertf %arg0 to_nearest_even : bf16 to f16 + return %0 : f16 +} + // CHECK-LABEL: test_uitofp func.func @test_uitofp(%arg0 : i32) -> f32 { %0 = arith.uitofp %arg0 : i32 to f32 @@ -1228,3 +1263,4 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) { %4 = arith.trunci %arg0 overflow : i64 to i32 return } +