-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[MLIR][NVVM] Add support for f32x2 to f4x2 conversion #162273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add support for f32x2 to f4x2 conversion #162273
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
This change adds the `convert.f32x2.to.f4x2` op to the NVVM Dialect for converting a pair of f32 values to an f4x2 (`e2m1x2`) value. PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
0c8e787
to
6cf5a89
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds the PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt Full diff: https://github.com/llvm/llvm-project/pull/162273.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e2a0331542742..3a65555204c36 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1655,6 +1655,40 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}
+def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
+ let summary = "Convert a pair of float inputs to f4x2";
+ let description = [{
+ This Op converts each of the given float inputs to the specified fp4 type.
+ The result `dst` is returned as an i8 type where the converted values are
+ packed such that the value converted from `a` is stored in the upper 4 bits
+ of `dst` and the value converted from `b` is stored in the lower 4 bits of
+ `dst`.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+
+ let results = (outs I8:$dst);
+ let arguments = (ins F32:$a, F32:$b,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
+ $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+ }];
+}
+
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7f419a062201d..37b4168386da8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f32x2 to f4x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2014,6 +2025,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getB()));
+
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+
+ return {intId, std::move(args)};
+}
+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
new file mode 100644
index 0000000000000..04e2ddff802a9
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1
+llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+ %res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN)
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+ %res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0b3615487716d..78e1e659ed85d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -254,6 +254,14 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
// -----
+llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}}
+ %res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice thanks
This change adds the
convert.f32x2.to.f4x2
op to the NVVM Dialectfor converting a pair of f32 values to an f4x2 (
e2m1x2
) value.PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt