From c07effbd977e685e66548bfe63b6f8dad426df0a Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Thu, 20 Nov 2025 05:04:42 +0000 Subject: [PATCH 1/7] [MLIR][NVVM] Add missing rounding modes in fp16x2 conversions This change adds the `RN` and `RZ` rounding modes to the `convert.f32x2.to.f16x2` and `convert.f32x2.to.bf16x2` Ops. Tests are added `convert_fp16x2.mlir` and `nvvmir-invalid.mlir`. Tests with these Ops in `convert_stochastic_rounding.mlir` and `invalid-convert-stochastic-rounding.mlir` have been removed or modified. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 42 +++--- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 138 ++++++++++++++---- .../invalid-convert-stochastic-rounding.mlir | 26 +--- .../Target/LLVMIR/nvvm/convert_fp16x2.mlir | 87 +++++++++++ .../nvvm/convert_stochastic_rounding.mlir | 68 +-------- mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 46 ++++++ 6 files changed, 275 insertions(+), 132 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 6e3a92b5bde42..7a2cfb1fee5eb 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1912,45 +1912,51 @@ def NVVM_ConvertF4x2ToF16x2Op : // Base class for conversions from F32x2 to FPx2 formats // (F16x2, BF16x2) -// TODO: In separate PR, add .rn and .rz rounding variants for this conversion -// as currently only support .rs rounding mode class NVVM_ConvertF32x2ToFPx2OpBase : - NVVM_Op]>, + NVVM_Op, Results<(outs dstType:$dst)>, - Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits, - DefaultValuedAttr:$rnd, + Arguments<(ins F32:$src_hi, F32:$src_lo, Optional:$rbits, + DefaultValuedAttr:$rnd, DefaultValuedAttr:$sat, DefaultValuedAttr:$relu)> { - let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)"; + let summary = "Convert two F32 values to packed " # !tolower(dstFormat) # "."; let description = [{ - Converts two F32 values to packed }] # dstFormat # [{ format using stochastic - rounding (.rs) mode with randomness provided by the `rbits` parameter. The - `relu` attribute clamps negative results to 0. The `sat` attribute determines - saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands - `a` and `b` in the PTX ISA, respectively. + Converts two F32 values to packed }] # !tolower(dstFormat) # [{ format with + the specified rounding mode. The `src_hi` and `src_lo` parameters + correspond to operands `a` and `b` in the PTX ISA, respectively. + + The `rbits` parameter is required for stochastic rounding. + + The `relu` attribute clamps negative results to 0. + + The `sat` attribute determines saturation behavior. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) }]; - let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)"; + let assemblyFormat = "$src_hi `,` $src_lo (`,` $rbits^)? attr-dict `:` type($dst)"; let hasVerifier = 1; let extraClassDeclaration = [{ - llvm::Intrinsic::ID getIntrinsicID(); + static NVVM::IDArgPair + getIntrinsicIDAndArgs( + NVVM::ConvertF32x2To}] # dstFormat # [{Op &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); }]; string llvmBuilder = [{ - auto intId = op.getIntrinsicID(); - $dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits}); + auto [intId, args] = mlir::NVVM::ConvertF32x2To}] # dstFormat # + [{Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder); + $dst = createIntrinsicCall(builder, intId, args); }]; - } +} // F32x2 -> F16x2 with stochastic rounding -def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>; +def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"F16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>; // F32x2 -> BF16x2 with stochastic rounding -def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>; +def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>; // Base class for stochastic rounding conversions from F32x4 to FPx4 formats // (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 7ac427dbe3941..85118a190c14b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -391,16 +391,40 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() { //===----------------------------------------------------------------------===// LogicalResult ConvertF32x2ToF16x2Op::verify() { - if (getRnd() != FPRoundingMode::RS) - return emitOpError("Only RS rounding mode is supported for " + switch (getRnd()) { + case FPRoundingMode::RN: + case FPRoundingMode::RZ: + if (getRbits()) + return emitOpError("rbits not supported for RN and RZ rounding modes."); + break; + case FPRoundingMode::RS: + if (!getRbits()) + return emitOpError("rbits is required for RS rounding mode."); + break; + default: + return emitOpError("Only RN, RZ, and RS rounding modes are supported for " "conversions from f32x2 to f16x2."); + } + return success(); } LogicalResult ConvertF32x2ToBF16x2Op::verify() { - if (getRnd() != FPRoundingMode::RS) - return emitOpError("Only RS rounding mode is supported for " + switch (getRnd()) { + case FPRoundingMode::RN: + case FPRoundingMode::RZ: + if (getRbits()) + return emitOpError("rbits not supported for RN and RZ rounding modes."); + break; + case FPRoundingMode::RS: + if (!getRbits()) + return emitOpError("rbits is required for RS rounding mode."); + break; + default: + return emitOpError("Only RN, RZ, and RS rounding modes are supported for " "conversions from f32x2 to bf16x2."); + } + return success(); } @@ -2727,30 +2751,94 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \ }() -llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() { - bool hasRelu = getRelu(); - bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); +NVVM::IDArgPair +ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + static const llvm::Intrinsic::ID rndRNIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rn, + llvm::Intrinsic::nvvm_ff2f16x2_rn_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite, + }; + static const llvm::Intrinsic::ID rndRZIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rz, + llvm::Intrinsic::nvvm_ff2f16x2_rz_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite, + }; + static const llvm::Intrinsic::ID rndRSIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rs, + llvm::Intrinsic::nvvm_ff2f16x2_rs_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite, + }; - if (hasRelu && hasSatFinite) - return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite; - if (hasRelu) - return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu; - if (hasSatFinite) - return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite; - return llvm::Intrinsic::nvvm_ff2f16x2_rs; + bool hasRelu = op.getRelu(); + bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE); + unsigned idx = hasRelu | (hasSatFinite << 1); + + llvm::SmallVector args; + args.push_back(mt.lookupValue(op.getSrcHi())); + args.push_back(mt.lookupValue(op.getSrcLo())); + if (op.getRbits()) + args.push_back(mt.lookupValue(op.getRbits())); + + switch (op.getRnd()) { + case FPRoundingMode::RN: + return {rndRNIds[idx], std::move(args)}; + case FPRoundingMode::RZ: + return {rndRZIds[idx], std::move(args)}; + case FPRoundingMode::RS: + return {rndRSIds[idx], std::move(args)}; + default: + llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op"); + } } -llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() { - bool hasRelu = getRelu(); - bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); - - if (hasRelu && hasSatFinite) - return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite; - if (hasRelu) - return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu; - if (hasSatFinite) - return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite; - return llvm::Intrinsic::nvvm_ff2bf16x2_rs; +NVVM::IDArgPair +ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + static const llvm::Intrinsic::ID rndRNIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rn, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite, + }; + static const llvm::Intrinsic::ID rndRZIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rz, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite, + }; + static const llvm::Intrinsic::ID rndRSIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rs, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite, + }; + + bool hasRelu = op.getRelu(); + bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE); + unsigned idx = hasRelu | (hasSatFinite << 1); + + llvm::SmallVector args; + args.push_back(mt.lookupValue(op.getSrcHi())); + args.push_back(mt.lookupValue(op.getSrcLo())); + if (op.getRbits()) + args.push_back(mt.lookupValue(op.getRbits())); + + switch (op.getRnd()) { + case FPRoundingMode::RN: + return {rndRNIds[idx], std::move(args)}; + case FPRoundingMode::RZ: + return {rndRZIds[idx], std::move(args)}; + case FPRoundingMode::RS: + return {rndRSIds[idx], std::move(args)}; + default: + llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op"); + } } llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() { diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir index 35f5e1b3c8ba2..506b81e1e7048 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir @@ -2,35 +2,15 @@ // Test invalid target architecture (sm_100 instead of sm_100a) gpu.module @invalid_arch_sm_100 [#nvvm.target] { - func.func @convert_rs() { - %f1 = llvm.mlir.constant(1.0 : f32) : f32 - %f2 = llvm.mlir.constant(2.0 : f32) : f32 - %rbits = llvm.mlir.constant(0x12345678 : i32) : i32 - // expected-error@+1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}} - %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16> + func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) { + // expected-error@+1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}} + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN) return } } // ----- -// Test that operations require stochastic rounding mode -llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { - // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}} - %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> - llvm.return %res : vector<2xf16> -} - -// ----- - -llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { - // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}} - %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> - llvm.return %res : vector<2xbf16> -} - -// ----- - // Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2) llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir new file mode 100644 index 0000000000000..a4bece83f832a --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rn +llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, relu = true} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, relu = true, sat = #nvvm.sat_mode} : vector<2xf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rz +llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, relu = true} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, relu = true, sat = #nvvm.sat_mode} : vector<2xf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic +llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xf16> + + llvm.return +} + +// ----- + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn +llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, relu = true} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, relu = true, sat = #nvvm.sat_mode} : vector<2xbf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz +llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, relu = true} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, relu = true, sat = #nvvm.sat_mode} : vector<2xbf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic +llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir index b5bb22350dcd7..03abcddd96cb0 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir @@ -10,7 +10,7 @@ gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target] { %f1 = llvm.mlir.constant(1.0 : f32) : f32 %f2 = llvm.mlir.constant(2.0 : f32) : f32 %rbits = llvm.mlir.constant(0x12345678 : i32) : i32 - %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16> + %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> return } } @@ -21,77 +21,13 @@ gpu.module @valid_bf16x2_rs_sm_103a [#nvvm.target] { %f1 = llvm.mlir.constant(1.0 : f32) : f32 %f2 = llvm.mlir.constant(2.0 : f32) : f32 %rbits = llvm.mlir.constant(0 : i32) : i32 - %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits : vector<2xbf16> + %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> return } } // ----- -// Test F32x2 -> F16x2 with stochastic rounding (.rs) - -// CHECK-LABEL: @convert_f32x2_to_f16x2_rs -llvm.func @convert_f32x2_to_f16x2_rs(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { - // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits : vector<2xf16> - llvm.return %res : vector<2xf16> -} - -// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_satfinite -llvm.func @convert_f32x2_to_f16x2_rs_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { - // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {sat = #nvvm.sat_mode} : vector<2xf16> - llvm.return %res : vector<2xf16> -} - -// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_relu -llvm.func @convert_f32x2_to_f16x2_rs_relu(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { - // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true} : vector<2xf16> - llvm.return %res : vector<2xf16> -} - -// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_relu_satfinite -llvm.func @convert_f32x2_to_f16x2_rs_relu_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { - // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, sat = #nvvm.sat_mode} : vector<2xf16> - llvm.return %res : vector<2xf16> -} - -// ----- - -// Test F32x2 -> BF16x2 with stochastic rounding (.rs) - -// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs -llvm.func @convert_f32x2_to_bf16x2_rs(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { - // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits : vector<2xbf16> - llvm.return %res : vector<2xbf16> -} - -// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_satfinite -llvm.func @convert_f32x2_to_bf16x2_rs_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { - // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {sat = #nvvm.sat_mode} : vector<2xbf16> - llvm.return %res : vector<2xbf16> -} - -// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_relu -llvm.func @convert_f32x2_to_bf16x2_rs_relu(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { - // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true} : vector<2xbf16> - llvm.return %res : vector<2xbf16> -} - -// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_relu_satfinite -llvm.func @convert_f32x2_to_bf16x2_rs_relu_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { - // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, sat = #nvvm.sat_mode} : vector<2xbf16> - llvm.return %res : vector<2xbf16> -} - -// ----- - // Test F32x4 -> F8x4 (E4M3) with stochastic rounding (.rs) // CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index d5868ee73cc50..b6e175ee9789f 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -174,6 +174,52 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) { // ----- +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) { + // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to f16x2.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) { + // expected-error @below {{rbits is required for RS rounding mode.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { + // expected-error @below {{rbits not supported for RN and RZ rounding modes.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) { + // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to bf16x2.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) { + // expected-error @below {{rbits is required for RS rounding mode.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + llvm.return +} + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { + // expected-error @below {{rbits not supported for RN and RZ rounding modes.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + llvm.return +} + +// ----- + llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) { // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) From 447acafce02d8b89fe12d35259f60ad87bcc7bd3 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Fri, 21 Nov 2025 06:43:23 +0000 Subject: [PATCH 2/7] address comments --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 20 ++++---- .../LLVMIR/nvvm/invalid_convert_fp16x2.mlir | 47 +++++++++++++++++++ mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 46 ------------------ 3 files changed, 59 insertions(+), 54 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 85118a190c14b..4654ed49a0ca1 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -2755,19 +2755,19 @@ NVVM::IDArgPair ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { - static const llvm::Intrinsic::ID rndRNIds[] = { + static constexpr llvm::Intrinsic::ID rndRNIds[] = { llvm::Intrinsic::nvvm_ff2f16x2_rn, llvm::Intrinsic::nvvm_ff2f16x2_rn_relu, llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite, llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite, }; - static const llvm::Intrinsic::ID rndRZIds[] = { + static constexpr llvm::Intrinsic::ID rndRZIds[] = { llvm::Intrinsic::nvvm_ff2f16x2_rz, llvm::Intrinsic::nvvm_ff2f16x2_rz_relu, llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite, llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite, }; - static const llvm::Intrinsic::ID rndRSIds[] = { + static constexpr llvm::Intrinsic::ID rndRSIds[] = { llvm::Intrinsic::nvvm_ff2f16x2_rs, llvm::Intrinsic::nvvm_ff2f16x2_rs_relu, llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite, @@ -2776,7 +2776,9 @@ ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op, bool hasRelu = op.getRelu(); bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE); - unsigned idx = hasRelu | (hasSatFinite << 1); + // idx: bit-0 - relu + // bit-1 - satfinite + unsigned idx = (hasSatFinite << 1) | hasRelu; llvm::SmallVector args; args.push_back(mt.lookupValue(op.getSrcHi())); @@ -2800,19 +2802,19 @@ NVVM::IDArgPair ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { - static const llvm::Intrinsic::ID rndRNIds[] = { + static constexpr llvm::Intrinsic::ID rndRNIds[] = { llvm::Intrinsic::nvvm_ff2bf16x2_rn, llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu, llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite, llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite, }; - static const llvm::Intrinsic::ID rndRZIds[] = { + static constexpr llvm::Intrinsic::ID rndRZIds[] = { llvm::Intrinsic::nvvm_ff2bf16x2_rz, llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu, llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite, llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite, }; - static const llvm::Intrinsic::ID rndRSIds[] = { + static constexpr llvm::Intrinsic::ID rndRSIds[] = { llvm::Intrinsic::nvvm_ff2bf16x2_rs, llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu, llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite, @@ -2821,7 +2823,9 @@ ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, bool hasRelu = op.getRelu(); bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE); - unsigned idx = hasRelu | (hasSatFinite << 1); + // idx: bit-0 - relu + // bit-1 - satfinite + unsigned idx = (hasSatFinite << 1) | hasRelu; llvm::SmallVector args; args.push_back(mt.lookupValue(op.getSrcHi())); diff --git a/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir new file mode 100644 index 0000000000000..60571944e7bf8 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) { + // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to f16x2.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) { + // expected-error @below {{rbits is required for RS rounding mode.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { + // expected-error @below {{rbits not supported for RN and RZ rounding modes.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) { + // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to bf16x2.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) { + // expected-error @below {{rbits is required for RS rounding mode.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + llvm.return +} + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { + // expected-error @below {{rbits not supported for RN and RZ rounding modes.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index b6e175ee9789f..d5868ee73cc50 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -174,52 +174,6 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) { // ----- -llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) { - // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to f16x2.}} - %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> - llvm.return -} - -// ----- - -llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) { - // expected-error @below {{rbits is required for RS rounding mode.}} - %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> - llvm.return -} - -// ----- - -llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { - // expected-error @below {{rbits not supported for RN and RZ rounding modes.}} - %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> - llvm.return -} - -// ----- - -llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) { - // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to bf16x2.}} - %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> - llvm.return -} - -// ----- - -llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) { - // expected-error @below {{rbits is required for RS rounding mode.}} - %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> - llvm.return -} - -llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { - // expected-error @below {{rbits not supported for RN and RZ rounding modes.}} - %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> - llvm.return -} - -// ----- - llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) { // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) From 04d38d05524bf795750ddd2b9ec8a2fa024d84fb Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Fri, 21 Nov 2025 11:25:18 +0000 Subject: [PATCH 3/7] address comments --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 12 +++-- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 54 ++++++++++----------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 7a2cfb1fee5eb..7f42122587a1e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1915,7 +1915,8 @@ def NVVM_ConvertF4x2ToF16x2Op : class NVVM_ConvertF32x2ToFPx2OpBase : NVVM_Op, Results<(outs dstType:$dst)>, - Arguments<(ins F32:$src_hi, F32:$src_lo, Optional:$rbits, + Arguments<(ins F32:$src_hi, F32:$src_lo, + Optional:$random_bits, DefaultValuedAttr:$rnd, DefaultValuedAttr:$sat, DefaultValuedAttr:$relu)> { @@ -1925,7 +1926,12 @@ class NVVM_ConvertF32x2ToFPx2OpBaseemitOpError( + "random_bits not supported for RN and RZ rounding modes."); break; case FPRoundingMode::RS: - if (!getRbits()) - return emitOpError("rbits is required for RS rounding mode."); + if (!hasRandomBits) + return op->emitOpError("random_bits is required for RS rounding mode."); break; default: - return emitOpError("Only RN, RZ, and RS rounding modes are supported for " - "conversions from f32x2 to f16x2."); + return op->emitOpError( + "Only RN, RZ, and RS rounding modes are supported for " + "conversions from f32x2 to ") + << dstType << "."; } - return success(); } -LogicalResult ConvertF32x2ToBF16x2Op::verify() { - switch (getRnd()) { - case FPRoundingMode::RN: - case FPRoundingMode::RZ: - if (getRbits()) - return emitOpError("rbits not supported for RN and RZ rounding modes."); - break; - case FPRoundingMode::RS: - if (!getRbits()) - return emitOpError("rbits is required for RS rounding mode."); - break; - default: - return emitOpError("Only RN, RZ, and RS rounding modes are supported for " - "conversions from f32x2 to bf16x2."); - } +LogicalResult ConvertF32x2ToF16x2Op::verify() { + return verifyConvertF32x2ToFPx2Op("f16x2", getRnd(), getRandomBits(), *this); +} - return success(); +LogicalResult ConvertF32x2ToBF16x2Op::verify() { + return verifyConvertF32x2ToFPx2Op("bf16x2", getRnd(), getRandomBits(), *this); } LogicalResult ConvertF32x4ToF8x4Op::verify() { @@ -2774,8 +2768,9 @@ ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op, llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite, }; - bool hasRelu = op.getRelu(); - bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE); + unsigned hasRelu = op.getRelu() ? 1 : 0; + unsigned hasSatFinite = + (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; // idx: bit-0 - relu // bit-1 - satfinite unsigned idx = (hasSatFinite << 1) | hasRelu; @@ -2821,8 +2816,9 @@ ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite, }; - bool hasRelu = op.getRelu(); - bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE); + unsigned hasRelu = op.getRelu() ? 1 : 0; + unsigned hasSatFinite = + (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; // idx: bit-0 - relu // bit-1 - satfinite unsigned idx = (hasSatFinite << 1) | hasRelu; From 4f8cb48a702e31e17471ca8f0f00feeb41a99793 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Fri, 21 Nov 2025 11:27:38 +0000 Subject: [PATCH 4/7] fix comment --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 7f42122587a1e..6d919a743d15b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1958,10 +1958,10 @@ class NVVM_ConvertF32x2ToFPx2OpBase F16x2 with stochastic rounding +// F32x2 -> F16x2 def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"F16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>; -// F32x2 -> BF16x2 with stochastic rounding +// F32x2 -> BF16x2 def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>; // Base class for stochastic rounding conversions from F32x4 to FPx4 formats From 9b6595946693240010169617de91d7eed87d6a47 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Fri, 21 Nov 2025 12:11:34 +0000 Subject: [PATCH 5/7] fix errors --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index a458219a2c634..47db013ac5659 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -390,10 +390,10 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() { // Stochastic Rounding Conversion Ops //===----------------------------------------------------------------------===// -static LogicalResult verifyConvertF32x2ToFPx2Op(Twine dstType, - FPRoundingMode rnd, - bool hasRandomBits, - Operation *op) { +static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, + FPRoundingMode rnd, + bool hasRandomBits, + Operation *op) { switch (rnd) { case FPRoundingMode::RN: case FPRoundingMode::RZ: @@ -415,11 +415,13 @@ static LogicalResult verifyConvertF32x2ToFPx2Op(Twine dstType, } LogicalResult ConvertF32x2ToF16x2Op::verify() { - return verifyConvertF32x2ToFPx2Op("f16x2", getRnd(), getRandomBits(), *this); + return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(), + getRandomBits() ? true : false, *this); } LogicalResult ConvertF32x2ToBF16x2Op::verify() { - return verifyConvertF32x2ToFPx2Op("bf16x2", getRnd(), getRandomBits(), *this); + return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(), + getRandomBits() ? true : false, *this); } LogicalResult ConvertF32x4ToF8x4Op::verify() { @@ -2778,8 +2780,8 @@ ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op, llvm::SmallVector args; args.push_back(mt.lookupValue(op.getSrcHi())); args.push_back(mt.lookupValue(op.getSrcLo())); - if (op.getRbits()) - args.push_back(mt.lookupValue(op.getRbits())); + if (op.getRandomBits()) + args.push_back(mt.lookupValue(op.getRandomBits())); switch (op.getRnd()) { case FPRoundingMode::RN: @@ -2826,8 +2828,8 @@ ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, llvm::SmallVector args; args.push_back(mt.lookupValue(op.getSrcHi())); args.push_back(mt.lookupValue(op.getSrcLo())); - if (op.getRbits()) - args.push_back(mt.lookupValue(op.getRbits())); + if (op.getRandomBits()) + args.push_back(mt.lookupValue(op.getRandomBits())); switch (op.getRnd()) { case FPRoundingMode::RN: From fb0d052bac37ca80f6211e4c7ff29c2820e722a2 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Fri, 21 Nov 2025 12:24:08 +0000 Subject: [PATCH 6/7] update test --- mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir index 60571944e7bf8..37756c8737f11 100644 --- a/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir @@ -11,7 +11,7 @@ llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) { // ----- llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) { - // expected-error @below {{rbits is required for RS rounding mode.}} + // expected-error @below {{random_bits is required for RS rounding mode.}} %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> llvm.return } @@ -19,7 +19,7 @@ llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) { // ----- llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { - // expected-error @below {{rbits not supported for RN and RZ rounding modes.}} + // expected-error @below {{random_bits not supported for RN and RZ rounding modes.}} %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> llvm.return } @@ -35,13 +35,13 @@ llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) { // ----- llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) { - // expected-error @below {{rbits is required for RS rounding mode.}} + // expected-error @below {{random_bits is required for RS rounding mode.}} %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> llvm.return } llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { - // expected-error @below {{rbits not supported for RN and RZ rounding modes.}} + // expected-error @below {{random_bits not supported for RN and RZ rounding modes.}} %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> llvm.return } From 1139e5a802a5de9a216ad64f9ebbdef4da2df97a Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Fri, 21 Nov 2025 13:41:43 +0000 Subject: [PATCH 7/7] update verifier --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 28 ++++++++++++---------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 47db013ac5659..9cce91da654f7 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -394,23 +394,27 @@ static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op) { - switch (rnd) { - case FPRoundingMode::RN: - case FPRoundingMode::RZ: - if (hasRandomBits) - return op->emitOpError( - "random_bits not supported for RN and RZ rounding modes."); - break; - case FPRoundingMode::RS: - if (!hasRandomBits) - return op->emitOpError("random_bits is required for RS rounding mode."); - break; - default: + static constexpr FPRoundingMode validRndModes[] = { + FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS}; + + if (!llvm::is_contained(validRndModes, rnd)) { return op->emitOpError( "Only RN, RZ, and RS rounding modes are supported for " "conversions from f32x2 to ") << dstType << "."; } + + if (rnd == FPRoundingMode::RS) { + if (!hasRandomBits) { + return op->emitOpError("random_bits is required for RS rounding mode."); + } + } else { + if (hasRandomBits) { + return op->emitOpError( + "random_bits not supported for RN and RZ rounding modes."); + } + } + return success(); }