Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 90 additions & 12 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
!eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k32:c:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k32:c:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),

// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
// All other supported geometries use the same fragment format for f32 and
Expand All @@ -298,6 +302,21 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
!eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
!eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),

!eq(gft,"m16n8k4:a:f64") : !listsplat(llvm_double_ty, 2),
!eq(gft,"m16n8k4:b:f64") : [llvm_double_ty],
!eq(gft,"m16n8k4:c:f64") : !listsplat(llvm_double_ty, 4),
!eq(gft,"m16n8k4:d:f64") : !listsplat(llvm_double_ty, 4),

!eq(gft,"m16n8k8:a:f64") : !listsplat(llvm_double_ty, 4),
!eq(gft,"m16n8k8:b:f64") : !listsplat(llvm_double_ty, 2),
!eq(gft,"m16n8k8:c:f64") : !listsplat(llvm_double_ty, 4),
!eq(gft,"m16n8k8:d:f64") : !listsplat(llvm_double_ty, 4),

!eq(gft,"m16n8k16:a:f64") : !listsplat(llvm_double_ty, 8),
!eq(gft,"m16n8k16:b:f64") : !listsplat(llvm_double_ty, 4),
!eq(gft,"m16n8k16:c:f64") : !listsplat(llvm_double_ty, 4),
!eq(gft,"m16n8k16:d:f64") : !listsplat(llvm_double_ty, 4),

// wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
!eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
Expand Down Expand Up @@ -378,6 +397,26 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
!eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),

// mma e4m3/e5m2 -> f16/f32 @ m16n8k16
!eq(gft,"m16n8k16:a:e4m3") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k16:a:e5m2") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k16:b:e4m3") : [llvm_i32_ty],
!eq(gft,"m16n8k16:b:e5m2") : [llvm_i32_ty],
// mma e4m3/e5m2/e3m2/e2m3/e2m1 -> f32 @ m16n8k32
!eq(gft,"m16n8k32:a:e4m3") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k32:a:e5m2") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k32:a:e3m2") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k32:a:e2m3") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k32:a:e2m1") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k32:b:e4m3") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k32:b:e5m2") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k32:b:e3m2") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k32:b:e2m3") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k32:b:e2m1") : !listsplat(llvm_i32_ty, 2),
// mma e2m1 -> f32 @m16n8k64
!eq(gft,"m16n8k64:a:e2m1") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k64:b:e2m1") : !listsplat(llvm_i32_ty, 2),

// wmma/mma b1 -> s32 @ m8n8k128(b1)
!eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
!eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
Expand Down Expand Up @@ -468,14 +507,15 @@ class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, strin
# !if(Satfinite, "_satfinite", "");
}

class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, string Kind,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
string record = "int_nvvm_mma"
# !subst(".", "_", b1op)
# "_" # A.geom
# "_" # ALayout
# "_" # BLayout
# !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
# !if(Satfinite, "_satfinite", "")
# signature;
}
Expand Down Expand Up @@ -601,14 +641,26 @@ class NVVM_MMA_OPS {
["m16n8k16", "m16n8k8"],
["bf16"], [], ["f32"], []>.ret;
list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
["m8n8k4"],
["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"],
["f64"], [], ["f64"], []>.ret;
list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
["m8n8k4", "m16n8k8", "m16n8k16"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
["m8n8k16", "m16n8k16", "m16n8k32"],
["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
// m16n8k32 fp8 variants are intersected with f8f6f4 variants
// and processed there
list<list<WMMA_REGS>> fp8_mma_ops = MMA_OPS<
["m16n8k16"],
["e4m3", "e5m2"], ["e4m3", "e5m2"],
["f16", "f32"], ["f16", "f32"]>.ret;
// it also contains e4m3/e5m2 from fp8 variants
list<list<WMMA_REGS>> f8f6f4_mma_ops = MMA_OPS<
["m16n8k32"],
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
["f16", "f32"], ["f16", "f32"]>.ret;
list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
["m8n8k32", "m16n8k32", "m16n8k64"],
["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
Expand All @@ -617,7 +669,8 @@ class NVVM_MMA_OPS {
["b1"], [], ["s32"], []>.ret;
list<list<WMMA_REGS>> all_mma_ops = !listconcat(
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
int_mma_ops, subint_mma_ops, bit_mma_ops);

list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
["m16n8k16", "m16n8k32"],
Expand Down Expand Up @@ -770,7 +823,8 @@ class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
// if NVVM_MMA_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b,
string kind, int satf> {
// MMA ops check both layouts.
string layout = layout_a # ":" # layout_b;
string a_type = frags[0].ptx_elt_type;
Expand Down Expand Up @@ -805,10 +859,31 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
!or(!ne(a_type, b_type),
!ne(c_type, d_type))): false,

// m16n8k8 requires C and D to be the same type.
!and(!eq(geom, "m16n8k8"),
// m16n8k16/m16n8k32 requires C and D to be the same type
!and(!or(!eq(geom, "m16n8k16"),
!eq(geom, "m16n8k32")),
!ne(c_type, d_type)): false,

// Limit kind to valid types and geometries
!and(!ne(kind, ""),
!or(!ne(geom, "m16n8k32"),
!and(!ne(a_type, "e4m3"),
!ne(a_type, "e5m2"),
!ne(a_type, "e3m2"),
!ne(a_type, "e2m3"),
!ne(a_type, "e2m1")))): false,

// Limit m16n8k16/m16n8k32 with no kind to valid types
!and(!eq(kind, ""),
!or(!eq(geom, "m16n8k16"),
!eq(geom, "m16n8k32")),
!or(!eq(a_type, "e3m2"),
!eq(a_type, "e2m3"),
!eq(a_type, "e2m1"),
!eq(b_type, "e3m2"),
!eq(b_type, "e2m3"),
!eq(b_type, "e2m1"))): false,

// All other are OK.
true: true
);
Expand Down Expand Up @@ -882,9 +957,10 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
!eq(a_type, "tf32")),
!ne(a_type, b_type)): false,

// m16n8k16 and m16n8k32 requires C and D to be the same type.
// m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
!and(!or(!eq(geom, "m16n8k16"),
!eq(geom, "m16n8k32")),
!eq(geom, "m16n8k32"),
!eq(geom, "m16n8k64")),
!ne(c_type, d_type)): false,

!and(!eq(kind, ""),
Expand Down Expand Up @@ -2215,10 +2291,12 @@ foreach layout_a = ["row", "col"] in {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
def MMA_NAME<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>.record
: NVVM_MMA<op[0], op[1], op[2], op[3]>;
}
foreach kind = ["", "kind::f8f6f4"] in {
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
def MMA_NAME<layout_a, layout_b, satf, b1op, kind, op[0], op[1], op[2], op[3]>.record
: NVVM_MMA<op[0], op[1], op[2], op[3]>;
}
} // kind
} // b1op
} // op
} // satf
Expand Down
30 changes: 21 additions & 9 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -4461,6 +4461,10 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!eq(ptx_elt_type, "e2m1"),
!ne(kind, "")) : [hasSM120a, hasPTX<87>],

