-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][NVVM] Add support for converting fp4/6/8 to fp16x2 #162439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1855,6 +1855,55 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> { | |
}]; | ||
} | ||
|
||
class NVVM_ConvertToFP16x2Op_Base <string srcType, Type srcArgType, string dstType> | ||
: NVVM_Op<"convert." # srcType # "x2.to." # !tolower(dstType) # "x2"> { | ||
let summary = "Convert a pair of " # srcType # " inputs to " # !tolower(dstType) # "x2"; | ||
let description = [{ | ||
This Op converts the given }] # srcType # [{ inputs in a }] # | ||
!if(!eq(srcType, "f4"), "packed i8", "i8x2 vector") # [{ to }] # | ||
!tolower(dstType) # [{. | ||
|
||
The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements. | ||
}] # | ||
!if(!eq(dstType, "F16"), | ||
[{The `relu` attribute, when set, lowers to the '.relu' variant of | ||
the cvt instruction."}], "") # [{ | ||
|
||
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) | ||
}]; | ||
let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst); | ||
let arguments = !if(!eq(dstType, "F16"), | ||
(ins srcArgType:$src, | ||
DefaultValuedAttr<BoolAttr, "false">:$relu, | ||
TypeAttr:$srcType), | ||
(ins srcArgType:$src, | ||
TypeAttr:$srcType)); | ||
let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)"; | ||
let hasVerifier = 1; | ||
|
||
let extraClassDeclaration = [{ | ||
static IDArgPair | ||
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, | ||
llvm::IRBuilderBase &builder); | ||
}]; | ||
|
||
string llvmBuilder = [{ | ||
auto [intId, args] = | ||
NVVM::Convert}] # !toupper(srcType) # [{x2To}] # dstType # | ||
[{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); | ||
$dst = createIntrinsicCall(builder, intId, args); | ||
}]; | ||
} | ||
|
||
def NVVM_ConvertF8x2ToF16x2Op : | ||
NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "F16">; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we'd have to do some string truncation in that case which I think will not be as clean as it currently. |
||
def NVVM_ConvertF8x2ToBF16x2Op : | ||
NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "BF16">; | ||
def NVVM_ConvertF6x2ToF16x2Op : | ||
NVVM_ConvertToFP16x2Op_Base<"f6", VectorOfLengthAndType<[2], [I8]>, "F16">; | ||
def NVVM_ConvertF4x2ToF16x2Op : | ||
NVVM_ConvertToFP16x2Op_Base<"f4", I8, "F16">; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// NVVM MMA Ops | ||
//===----------------------------------------------------------------------===// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @convert_f4x2_to_f16x2 | ||
llvm.func @convert_f4x2_to_f16x2(%src : i8) { | ||
// CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16 | ||
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %[[res1]]) | ||
%res1 = nvvm.convert.f4x2.to.f16x2 %src : i8 (f4E2M1FN)-> vector<2xf16> | ||
// CHECK: %[[res2:.*]] = zext i8 %{{.*}} to i16 | ||
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %[[res2]]) | ||
%res2 = nvvm.convert.f4x2.to.f16x2 %src {relu = true} : i8 (f4E2M1FN)-> vector<2xf16> | ||
llvm.return | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering do we need
TypeAttr
. Isn't that clear from the name of the OP?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, this
TypeAttr
is for the source type since it can be eithere4m3
(f8E4M3FN
) ore5m2
(f8E5M2
).