Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 20 additions & 43 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Dialect/LLVMIR/LLVMTypes.td"
include "mlir/IR/CommonAttrConstraints.td"

def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
Expand Down Expand Up @@ -1258,18 +1259,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}

def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;

def ConvertFP6Type : I32EnumAttr<"ConvertFP6Type", "NVVM ConvertFP6Type kind",
[ConvertFP6E2M3, ConvertFP6E3M2]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def ConvertFP6TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP6Type, "convert_fp6_type"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
Expand All @@ -1290,19 +1279,20 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {

let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
ConvertFP6TypeAttr:$type,
F32:$a,
F32:$b,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttr:$dstTy);
let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
let hasVerifier = 1;

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP6Type,
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
bool hasRelu);
}];

string llvmBuilder = [{
auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($type, $relu);
auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
Expand All @@ -1312,19 +1302,6 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
}];
}

def ConvertFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
def ConvertFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
def ConvertFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;

def ConvertFP8Type : I32EnumAttr<"ConvertFP8Type", "NVVM ConvertFP8Type kind",
[ConvertFP8E4M3, ConvertFP8E5M2, ConvertFP8UE8M0]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def ConvertFP8TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP8Type, "convert_fp8_type"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let summary = "Convert a pair of float inputs to f8x2";
let description = [{
Expand All @@ -1346,23 +1323,23 @@ def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
ConvertFP8TypeAttr:$type,
F32:$a,
F32:$b,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttr:$dstTy);
let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu);
}];

string llvmBuilder = [{
auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
Expand Down Expand Up @@ -1394,18 +1371,18 @@ def NVVM_ConvertF16x2ToF8x2Op : NVVM_Op<"convert.f16x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
ConvertFP8TypeAttr:$type,
VectorOfLengthAndType<[2], [F16]>:$a,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttr:$dstTy);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
bool hasRelu);
}];

string llvmBuilder = [{
auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($type, $relu);
auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($dstTy, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
Expand Down Expand Up @@ -1437,11 +1414,11 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
ConvertFP8TypeAttr:$type,
VectorOfLengthAndType<[2], [BF16]>:$a,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
TypeAttr:$dstTy);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
Expand Down
167 changes: 109 additions & 58 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ LogicalResult ConvertFloatToTF32Op::verify() {
return success();
}

LogicalResult ConvertF32x2ToF6x2Op::verify() {
mlir::MLIRContext *ctx = getContext();

if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
return emitOpError("Only ")
<< mlir::Float6E2M3FNType::get(ctx) << " and "
<< mlir::Float6E3M2FNType::get(ctx)
<< " types are supported for conversions from f32x2 to f6x2.";
Copy link
Contributor

@durga4github durga4github Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Not for this PR]:
I am wondering if there is a way to specify this type constraint (line 195) in the Op itself (in tablegen).
That way, the default case checks can all happen within the td file.

}
return success();
}

LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
Expand All @@ -200,41 +212,67 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {

bool hasRelu = getRelu();

switch (getType()) {
case ConvertFP8Type::E4M3:
case ConvertFP8Type::E5M2:
if (!isRoundingModeRN)
return emitOpError("Only RN rounding mode is supported for conversions "
"from f32x2 to .e4m3x2 or .e5m2x2 types");
if (!isSatFinite)
return emitOpError("Only SATFINITE saturation mode is supported for "
"conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
break;
case ConvertFP8Type::UE8M0:
if (!(isRoundingModeRZ || isRoundingModeRP))
return emitOpError("Only RZ or RP rounding modes are supported for "
"conversions from f32x2 to .ue8m0x2 type");
if (hasRelu)
return emitOpError("relu not supported for conversions to .ue8m0x2 type");
break;
}
return success();
mlir::MLIRContext *ctx = getContext();

return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
.Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
[&](mlir::Type) -> LogicalResult {
if (!isRoundingModeRN) {
return emitOpError("Only RN rounding mode is supported for "
"conversions from f32x2 to ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx) << " types";
}
if (!isSatFinite) {
return emitOpError("Only SATFINITE saturation mode is supported "
"for conversions "
"from f32x2 to ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx) << " types";
}
return success();
})
.Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
if (!(isRoundingModeRZ || isRoundingModeRP)) {
return emitOpError("Only RZ and RP rounding modes are supported for "
"conversions from f32x2 to ")
<< mlir::Float8E8M0FNUType::get(ctx) << " type";
}
if (hasRelu) {
return emitOpError("relu not supported for conversions to ")
<< mlir::Float8E8M0FNUType::get(ctx) << " type";
}
return success();
})
.Default([&](mlir::Type) {
return emitOpError("Only ")
<< mlir::Float8E4M3FNType::get(ctx) << ", "
<< mlir::Float8E5M2Type::get(ctx) << ", and "
<< mlir::Float8E8M0FNUType::get(ctx)
<< " types are "
"supported for conversions from f32x2 to f8x2";
});
}

LogicalResult ConvertF16x2ToF8x2Op::verify() {
if (getType() == ConvertFP8Type::UE8M0)
return emitOpError("Only .e4m3 or .e5m2 types are supported for "
"conversions from f16x2 to f8x2.");
mlir::MLIRContext *ctx = getContext();

if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
return emitOpError("Only ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx)
<< " types are supported for conversions from f16x2 to f8x2.";
}
return success();
}

LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;

if (getType() != ConvertFP8Type::UE8M0)
return emitOpError(
"Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
<< " type is supported for conversions from "
"bf16x2 to f8x2.";

auto rnd = getRnd();
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
Expand Down Expand Up @@ -1714,15 +1752,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite

llvm::Intrinsic::ID
ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
switch (type) {
case NVVM::ConvertFP6Type::E2M3:
return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
case NVVM::ConvertFP6Type::E3M2:
return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
}
llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
bool hasRelu) {
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
})
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
return llvm::Intrinsic::not_intrinsic;
});
}

#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
Expand All @@ -1734,41 +1776,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
: llvm::Intrinsic::nvvm_ff_to_##type##_rn

llvm::Intrinsic::ID
ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
NVVM::FPRoundingMode rnd,
ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);

switch (type) {
case NVVM::ConvertFP8Type::E4M3:
return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
case NVVM::ConvertFP8Type::E5M2:
return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
case NVVM::ConvertFP8Type::UE8M0:
if (hasRoundingModeRZ)
return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
else if (hasRoundingModeRP)
return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
}
llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
})
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
})
.Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
if (hasRoundingModeRZ)
return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
else if (hasRoundingModeRP)
return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);

llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
return llvm::Intrinsic::not_intrinsic;
});
}

#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn

llvm::Intrinsic::ID
ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
switch (type) {
case NVVM::ConvertFP8Type::E4M3:
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
case NVVM::ConvertFP8Type::E5M2:
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
default:
llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
}
llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
bool hasRelu) {
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
})
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
return llvm::Intrinsic::not_intrinsic;
});
}

#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
%res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN)
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
llvm.return
}

// CHECK-LABEL: @convert_f32x2_to_fp6x2_vector
llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
%res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
%res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E2M3FN)
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
%res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN)
llvm.return
}
Loading