diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index e2a0331542742..3a65555204c36 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1655,6 +1655,40 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> { }]; } +def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> { + let summary = "Convert a pair of float inputs to f4x2"; + let description = [{ + This Op converts each of the given float inputs to the specified fp4 type. + The result `dst` is returned as an i8 type where the converted values are + packed such that the value converted from `a` is stored in the upper 4 bits + of `dst` and the value converted from `b` is stored in the lower 4 bits of + `dst`. + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let results = (outs I8:$dst); + let arguments = (ins F32:$a, F32:$b, + DefaultValuedAttr:$relu, + TypeAttr:$dstTy); + let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder); + llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args); + $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext())); + }]; +} + def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { let summary = "Convert a pair of float inputs to f6x2"; let description = [{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 7f419a062201d..3e4c7cd6826fc 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() { return success(); } +LogicalResult ConvertF32x2ToF4x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa(getDstTy())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f32x2 to f4x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -2014,6 +2025,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair +ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + llvm::SmallVector args; + args.push_back(mt.lookupValue(op.getA())); + args.push_back(mt.lookupValue(op.getB())); + + bool hasRelu = op.getRelu(); + + llvm::Intrinsic::ID intId = + hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite; + + return {intId, std::move(args)}; +} + #define GET_F32x2_TO_F6x2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir new file mode 100644 index 0000000000000..04e2ddff802a9 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1 +llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) { + // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8 + %res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN) + // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) + // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8 + %res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN) + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 0b3615487716d..78e1e659ed85d 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -254,6 +254,14 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) { // ----- +llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) { + // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}} + %res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN) + llvm.return +} + +// ----- + llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) { // expected-error @below {{cache eviction priority supported only for cache level L2}} nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>