Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,55 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
}];
}

class NVVM_ConvertToFP16x2Op_Base <string srcType, Type srcArgType, string dstType>
: NVVM_Op<"convert." # srcType # "x2.to." # !tolower(dstType) # "x2"> {
let summary = "Convert a pair of " # srcType # " inputs to " # !tolower(dstType) # "x2";
let description = [{
This Op converts the given }] # srcType # [{ inputs in a }] #
!if(!eq(srcType, "f4"), "packed i8", "i8x2 vector") # [{ to }] #
!tolower(dstType) # [{.

The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements.
}] #
!if(!eq(dstType, "F16"),
[{The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction."}], "") # [{

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst);
let arguments = !if(!eq(dstType, "F16"),
(ins srcArgType:$src,
DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttr:$srcType),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering do we need TypeAttr. Isn't that clear from the name of the OP?

Copy link
Contributor Author

@Wolfram70 Wolfram70 Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, this TypeAttr is for the source type since it can be either e4m3 (f8E4M3FN) or e5m2 (f8E5M2).

(ins srcArgType:$src,
TypeAttr:$srcType));
let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
let hasVerifier = 1;

let extraClassDeclaration = [{
static IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder);
}];

string llvmBuilder = [{
auto [intId, args] =
NVVM::Convert}] # !toupper(srcType) # [{x2To}] # dstType #
[{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
$dst = createIntrinsicCall(builder, intId, args);
}];
}

def NVVM_ConvertF8x2ToF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "F16">;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious,
would it look better if we move the "x2" semantics here itself? like "f8x2" and "F16x2"..
I believe it may not be possible due to cast<Type> on the dstType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we'd have to do some string truncation in that case which I think will not be as clean as it currently.

def NVVM_ConvertF8x2ToBF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "BF16">;
def NVVM_ConvertF6x2ToF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"f6", VectorOfLengthAndType<[2], [I8]>, "F16">;
def NVVM_ConvertF4x2ToF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"f4", I8, "F16">;

//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
Expand Down
137 changes: 137 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,51 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}

LogicalResult ConvertF8x2ToF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();

if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
return emitOpError("Only ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx)
<< " types are supported for conversions from f8x2 to f16x2.";

return success();
}

LogicalResult ConvertF8x2ToBF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
return emitOpError("Only ")
<< mlir::Float8E8M0FNUType::get(ctx)
<< " type is supported for conversions from f8x2 to bf16x2.";

return success();
}

LogicalResult ConvertF6x2ToF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();

if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
return emitOpError("Only ")
<< mlir::Float6E2M3FNType::get(ctx) << " and "
<< mlir::Float6E3M2FNType::get(ctx)
<< " types are supported for conversions from f6x2 to f16x2.";

return success();
}

LogicalResult ConvertF4x2ToF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();

if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
return emitOpError("Only ")
<< mlir::Float4E2M1FNType::get(ctx)
<< " type is supported for conversions from f4x2 to f16x2.";

return success();
}

LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
Expand Down Expand Up @@ -2055,6 +2100,98 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}

NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);

bool hasRelu = curOp.getRelu();

llvm::Intrinsic::ID intId =
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
.Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
})
.Case<Float8E5M2Type>([&](Float8E5M2Type type) {
return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
})
.Default([](mlir::Type type) {
llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
return llvm::Intrinsic::not_intrinsic;
});

llvm::Value *packedI16 =
builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
llvm::Type::getInt16Ty(builder.getContext()));

return {intId, {packedI16}};
}

NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);

llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
llvm::Value *packedI16 =
builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
llvm::Type::getInt16Ty(builder.getContext()));

return {intId, {packedI16}};
}

NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);

bool hasRelu = curOp.getRelu();

llvm::Intrinsic::ID intId =
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
.Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
})
.Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
})
.Default([](mlir::Type type) {
llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
return llvm::Intrinsic::not_intrinsic;
});

llvm::Value *packedI16 =
builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
llvm::Type::getInt16Ty(builder.getContext()));

return {intId, {packedI16}};
}

NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);

bool hasRelu = curOp.getRelu();

llvm::Intrinsic::ID intId =
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
.Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
: llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
})
.Default([](mlir::Type type) {
llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
return llvm::Intrinsic::not_intrinsic;
});

llvm::Value *extendedI16 =
builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
llvm::Type::getInt16Ty(builder.getContext()));

return {intId, {extendedI16}};
}

llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// -----

// CHECK-LABEL: @convert_f4x2_to_f16x2
llvm.func @convert_f4x2_to_f16x2(%src : i8) {
// CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %[[res1]])
%res1 = nvvm.convert.f4x2.to.f16x2 %src : i8 (f4E2M1FN)-> vector<2xf16>
// CHECK: %[[res2:.*]] = zext i8 %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %[[res2]])
%res2 = nvvm.convert.f4x2.to.f16x2 %src {relu = true} : i8 (f4E2M1FN)-> vector<2xf16>
llvm.return
}
24 changes: 24 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,27 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
%res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
llvm.return
}

// -----

// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3
llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) {
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn(i16 %[[res1]])
%res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
// CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn.relu(i16 %[[res2]])
%res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
llvm.return
}

// CHECK-LABEL: @convert_f6x2_to_f16x2_e3m2
llvm.func @convert_f6x2_to_f16x2_e3m2(%src : vector<2xi8>) {
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn(i16 %[[res1]])
%res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
// CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 %[[res2]])
%res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
llvm.return
}
34 changes: 34 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,37 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
%res2 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8>
llvm.return
}

// -----

// CHECK-LABEL: @convert_f8x2_to_f16x2
llvm.func @convert_f8x2_to_f16x2_e4m3(%src : vector<2xi8>) {
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %[[res1]])
%res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
// CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %[[res2]])
%res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
llvm.return
}

// CHECK-LABEL: @convert_f8x2_to_f16x2_e5m2
llvm.func @convert_f8x2_to_f16x2_e5m2(%src : vector<2xi8>) {
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %[[res1]])
%res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E5M2)-> vector<2xf16>
// CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %[[res2]])
%res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E5M2)-> vector<2xf16>
llvm.return
}

// -----

// CHECK-LABEL: @convert_f8x2_to_bf16x2_ue8m0
llvm.func @convert_f8x2_to_bf16x2_ue8m0(%src : vector<2xi8>) {
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
// CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %[[res1]])
%res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E8M0FNU)-> vector<2xbf16>
llvm.return
}
32 changes: 32 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,38 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {

// -----

llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
// expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f8x2 to f16x2.}}
%res = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xf16>
llvm.return
}

// -----

llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_type(%src : vector<2xi8>) {
// expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from f8x2 to bf16x2.}}
%res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
llvm.return
}

// -----

llvm.func @nvvm_cvt_f6x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
// expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f6x2 to f16x2.}}
%res = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xf16>
llvm.return
}

// -----

llvm.func @nvvm_cvt_f4x2_to_f16x2_invalid_type(%src : i8) {
// expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f4x2 to f16x2.}}
%res = nvvm.convert.f4x2.to.f16x2 %src : i8 (f6E2M3FN) -> vector<2xf16>
llvm.return
}

// -----

llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
Expand Down