Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,40 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}

def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
let summary = "Convert a pair of float inputs to f4x2";
let description = [{
This Op converts each of the given float inputs to the specified fp4 type.
The result `dst` is returned as an i8 type where the converted values are
packed such that the value converted from `a` is stored in the upper 4 bits
of `dst` and the value converted from `b` is stored in the lower 4 bits of
`dst`.
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];

let results = (outs I8:$dst);
let arguments = (ins F32:$a, F32:$b,
DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttr:$dstTy);
let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
let hasVerifier = 1;

let extraClassDeclaration = [{
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
}];

string llvmBuilder = [{
auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
$dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
}];
}

def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
Expand Down
28 changes: 28 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}

LogicalResult ConvertF32x2ToF4x2Op::verify() {
mlir::MLIRContext *ctx = getContext();

if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
return emitOpError("Only ")
<< mlir::Float4E2M1FNType::get(ctx)
<< " type is supported for conversions from f32x2 to f4x2.";

return success();
}

LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
Expand Down Expand Up @@ -2014,6 +2025,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}

NVVM::IDArgPair
ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(op.getA()));
args.push_back(mt.lookupValue(op.getB()));

bool hasRelu = op.getRelu();

llvm::Intrinsic::ID intId =
hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;

return {intId, std::move(args)};
}

#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1
llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
%res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
%res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
llvm.return
}
8 changes: 8 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,14 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {

// -----

llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
// expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}}
%res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN)
llvm.return
}

// -----

llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
Expand Down