!and(!or(!eq(ptx_elt_type,"e4m3"),
!eq(ptx_elt_type,"e5m2")),
!eq(geom, "m16n8k16")) : [hasSM<89>, hasPTX<87>],

!or(!eq(ptx_elt_type, "e4m3"),
!eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],

Expand All @@ -4476,6 +4480,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!and(!eq(geom, "m8n8k4"),
!eq(ptx_elt_type, "f64")) : [hasSM<80>, hasPTX<70>],

!and(!or(!eq(geom, "m16n8k4"),
!eq(geom, "m16n8k8"),
!eq(geom, "m16n8k16")),
!eq(ptx_elt_type, "f64")) : [hasSM<90>, hasPTX<78>],

// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
!and(!or(!eq(geom, "m8n32k16"),
!eq(geom, "m32n8k16")),
Expand Down Expand Up @@ -4760,8 +4769,8 @@ defset list<WMMA_INSTR> WMMAs = {
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
string ALayout, string BLayout, int Satfinite, string b1op>
: WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
string ALayout, string BLayout, int Satfinite, string b1op, string Kind>
: WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
Expand All @@ -4776,6 +4785,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragA.geom
# "." # ALayout
# "." # BLayout
# !if(!ne(Kind, ""), "." # Kind, "")
# !if(Satfinite, ".satfinite", "")
# TypeList
# b1op # "\n\t\t"
Expand All @@ -4792,13 +4802,15 @@ defset list<WMMA_INSTR> MMAs = {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
def : MMA<WMMA_REGINFO<op[0], "mma">,
WMMA_REGINFO<op[1], "mma">,
WMMA_REGINFO<op[2], "mma">,
WMMA_REGINFO<op[3], "mma">,
layout_a, layout_b, satf, b1op>;
}
foreach kind = ["", "kind::f8f6f4"] in {
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
def : MMA<WMMA_REGINFO<op[0], "mma", "", kind>,
WMMA_REGINFO<op[1], "mma", "", kind>,
WMMA_REGINFO<op[2], "mma", "", kind>,
WMMA_REGINFO<op[3], "mma", "", kind>,
layout_a, layout_b, satf, b1op, kind>;
}
} // kind
} // b1op
} // op
} // satf
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# RUN: %python %s --ptx=87 --gpu-arch=120 --aa > %t-ptx87-sm_120a.ll
# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
# RUN: | FileCheck %t-ptx87-sm_120a.ll
# RUN: %if ptxas-12.7 %{ \
# RUN: %if ptxas-sm_120a && ptxas-isa-8.7 %{ \
# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
# RUN: | %ptxas-verify -arch=sm_120a \
# RUN: %}
Expand Down
Loading