From 1837100b7b12a85cbadd8f69a197484b57737533 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Mon, 15 Sep 2025 15:11:16 +0530 Subject: [PATCH 1/4] [MLIR][NVVM] Update convert Ops to use builtin types This change updates the `convert.f32x2.to.f6x2`, `convert.f32x2.to.f8x2`, `convert.f16x2.to.f8x2`, and `convert.bf16x2.to.f8x2` Ops to use builtin types for the destination types as a `TypeAttr` instead of custom enums. The corresponding tests are updated to reflect the changes in the assembly format. --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 63 +++----- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 148 +++++++++++------- .../Target/LLVMIR/nvvm/convert_fp6x2.mlir | 8 +- .../Target/LLVMIR/nvvm/convert_fp8x2.mlir | 44 +++--- mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 34 ++-- 5 files changed, 153 insertions(+), 144 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 8537c7030aa8f..c540c5ccf50bf 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Dialect/LLVMIR/LLVMTypes.td" +include "mlir/IR/CommonAttrConstraints.td" def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>; def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; @@ -1258,18 +1259,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> { }]; } -def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">; -def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">; - -def ConvertFP6Type : I32EnumAttr<"ConvertFP6Type", "NVVM ConvertFP6Type kind", - [ConvertFP6E2M3, ConvertFP6E3M2]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::NVVM"; -} -def ConvertFP6TypeAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { let summary = "Convert a pair of float inputs to f6x2"; let description = [{ @@ -1290,19 +1279,20 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); let arguments = (ins - ConvertFP6TypeAttr:$type, F32:$a, F32:$b, - DefaultValuedAttr:$relu); - let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)"; + DefaultValuedAttr:$relu, + TypeAttr:$dstTy); + let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`"; + let hasVerifier = 1; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP6Type, + static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu); }]; string llvmBuilder = [{ - auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($type, $relu); + auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($dstTy, $relu); llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b}); if(op.getDst().getType().isInteger(16)) $dst = packedI16; @@ -1312,19 +1302,6 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { }]; } -def ConvertFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">; -def ConvertFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">; -def ConvertFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">; - -def ConvertFP8Type : I32EnumAttr<"ConvertFP8Type", "NVVM ConvertFP8Type kind", - [ConvertFP8E4M3, ConvertFP8E5M2, ConvertFP8UE8M0]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::NVVM"; -} -def ConvertFP8TypeAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> { let summary = "Convert a pair of float inputs to f8x2"; let description = [{ @@ -1346,23 +1323,23 @@ def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> { let hasVerifier = 1; let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); let arguments = (ins - ConvertFP8TypeAttr:$type, F32:$a, F32:$b, DefaultValuedAttr:$rnd, DefaultValuedAttr:$sat, - DefaultValuedAttr:$relu); - let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)"; + DefaultValuedAttr:$relu, + TypeAttr:$dstTy); + let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`"; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to, + static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd, NVVM::SaturationMode sat, bool hasRelu); }]; string llvmBuilder = [{ - auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu); + auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu); llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b}); if(op.getDst().getType().isInteger(16)) $dst = packedI16; @@ -1394,18 +1371,18 @@ def NVVM_ConvertF16x2ToF8x2Op : NVVM_Op<"convert.f16x2.to.f8x2"> { let hasVerifier = 1; let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); let arguments = (ins - ConvertFP8TypeAttr:$type, VectorOfLengthAndType<[2], [F16]>:$a, - DefaultValuedAttr:$relu); - let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)"; + DefaultValuedAttr:$relu, + TypeAttr:$dstTy); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`"; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to, + static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu); }]; string llvmBuilder = [{ - auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($type, $relu); + auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($dstTy, $relu); llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a}); if(op.getDst().getType().isInteger(16)) $dst = packedI16; @@ -1437,11 +1414,11 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> { let hasVerifier = 1; let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); let arguments = (ins - ConvertFP8TypeAttr:$type, VectorOfLengthAndType<[2], [BF16]>:$a, DefaultValuedAttr:$rnd, - DefaultValuedAttr:$sat); - let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)"; + DefaultValuedAttr:$sat, + TypeAttr:$dstTy); + let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`"; let extraClassDeclaration = [{ static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd, diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 77ec1ebde3109..28fa3f2a098e0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -189,6 +189,14 @@ LogicalResult ConvertFloatToTF32Op::verify() { return success(); } +LogicalResult ConvertF32x2ToF6x2Op::verify() { + if (!llvm::isa(getDstTy())) { + return emitError("Only f6E2M3FN and f6E3M2FN types are supported for " + "ConvertF32x2ToF6x2Op."); + } + return success(); +} + LogicalResult ConvertF32x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; using SatMode = NVVM::SaturationMode; @@ -200,41 +208,52 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() { bool hasRelu = getRelu(); - switch (getType()) { - case ConvertFP8Type::E4M3: - case ConvertFP8Type::E5M2: - if (!isRoundingModeRN) - return emitOpError("Only RN rounding mode is supported for conversions " - "from f32x2 to .e4m3x2 or .e5m2x2 types"); - if (!isSatFinite) - return emitOpError("Only SATFINITE saturation mode is supported for " - "conversions from f32x2 to .e4m3x2 or .e5m2x2 types"); - break; - case ConvertFP8Type::UE8M0: - if (!(isRoundingModeRZ || isRoundingModeRP)) - return emitOpError("Only RZ or RP rounding modes are supported for " - "conversions from f32x2 to .ue8m0x2 type"); - if (hasRelu) - return emitOpError("relu not supported for conversions to .ue8m0x2 type"); - break; - } - return success(); + return llvm::TypeSwitch(getDstTy()) + .Case( + [&](mlir::Type) -> LogicalResult { + if (!isRoundingModeRN) { + return emitOpError( + "Only RN rounding mode is supported for conversions from " + "f32x2 to f8E4M3FNx2 or f8E5M2x2 types"); + } + if (!isSatFinite) { + return emitOpError( + "Only SATFINITE saturation mode is supported for conversions " + "from f32x2 to f8E4M3FNx2 or f8E5M2x2 types"); + } + return success(); + }) + .Case([&](mlir::Type) -> LogicalResult { + if (!(isRoundingModeRZ || isRoundingModeRP)) { + return emitOpError("Only RZ or RP rounding modes are supported for " + "conversions from f32x2 to f8E8M0FNUx2 type"); + } + if (hasRelu) { + return emitOpError( + "relu not supported for conversions to f8E8M0FNUx2 type"); + } + return success(); + }) + .Default([this](mlir::Type) { + return emitOpError("Only f8e4m3fn, f8e5m2, and f8e8m0fnu types are " + "supported for conversions from f32x2 to f8x2"); + }); } LogicalResult ConvertF16x2ToF8x2Op::verify() { - if (getType() == ConvertFP8Type::UE8M0) - return emitOpError("Only .e4m3 or .e5m2 types are supported for " + if (!llvm::isa(getDstTy())) { + return emitOpError("Only f8E4M3FN or f8E5M2 types are supported for " "conversions from f16x2 to f8x2."); - + } return success(); } LogicalResult ConvertBF16x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; - if (getType() != ConvertFP8Type::UE8M0) - return emitOpError( - "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2."); + if (!llvm::isa(getDstTy())) + return emitOpError("Only f8E8M0FNU type is supported for conversions from " + "bf16x2 to f8x2."); auto rnd = getRnd(); if (!(rnd == RndMode::RZ || rnd == RndMode::RP)) @@ -1714,15 +1733,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite -llvm::Intrinsic::ID -ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) { - switch (type) { - case NVVM::ConvertFP6Type::E2M3: - return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); - case NVVM::ConvertFP6Type::E3M2: - return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); - } - llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op"); +llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy, + bool hasRelu) { + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float6E2M3FNType) { + return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); + }) + .Case([&](mlir::Float6E3M2FNType) { + return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \ @@ -1734,41 +1757,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) { : llvm::Intrinsic::nvvm_ff_to_##type##_rn llvm::Intrinsic::ID -ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, - NVVM::FPRoundingMode rnd, +ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd, NVVM::SaturationMode sat, bool hasRelu) { bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ); bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP); - switch (type) { - case NVVM::ConvertFP8Type::E4M3: - return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); - case NVVM::ConvertFP8Type::E5M2: - return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); - case NVVM::ConvertFP8Type::UE8M0: - if (hasRoundingModeRZ) - return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); - else if (hasRoundingModeRP) - return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite); - } - llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op"); + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float8E4M3FNType) { + return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); + }) + .Case([&](mlir::Float8E5M2Type) { + return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); + }) + .Case([&](mlir::Float8E8M0FNUType) { + if (hasRoundingModeRZ) + return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); + else if (hasRoundingModeRP) + return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite); + + llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op"); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_F16x2_TO_F8X2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \ : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn -llvm::Intrinsic::ID -ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) { - switch (type) { - case NVVM::ConvertFP8Type::E4M3: - return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); - case NVVM::ConvertFP8Type::E5M2: - return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); - default: - llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op"); - } +llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, + bool hasRelu) { + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float8E4M3FNType) { + return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); + }) + .Case([&](mlir::Float8E5M2Type) { + return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \ diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir index 04163b578aa02..99289923b58b1 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir @@ -3,9 +3,9 @@ // CHECK-LABEL: @convert_f32x2_to_fp6x2_packed llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) { //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) - %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 + %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN) //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) - %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 + %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN) llvm.return } @@ -13,9 +13,9 @@ llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) { llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) { //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8> - %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> + %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E2M3FN) //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> - %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> + %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN) llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir index 4a15efb9e805c..de21826445afb 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir @@ -5,31 +5,31 @@ // CHECK-LABEL: @convert_f32x2_to_f8x2_e4m3 llvm.func @convert_f32x2_to_f8x2_e4m3(%srcA : f32, %srcB : f32) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}}) - %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}}) - %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) llvm.return } // CHECK-LABEL: @convert_f32x2_to_f8x2_e5m2 llvm.func @convert_f32x2_to_f8x2_e5m2(%srcA : f32, %srcB : f32) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}}) - %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E5M2) // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}}) - %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E5M2) llvm.return } // CHECK-LABEL: @convert_f32x2_to_f8x2_ue8m0 llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}}) - %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : i16 + %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}}) - %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : i16 + %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode} : i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz.satfinite(float %{{.*}}, float %{{.*}}) - %res3 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res3 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp.satfinite(float %{{.*}}, float %{{.*}}) - %res4 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + %res4 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E8M0FNU) llvm.return } @@ -37,10 +37,10 @@ llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) { llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) { // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> - %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xi8> + %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xi8> (f8E4M3FN) // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> - %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xi8> + %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xi8> (f8E4M3FN) llvm.return } @@ -49,18 +49,18 @@ llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) { // CHECK-LABEL: @convert_f16x2_to_f8x2_e4m3 llvm.func @convert_f16x2_to_f8x2_e4m3(%src : vector<2xf16>) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}}) - %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 + %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E4M3FN) // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}}) - %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 + %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 (f8E4M3FN) llvm.return } // CHECK-LABEL: @convert_f16x2_to_f8x2_e5m2 llvm.func @convert_f16x2_to_f8x2_e5m2(%src : vector<2xf16>) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}}) - %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 + %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E5M2) // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}}) - %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 + %res2 = nvvm.convert.f16x2.to.f8x2 %src {relu = true} : vector<2xf16> -> i16 (f8E5M2) llvm.return } @@ -68,10 +68,10 @@ llvm.func @convert_f16x2_to_f8x2_e5m2(%src : vector<2xf16>) { llvm.func @convert_f16x2_to_f8x2_vector_return(%src : vector<2xf16>) { // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> - %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> + %res1 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> (f8E4M3FN) // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> - %res2 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> + %res2 = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> vector<2xi8> (f8E5M2) llvm.return } @@ -80,13 +80,13 @@ llvm.func @convert_f16x2_to_f8x2_vector_return(%src : vector<2xf16>) { // CHECK-LABEL: @convert_bf16x2_to_f8x2_ue8m0 llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) { // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}}) - %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}}) - %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz.satfinite(<2 x bfloat> %{{.*}}) - %res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 + %res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}}) - %res4 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 + %res4 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) llvm.return } @@ -94,9 +94,9 @@ llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) { llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) { // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> - %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> vector<2xi8> + %res1 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU) // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> - %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> + %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU) llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index b35a6dbcca286..8d4a32095c396 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -175,64 +175,64 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) { // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) { - // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) llvm.return } // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e5m2(%a : f32, %b : f32) { - // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E5M2) llvm.return } // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_ue8m0(%a : f32, %b : f32) { - // expected-error @below {{Only RZ or RP rounding modes are supported for conversions from f32x2 to .ue8m0x2 type}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : i16 + // expected-error @below {{Only RZ or RP rounding modes are supported for conversions from f32x2 to f8E8M0FNUx2 type}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : i16 (f8E8M0FNU) llvm.return } // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e4m3(%a : f32, %b : f32) { - // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) llvm.return } // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e5m2(%a : f32, %b : f32) { - // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 + // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E5M2) llvm.return } // ----- llvm.func @nvvm_cvt_float_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) { - // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}} - %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, relu = true} : i16 + // expected-error @below {{relu not supported for conversions to f8E8M0FNUx2 type}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, relu = true} : i16 (f8E8M0FNU) llvm.return } // ----- llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) { - // expected-error @below {{Only .e4m3 or .e5m2 types are supported for conversions from f16x2 to f8x2.}} - %res = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 + // expected-error @below {{Only f8E4M3FN or f8E5M2 types are supported for conversions from f16x2 to f8x2.}} + %res = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E8M0FNU) llvm.return } // ----- llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) { - // expected-error @below {{Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.}} - %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + // expected-error @below {{Only f8E8M0FNU type is supported for conversions from bf16x2 to f8x2.}} + %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E4M3FN) llvm.return } @@ -240,7 +240,7 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) { llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) { // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}} - %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 + %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) llvm.return } From c35df910f5f9f32e6f3075785beac42aab562654 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Tue, 23 Sep 2025 13:17:36 +0530 Subject: [PATCH 2/4] fix error messages and use get methods for type names --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 49 +++++++++++++-------- mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 24 ++++++---- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 28fa3f2a098e0..bedc4e8e40e50 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -191,8 +191,10 @@ LogicalResult ConvertFloatToTF32Op::verify() { LogicalResult ConvertF32x2ToF6x2Op::verify() { if (!llvm::isa(getDstTy())) { - return emitError("Only f6E2M3FN and f6E3M2FN types are supported for " - "ConvertF32x2ToF6x2Op."); + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(getContext()) << " and " + << mlir::Float6E3M2FNType::get(getContext()) + << " types are supported for conversions from f32x2 to f6x2."; } return success(); } @@ -212,38 +214,48 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() { .Case( [&](mlir::Type) -> LogicalResult { if (!isRoundingModeRN) { - return emitOpError( - "Only RN rounding mode is supported for conversions from " - "f32x2 to f8E4M3FNx2 or f8E5M2x2 types"); + return emitOpError("Only RN rounding mode is supported for " + "conversions from f32x2 to ") + << mlir::Float8E4M3FNType::get(getContext()) << " and " + << mlir::Float8E5M2Type::get(getContext()) << " types"; } if (!isSatFinite) { - return emitOpError( - "Only SATFINITE saturation mode is supported for conversions " - "from f32x2 to f8E4M3FNx2 or f8E5M2x2 types"); + return emitOpError("Only SATFINITE saturation mode is supported " + "for conversions " + "from f32x2 to ") + << mlir::Float8E4M3FNType::get(getContext()) << " and " + << mlir::Float8E5M2Type::get(getContext()) << " types"; } return success(); }) .Case([&](mlir::Type) -> LogicalResult { if (!(isRoundingModeRZ || isRoundingModeRP)) { - return emitOpError("Only RZ or RP rounding modes are supported for " - "conversions from f32x2 to f8E8M0FNUx2 type"); + return emitOpError("Only RZ and RP rounding modes are supported for " + "conversions from f32x2 to ") + << mlir::Float8E8M0FNUType::get(getContext()) << " type"; } if (hasRelu) { - return emitOpError( - "relu not supported for conversions to f8E8M0FNUx2 type"); + return emitOpError("relu not supported for conversions to ") + << mlir::Float8E8M0FNUType::get(getContext()) << " type"; } return success(); }) .Default([this](mlir::Type) { - return emitOpError("Only f8e4m3fn, f8e5m2, and f8e8m0fnu types are " - "supported for conversions from f32x2 to f8x2"); + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(getContext()) << ", " + << mlir::Float8E5M2Type::get(getContext()) << ", and " + << mlir::Float8E8M0FNUType::get(getContext()) + << " types are " + "supported for conversions from f32x2 to f8x2"; }); } LogicalResult ConvertF16x2ToF8x2Op::verify() { if (!llvm::isa(getDstTy())) { - return emitOpError("Only f8E4M3FN or f8E5M2 types are supported for " - "conversions from f16x2 to f8x2."); + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(getContext()) << " and " + << mlir::Float8E5M2Type::get(getContext()) + << " types are supported for conversions from f16x2 to f8x2."; } return success(); } @@ -252,8 +264,9 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; if (!llvm::isa(getDstTy())) - return emitOpError("Only f8E8M0FNU type is supported for conversions from " - "bf16x2 to f8x2."); + return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext()) + << " type is supported for conversions from " + "bf16x2 to f8x2."; auto rnd = getRnd(); if (!(rnd == RndMode::RZ || rnd == RndMode::RP)) diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 8d4a32095c396..15ab66d6c511e 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -175,7 +175,7 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) { // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) { - // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}} + // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) llvm.return } @@ -183,7 +183,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) { // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e5m2(%a : f32, %b : f32) { - // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}} + // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E5M2) llvm.return } @@ -191,7 +191,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e5m2(%a : f32, %b : f32) { // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_ue8m0(%a : f32, %b : f32) { - // expected-error @below {{Only RZ or RP rounding modes are supported for conversions from f32x2 to f8E8M0FNUx2 type}} + // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from f32x2 to 'f8E8M0FNU' type}} %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode} : i16 (f8E8M0FNU) llvm.return } @@ -199,7 +199,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_ue8m0(%a : f32, %b : f32) { // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e4m3(%a : f32, %b : f32) { - // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}} + // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) llvm.return } @@ -207,7 +207,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e4m3(%a : f32, %b : f32) { // ----- llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e5m2(%a : f32, %b : f32) { - // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to f8E4M3FNx2 or f8E5M2x2 types}} + // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E5M2) llvm.return } @@ -215,7 +215,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e5m2(%a : f32, %b : f32) { // ----- llvm.func @nvvm_cvt_float_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) { - // expected-error @below {{relu not supported for conversions to f8E8M0FNUx2 type}} + // expected-error @below {{relu not supported for conversions to 'f8E8M0FNU' type}} %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, relu = true} : i16 (f8E8M0FNU) llvm.return } @@ -223,7 +223,7 @@ llvm.func @nvvm_cvt_float_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) { // ----- llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) { - // expected-error @below {{Only f8E4M3FN or f8E5M2 types are supported for conversions from f16x2 to f8x2.}} + // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f16x2 to f8x2.}} %res = nvvm.convert.f16x2.to.f8x2 %src : vector<2xf16> -> i16 (f8E8M0FNU) llvm.return } @@ -231,7 +231,7 @@ llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) { // ----- llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) { - // expected-error @below {{Only f8E8M0FNU type is supported for conversions from bf16x2 to f8x2.}} + // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from bf16x2 to f8x2.}} %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E4M3FN) llvm.return } @@ -246,6 +246,14 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) { // ----- +llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) { + // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x2 to f6x2.}} + %res = nvvm.convert.f32x2.to.f6x2 %a, %b : i16 (f8E8M0FNU) + llvm.return +} + +// ----- + llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) { // expected-error @below {{cache eviction priority supported only for cache level L2}} nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1> From 04aef78f52a9fcabd43ce47e4e386e3a96da3999 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Tue, 23 Sep 2025 13:32:33 +0530 Subject: [PATCH 3/4] clean up --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 32 +++++++++++++--------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index bedc4e8e40e50..85f6a2d6c19e4 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -190,10 +190,12 @@ LogicalResult ConvertFloatToTF32Op::verify() { } LogicalResult ConvertF32x2ToF6x2Op::verify() { + llvm::LLVMContext &ctx = getContext(); + if (!llvm::isa(getDstTy())) { return emitOpError("Only ") - << mlir::Float6E2M3FNType::get(getContext()) << " and " - << mlir::Float6E3M2FNType::get(getContext()) + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) << " types are supported for conversions from f32x2 to f6x2."; } return success(); @@ -210,21 +212,23 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() { bool hasRelu = getRelu(); + llvm::LLVMContext &ctx = getContext(); + return llvm::TypeSwitch(getDstTy()) .Case( [&](mlir::Type) -> LogicalResult { if (!isRoundingModeRN) { return emitOpError("Only RN rounding mode is supported for " "conversions from f32x2 to ") - << mlir::Float8E4M3FNType::get(getContext()) << " and " - << mlir::Float8E5M2Type::get(getContext()) << " types"; + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types"; } if (!isSatFinite) { return emitOpError("Only SATFINITE saturation mode is supported " "for conversions " "from f32x2 to ") - << mlir::Float8E4M3FNType::get(getContext()) << " and " - << mlir::Float8E5M2Type::get(getContext()) << " types"; + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types"; } return success(); }) @@ -232,29 +236,31 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() { if (!(isRoundingModeRZ || isRoundingModeRP)) { return emitOpError("Only RZ and RP rounding modes are supported for " "conversions from f32x2 to ") - << mlir::Float8E8M0FNUType::get(getContext()) << " type"; + << mlir::Float8E8M0FNUType::get(ctx) << " type"; } if (hasRelu) { return emitOpError("relu not supported for conversions to ") - << mlir::Float8E8M0FNUType::get(getContext()) << " type"; + << mlir::Float8E8M0FNUType::get(ctx) << " type"; } return success(); }) .Default([this](mlir::Type) { return emitOpError("Only ") - << mlir::Float8E4M3FNType::get(getContext()) << ", " - << mlir::Float8E5M2Type::get(getContext()) << ", and " - << mlir::Float8E8M0FNUType::get(getContext()) + << mlir::Float8E4M3FNType::get(ctx) << ", " + << mlir::Float8E5M2Type::get(ctx) << ", and " + << mlir::Float8E8M0FNUType::get(ctx) << " types are " "supported for conversions from f32x2 to f8x2"; }); } LogicalResult ConvertF16x2ToF8x2Op::verify() { + llvm::LLVMContext &ctx = getContext(); + if (!llvm::isa(getDstTy())) { return emitOpError("Only ") - << mlir::Float8E4M3FNType::get(getContext()) << " and " - << mlir::Float8E5M2Type::get(getContext()) + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types are supported for conversions from f16x2 to f8x2."; } return success(); From 74aef83a744497e6feb92993e6db0d991e491c41 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Tue, 23 Sep 2025 13:39:56 +0530 Subject: [PATCH 4/4] fix errors --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 85f6a2d6c19e4..a04741e0b5ab2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -190,7 +190,7 @@ LogicalResult ConvertFloatToTF32Op::verify() { } LogicalResult ConvertF32x2ToF6x2Op::verify() { - llvm::LLVMContext &ctx = getContext(); + mlir::MLIRContext *ctx = getContext(); if (!llvm::isa(getDstTy())) { return emitOpError("Only ") @@ -212,7 +212,7 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() { bool hasRelu = getRelu(); - llvm::LLVMContext &ctx = getContext(); + mlir::MLIRContext *ctx = getContext(); return llvm::TypeSwitch(getDstTy()) .Case( @@ -244,7 +244,7 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() { } return success(); }) - .Default([this](mlir::Type) { + .Default([&](mlir::Type) { return emitOpError("Only ") << mlir::Float8E4M3FNType::get(ctx) << ", " << mlir::Float8E5M2Type::get(ctx) << ", and " @@ -255,7 +255,7 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() { } LogicalResult ConvertF16x2ToF8x2Op::verify() { - llvm::LLVMContext &ctx = getContext(); + mlir::MLIRContext *ctx = getContext(); if (!llvm::isa(getDstTy())) { return emitOpError("Only ")