From d2ebae85ca5e44728b8117d1f47c5d512ce960b1 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Wed, 8 Oct 2025 06:26:01 +0000 Subject: [PATCH 1/6] [MLIR][NVVM] Add support for converting fp4/6/8 to fp16x2 This change adds the following NVVM dialect Ops for converting fp4/6/8 to fp16x2: - convert.f4x2.to.f16x2 - convert.f6x2.to.f16x2 - convert.f8x2.to.f16x2 - convert.f8x2.to.bf16x2 Tests are added in `convert_fp4x2.mlir`, `convert_fp6x2.mlir`, and `convert_fp8x2.mlir`. PTX Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 107 ++++++++++++++ mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 139 ++++++++++++++++++ .../Target/LLVMIR/nvvm/convert_fp4x2.mlir | 14 ++ .../Target/LLVMIR/nvvm/convert_fp6x2.mlir | 24 +++ .../Target/LLVMIR/nvvm/convert_fp8x2.mlir | 34 +++++ 5 files changed, 318 insertions(+) create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index e2a0331542742..5020af3992173 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1832,6 +1832,113 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> { }]; } +class NVVM_ConvertF8x2ToFP16x2Op_Base +: NVVM_Op<"convert.f8x2.to." # !tolower(dstType) # "x2"> { + let summary = "Convert a pair of f8 inputs to " # !tolower(dstType) # "x2"; + let description = [{ + This Op converts the given f8 inputs in a i8x2 vector to }] # !tolower(dstType) # [{. + + The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements. + + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + + For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + let results = (outs VectorOfLengthAndType<[2], [!cast(dstType)]>:$dst); + let arguments = !if(!eq(dstType, "F16"), + (ins VectorOfLengthAndType<[2], [I8]>:$src, + DefaultValuedAttr:$relu, + TypeAttr:$srcType), + (ins VectorOfLengthAndType<[2], [I8]>:$src, + TypeAttr:$srcType)); + let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [intId, args] = + NVVM::ConvertF8x2To}] # dstType # [{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); + $dst = createIntrinsicCall(builder, intId, args); + }]; +} + +def NVVM_ConvertF8x2ToF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"F16">; +def NVVM_ConvertF8x2ToBF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"BF16">; + + +def NVVM_ConvertF6x2ToF16x2Op : NVVM_Op<"convert.f6x2.to.f16x2"> { + let summary = "Convert a pair of f6 inputs to f16x2"; + let description = [{ + This Op converts the given f6 inputs in a i8x2 vector to f16x2. + + The result `dst` is represented as a vector of f16 elements. + + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + + For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst); + let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src, + DefaultValuedAttr:$relu, + TypeAttr:$srcType); + let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [intId, args] = + NVVM::ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); + $dst = createIntrinsicCall(builder, intId, args); + }]; +} + +def NVVM_ConvertF4x2ToF16x2Op : NVVM_Op<"convert.f4x2.to.f16x2"> { + let summary = "Convert a pair of f4 inputs to f16x2"; + let description = [{ + This Op converts the given f4 inputs packed in an i8 to f16x2. + + The result `dst` is represented as a vector of f16 elements. The value + converted from the lower 4 bits of `src` is stored in the first element of + `dst` and the value converted from the upper 4 bits of `src` is stored in + the second element of `dst`. + + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + + For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst); + let arguments = (ins I8:$src, + DefaultValuedAttr:$relu, + TypeAttr:$srcType); + let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [intId, args] = + NVVM::ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); + $dst = createIntrinsicCall(builder, intId, args); + }]; +} + //===----------------------------------------------------------------------===// // NVVM MMA Ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 7f419a062201d..bd38db52179a6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -309,6 +309,51 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() { return success(); } +LogicalResult ConvertF8x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa(getSrcType())) + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f8x2 to f16x2."; + + return success(); +} + +LogicalResult ConvertF8x2ToBF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + if (!llvm::isa(getSrcType())) + return emitOpError("Only ") + << mlir::Float8E8M0FNUType::get(ctx) + << " type is supported for conversions from f8x2 to bf16x2."; + + return success(); +} + +LogicalResult ConvertF6x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa(getSrcType())) + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f6x2 to f16x2."; + + return success(); +} + +LogicalResult ConvertF4x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa(getSrcType())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f4x2 to f16x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -2106,6 +2151,100 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch(curOp.getSrcType()) + .Case([&](Float8E4M3FNType type) { + return hasRelu + ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn; + }) + .Case([&](Float8E5M2Type type) { + return hasRelu + ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast(op); + + llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2; + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch(curOp.getSrcType()) + .Case([&](Float6E2M3FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn; + }) + .Case([&](Float6E3M2FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch(curOp.getSrcType()) + .Case([&](Float4E2M1FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *extendedI16 = + builder.CreateZExt(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {extendedI16}}; +} + llvm::Intrinsic::ID Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir new file mode 100644 index 0000000000000..e43dea4065c08 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// ----- + +// CHECK-LABEL: @convert_f4x2_to_f16x2 +llvm.func @convert_f4x2_to_f16x2(%src : i8) { + // CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %[[res1]]) + %res1 = nvvm.convert.f4x2.to.f16x2 %src : i8 (f4E2M1FN)-> vector<2xf16> + // CHECK: %[[res2:.*]] = zext i8 %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %[[res2]]) + %res2 = nvvm.convert.f4x2.to.f16x2 %src {relu = true} : i8 (f4E2M1FN)-> vector<2xf16> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir index 99289923b58b1..61a7a48f40d54 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir @@ -19,3 +19,27 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) { %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN) llvm.return } + +// ----- + +// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3 +llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) { + // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn(i16 %[[res1]]) + %res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E2M3FN)-> vector<2xf16> + // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn.relu(i16 %[[res2]]) + %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E2M3FN)-> vector<2xf16> + llvm.return +} + +// CHECK-LABEL: @convert_f6x2_to_f16x2_e3m2 +llvm.func @convert_f6x2_to_f16x2_e3m2(%src : vector<2xi8>) { + // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn(i16 %[[res1]]) + %res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E3M2FN)-> vector<2xf16> + // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 %[[res2]]) + %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN)-> vector<2xf16> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir index de21826445afb..4afe901bc08e9 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir @@ -100,3 +100,37 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) { %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU) llvm.return } + +// ----- + +// CHECK-LABEL: @convert_f8x2_to_f16x2 +llvm.func @convert_f8x2_to_f16x2_e4m3(%src : vector<2xi8>) { + // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %[[res1]]) + %res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN)-> vector<2xf16> + // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %[[res2]]) + %res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E4M3FN)-> vector<2xf16> + llvm.return +} + +// CHECK-LABEL: @convert_f8x2_to_f16x2_e5m2 +llvm.func @convert_f8x2_to_f16x2_e5m2(%src : vector<2xi8>) { + // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %[[res1]]) + %res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E5M2)-> vector<2xf16> + // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %[[res2]]) + %res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E5M2)-> vector<2xf16> + llvm.return +} + +// ----- + +// CHECK-LABEL: @convert_f8x2_to_bf16x2_ue8m0 +llvm.func @convert_f8x2_to_bf16x2_ue8m0(%src : vector<2xi8>) { + // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16 + // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %[[res1]]) + %res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E8M0FNU)-> vector<2xbf16> + llvm.return +} From 75a3137bc0dc4a52ef642dc17336f5e8f2101dd8 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Wed, 8 Oct 2025 06:42:13 +0000 Subject: [PATCH 2/6] add invalid test cases --- mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 32 +++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 0b3615487716d..c5f71cfeaba8b 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -254,6 +254,38 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) { // ----- +llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) { + // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f8x2 to f16x2.}} + %res = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_type(%src : vector<2xi8>) { + // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from f8x2 to bf16x2.}} + %res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f6x2_to_f16x2_invalid_type(%src : vector<2xi8>) { + // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f6x2 to f16x2.}} + %res = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f4x2_to_f16x2_invalid_type(%src : i8) { + // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f4x2 to f16x2.}} + %res = nvvm.convert.f4x2.to.f16x2 %src : i8 (f6E2M3FN) -> vector<2xf16> + llvm.return +} + +// ----- + llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) { // expected-error @below {{cache eviction priority supported only for cache level L2}} nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1> From 36d41ee036f7c3c8717bc52e372c80d91dd5f0b8 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Wed, 8 Oct 2025 06:44:25 +0000 Subject: [PATCH 3/6] fix formatting --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 8 +++----- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 5020af3992173..e6cfa24ccc3bb 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1843,7 +1843,7 @@ class NVVM_ConvertF8x2ToFP16x2Op_Base The `relu` attribute, when set, lowers to the '.relu' variant of the cvt instruction. - For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) }]; let results = (outs VectorOfLengthAndType<[2], [!cast(dstType)]>:$dst); let arguments = !if(!eq(dstType, "F16"), @@ -1867,11 +1867,9 @@ class NVVM_ConvertF8x2ToFP16x2Op_Base $dst = createIntrinsicCall(builder, intId, args); }]; } - def NVVM_ConvertF8x2ToF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"F16">; def NVVM_ConvertF8x2ToBF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"BF16">; - def NVVM_ConvertF6x2ToF16x2Op : NVVM_Op<"convert.f6x2.to.f16x2"> { let summary = "Convert a pair of f6 inputs to f16x2"; let description = [{ @@ -1882,7 +1880,7 @@ def NVVM_ConvertF6x2ToF16x2Op : NVVM_Op<"convert.f6x2.to.f16x2"> { The `relu` attribute, when set, lowers to the '.relu' variant of the cvt instruction. - For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) }]; let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst); let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src, @@ -1917,7 +1915,7 @@ def NVVM_ConvertF4x2ToF16x2Op : NVVM_Op<"convert.f4x2.to.f16x2"> { The `relu` attribute, when set, lowers to the '.relu' variant of the cvt instruction. - For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) }]; let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst); let arguments = (ins I8:$src, diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index bd38db52179a6..de7b9c31f7623 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -2154,7 +2154,7 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd, NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto curOp = cast(op); - + bool hasRelu = curOp.getRelu(); llvm::Intrinsic::ID intId = From 23afb336e4a66bf49116ff67e03a6b0451bed7cf Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Wed, 8 Oct 2025 08:48:10 +0000 Subject: [PATCH 4/6] fix formatting --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index de7b9c31f7623..229a1f3a4ad66 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -2160,14 +2160,12 @@ NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs( llvm::Intrinsic::ID intId = llvm::TypeSwitch(curOp.getSrcType()) .Case([&](Float8E4M3FNType type) { - return hasRelu - ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu - : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn; + return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn; }) .Case([&](Float8E5M2Type type) { - return hasRelu - ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu - : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn; + return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn; }) .Default([](mlir::Type type) { llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op"); From 41e2da8173dec1d64e03bbb7d30fe1bb35b94018 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Wed, 8 Oct 2025 10:27:02 +0000 Subject: [PATCH 5/6] refactor tablegen --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 100 +++++--------------- 1 file changed, 22 insertions(+), 78 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index e6cfa24ccc3bb..ed1308fba5578 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1832,25 +1832,28 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> { }]; } -class NVVM_ConvertF8x2ToFP16x2Op_Base -: NVVM_Op<"convert.f8x2.to." # !tolower(dstType) # "x2"> { - let summary = "Convert a pair of f8 inputs to " # !tolower(dstType) # "x2"; +class NVVM_ConvertToFP16x2Op_Base +: NVVM_Op<"convert." # srcType # "x2.to." # !tolower(dstType) # "x2"> { + let summary = "Convert a pair of " # srcType # " inputs to " # !tolower(dstType) # "x2"; let description = [{ - This Op converts the given f8 inputs in a i8x2 vector to }] # !tolower(dstType) # [{. + This Op converts the given }] # srcType # [{ inputs in a }] # + !if(!eq(srcType, "f4"), "packed i8", "i8x2 vector") # [{ to }] # + !tolower(dstType) # [{. The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements. - - The `relu` attribute, when set, lowers to the '.relu' variant of - the cvt instruction. + }] # + !if(!eq(dstType, "F16"), + [{The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction."}], "") # [{ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) }]; let results = (outs VectorOfLengthAndType<[2], [!cast(dstType)]>:$dst); let arguments = !if(!eq(dstType, "F16"), - (ins VectorOfLengthAndType<[2], [I8]>:$src, + (ins srcArgType:$src, DefaultValuedAttr:$relu, TypeAttr:$srcType), - (ins VectorOfLengthAndType<[2], [I8]>:$src, + (ins srcArgType:$src, TypeAttr:$srcType)); let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)"; let hasVerifier = 1; @@ -1863,79 +1866,20 @@ class NVVM_ConvertF8x2ToFP16x2Op_Base string llvmBuilder = [{ auto [intId, args] = - NVVM::ConvertF8x2To}] # dstType # [{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); - $dst = createIntrinsicCall(builder, intId, args); - }]; -} -def NVVM_ConvertF8x2ToF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"F16">; -def NVVM_ConvertF8x2ToBF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"BF16">; - -def NVVM_ConvertF6x2ToF16x2Op : NVVM_Op<"convert.f6x2.to.f16x2"> { - let summary = "Convert a pair of f6 inputs to f16x2"; - let description = [{ - This Op converts the given f6 inputs in a i8x2 vector to f16x2. - - The result `dst` is represented as a vector of f16 elements. - - The `relu` attribute, when set, lowers to the '.relu' variant of - the cvt instruction. - - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) - }]; - let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst); - let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src, - DefaultValuedAttr:$relu, - TypeAttr:$srcType); - let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)"; - let hasVerifier = 1; - - let extraClassDeclaration = [{ - static IDArgPair - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, - llvm::IRBuilderBase &builder); - }]; - - string llvmBuilder = [{ - auto [intId, args] = - NVVM::ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); + NVVM::Convert}] # !toupper(srcType) # [{x2To}] # dstType # + [{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); $dst = createIntrinsicCall(builder, intId, args); }]; } -def NVVM_ConvertF4x2ToF16x2Op : NVVM_Op<"convert.f4x2.to.f16x2"> { - let summary = "Convert a pair of f4 inputs to f16x2"; - let description = [{ - This Op converts the given f4 inputs packed in an i8 to f16x2. - - The result `dst` is represented as a vector of f16 elements. The value - converted from the lower 4 bits of `src` is stored in the first element of - `dst` and the value converted from the upper 4 bits of `src` is stored in - the second element of `dst`. - - The `relu` attribute, when set, lowers to the '.relu' variant of - the cvt instruction. - - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) - }]; - let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst); - let arguments = (ins I8:$src, - DefaultValuedAttr:$relu, - TypeAttr:$srcType); - let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)"; - let hasVerifier = 1; - - let extraClassDeclaration = [{ - static IDArgPair - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, - llvm::IRBuilderBase &builder); - }]; - - string llvmBuilder = [{ - auto [intId, args] = - NVVM::ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); - $dst = createIntrinsicCall(builder, intId, args); - }]; -} +def NVVM_ConvertF8x2ToF16x2Op : + NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "F16">; +def NVVM_ConvertF8x2ToBF16x2Op : + NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "BF16">; +def NVVM_ConvertF6x2ToF16x2Op : + NVVM_ConvertToFP16x2Op_Base<"f6", VectorOfLengthAndType<[2], [I8]>, "F16">; +def NVVM_ConvertF4x2ToF16x2Op : + NVVM_ConvertToFP16x2Op_Base<"f4", I8, "F16">; //===----------------------------------------------------------------------===// // NVVM MMA Ops From a349a4f0da142436f176bc4fb89615822ddeef04 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Fri, 10 Oct 2025 06:44:53 +0000 Subject: [PATCH 6/6] match argument cases --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index ed1308fba5578..0287a20e7ed10 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1833,11 +1833,11 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> { } class NVVM_ConvertToFP16x2Op_Base -: NVVM_Op<"convert." # srcType # "x2.to." # !tolower(dstType) # "x2"> { - let summary = "Convert a pair of " # srcType # " inputs to " # !tolower(dstType) # "x2"; +: NVVM_Op<"convert." # !tolower(srcType) # "x2.to." # !tolower(dstType) # "x2"> { + let summary = "Convert a pair of " # !tolower(srcType) # " inputs to " # !tolower(dstType) # "x2"; let description = [{ - This Op converts the given }] # srcType # [{ inputs in a }] # - !if(!eq(srcType, "f4"), "packed i8", "i8x2 vector") # [{ to }] # + This Op converts the given }] # !tolower(srcType) # [{ inputs in a }] # + !if(!eq(srcType, "F4"), "packed i8", "i8x2 vector") # [{ to }] # !tolower(dstType) # [{. The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements. @@ -1866,20 +1866,20 @@ class NVVM_ConvertToFP16x2Op_Base , "F16">; + NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "F16">; def NVVM_ConvertF8x2ToBF16x2Op : - NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "BF16">; + NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "BF16">; def NVVM_ConvertF6x2ToF16x2Op : - NVVM_ConvertToFP16x2Op_Base<"f6", VectorOfLengthAndType<[2], [I8]>, "F16">; + NVVM_ConvertToFP16x2Op_Base<"F6", VectorOfLengthAndType<[2], [I8]>, "F16">; def NVVM_ConvertF4x2ToF16x2Op : - NVVM_ConvertToFP16x2Op_Base<"f4", I8, "F16">; + NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">; //===----------------------------------------------------------------------===// // NVVM MMA Ops