Skip to content

Conversation

Wolfram70
Copy link
Contributor

This change adds the convert.f32x2.to.f4x2 op to the NVVM Dialect
for converting a pair of f32 values to an f4x2 (e2m1x2) value.

PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt

@Wolfram70 Wolfram70 requested a review from durga4github October 7, 2025 12:22
@Wolfram70 Wolfram70 self-assigned this Oct 7, 2025
@Wolfram70 Wolfram70 requested a review from grypp as a code owner October 7, 2025 12:22
Copy link

github-actions bot commented Oct 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-cvt-f32-f4x2 branch from 0c8e787 to 6cf5a89 Compare October 10, 2025 05:06
@llvmbot
Copy link
Member

llvmbot commented Oct 10, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

Changes

This change adds the convert.f32x2.to.f4x2 op to the NVVM Dialect
for converting a pair of f32 values to an f4x2 (e2m1x2) value.

PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt


Full diff: https://github.com/llvm/llvm-project/pull/162273.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+34)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+28)
  • (added) mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir (+12)
  • (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e2a0331542742..3a65555204c36 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1655,6 +1655,40 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
   }];
 }
 
+def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
+  let summary = "Convert a pair of float inputs to f4x2";
+  let description = [{
+    This Op converts each of the given float inputs to the specified fp4 type.
+    The result `dst` is returned as an i8 type where the converted values are 
+    packed such that the value converted from `a` is stored in the upper 4 bits 
+    of `dst` and the value converted from `b` is stored in the lower 4 bits of 
+    `dst`.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+
+  let results = (outs I8:$dst);
+  let arguments = (ins F32:$a, F32:$b,
+                       DefaultValuedAttr<BoolAttr, "false">:$relu,
+                       TypeAttr:$dstTy);
+  let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, 
+      LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
+    $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+  }];
+}
+
 def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
   let summary = "Convert a pair of float inputs to f6x2";
   let description = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7f419a062201d..37b4168386da8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
   return success();
 }
 
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+
+  if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+    return emitOpError("Only ")
+           << mlir::Float4E2M1FNType::get(ctx)
+           << " type is supported for conversions from f32x2 to f4x2.";
+
+  return success();
+}
+
 LogicalResult BulkStoreOp::verify() {
   if (getInitVal() != 0)
     return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2014,6 +2025,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   }
 }
 
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+                                            LLVM::ModuleTranslation &mt,
+                                            llvm::IRBuilderBase &builder) {
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(op.getA()));
+  args.push_back(mt.lookupValue(op.getB()));
+  
+  bool hasRelu = op.getRelu();
+  
+  llvm::Intrinsic::ID intId =
+      hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+              : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+              
+  return {intId, std::move(args)};
+}
+
 #define GET_F32x2_TO_F6x2_ID(type, has_relu)                                   \
   has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite            \
            : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
new file mode 100644
index 0000000000000..04e2ddff802a9
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1
+llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+  %res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN)
+  // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+  %res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0b3615487716d..78e1e659ed85d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -254,6 +254,14 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
 
 // -----
 
+llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
+  // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}}
+  %res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN)
+  llvm.return
+}
+
+// -----
+
 llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
   // expected-error @below {{cache eviction priority supported only for cache level L2}}
   nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice thanks

@Wolfram70 Wolfram70 merged commit e98de2e into llvm:main Oct 15, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants