Skip to content

Commit 2f627c1

Browse files
authored
[NVPTX] Support for dense and sparse MMA intrinsics with block scaling. (#163561)
This change adds dense and sparse MMA intrinsics with block scaling. The implementation is based on [PTX ISA version 9.0](https://docs.nvidia.com/cuda/parallel-thread-execution/). Tests for new intrinsics are added for PTX 8.7 and SM 120a and are generated by `llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py`. The tests have been verified with ptxas from CUDA-13.0 release. Dense MMA intrinsics with block scaling were supported by @schwarzschild-radius.
1 parent 18d3db4 commit 2f627c1

File tree

4 files changed

+672
-7
lines changed

4 files changed

+672
-7
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
178178
string gft = Geom#":"#Frag#":"#ptx_elt_type;
179179
string gf = Geom#":"#Frag;
180180
string ft = frag#":"#ptx_elt_type;
181-
list<LLVMType> regs = !if(!eq(IsSparse, true),
181+
bit isSparse = IsSparse;
182+
list<LLVMType> regs = !if(!eq(isSparse, true),
182183
!cond(
183184
// mma sparse ops use other fragments for some arguments
184185
!eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
@@ -277,6 +278,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
277278
!eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
278279
!eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
279280

281+
// mma.block_scale e2m1 (mxf4, mxf4nvf4) -> f32 @ m16n8k64
282+
!eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
283+
!eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),
284+
280285
// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
281286
// All other supported geometries use the same fragment format for f32 and
282287
// f16, so we only need to consider {fragment, type}.
@@ -520,6 +525,18 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, strin
520525
# signature;
521526
}
522527

528+
class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
529+
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
530+
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
531+
string record_name = "int_nvvm_mma_block_scale"
532+
# "_" # A.geom
533+
# "_row_col"
534+
# "_" # Kind
535+
# !subst(".", "_", ScaleVecSize)
536+
# signature
537+
# "_" # SType;
538+
}
539+
523540
class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
524541
WMMA_REGS A, WMMA_REGS B,
525542
WMMA_REGS C, WMMA_REGS D> {
@@ -533,6 +550,19 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
533550
# signature;
534551
}
535552

553+
class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
554+
WMMA_REGS A, WMMA_REGS B,
555+
WMMA_REGS C, WMMA_REGS D> {
556+
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
557+
string record_name = "int_nvvm_mma_sp_ordered_metadata_block_scale"
558+
# "_" # A.geom
559+
# "_row_col"
560+
# "_" # Kind
561+
# !subst(".", "_", ScaleVecSize)
562+
# signature
563+
# "_" # SType;
564+
}
565+
536566
// Helper class that takes an intrinsic name and construct a record name.
537567
// Additionally, sets `intr_name` to be non-empty if the default name assigned
538568
// to this intrinsic will not match the name given.
@@ -683,6 +713,18 @@ class NVVM_MMA_OPS {
683713
fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
684714
int_mma_ops, subint_mma_ops, bit_mma_ops);
685715

716+
list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
717+
["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"]
718+
>.ret;
719+
720+
list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
721+
["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
722+
["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"]
723+
>.ret;
724+
725+
list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
726+
mxf4_mma_ops, mxf8f6f4_mma_ops);
727+
686728
list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
687729
["m16n8k16", "m16n8k32"],
688730
["bf16"], [], ["f32"], [], true>.ret;
@@ -707,6 +749,18 @@ class NVVM_MMA_OPS {
707749
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
708750
subint_mma_sp_ops, int_mma_sp_ops);
709751

752+
// combines available geoms and types for mxf4 and mxf4nvf4 kinds
753+
list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
754+
["m16n8k128"],
755+
["e2m1"], ["e2m1"], ["f32"], [], true>.ret;
756+
list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
757+
["m16n8k64"],
758+
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
759+
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
760+
["f32"], [], true>.ret;
761+
list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
762+
mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
763+
710764
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
711765
["m16n16k16", "m32n8k16", "m8n32k16"],
712766
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -900,6 +954,32 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
900954
);
901955
}
902956

957+
class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string stype, string scale_vec_size> {
958+
string geom = frags[0].geom;
959+
960+
bit ret = !cond(
961+
!and(!eq(geom, "m16n8k64"),
962+
!eq(kind, "mxf4"),
963+
!or(!eq(scale_vec_size, ""),
964+
!eq(scale_vec_size, ".scale_2x")),
965+
!eq(stype, "ue8m0")) : true,
966+
!and(!eq(geom, "m16n8k64"),
967+
!eq(kind, "mxf4nvf4"),
968+
!eq(scale_vec_size, ".scale_2x"),
969+
!eq(stype, "ue8m0")) : true,
970+
!and(!eq(geom, "m16n8k64"),
971+
!eq(kind, "mxf4nvf4"),
972+
!eq(scale_vec_size, ".scale_4x"),
973+
!eq(stype, "ue4m3")) : true,
974+
!and(!eq(geom, "m16n8k32"),
975+
!eq(kind, "mxf8f6f4"),
976+
!or(!eq(scale_vec_size, ""),
977+
!eq(scale_vec_size, ".scale_1x")),
978+
!eq(stype, "ue8m0")) : true,
979+
true: false
980+
);
981+
}
982+
903983
// Returns true if the fragment is valid for ldmatrix ops is supported;
904984
// false otherwise.
905985
// E.g.
@@ -998,6 +1078,51 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
9981078
}
9991079

10001080

1081+
// Returns true if this combination of kind/scale_vec_size/stype
1082+
// for MMA.SP ops is supported;
1083+
// false otherwise.
1084+
// E.g.
1085+
// if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then
1086+
// def : FOO<>; // The record will only be defined for supported ops.
1087+
//
1088+
class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
1089+
string stype, string scale_vec_size> {
1090+
// MMA.SP ops check both layouts.
1091+
string a_type = frags[0].ptx_elt_type;
1092+
string b_type = frags[1].ptx_elt_type;
1093+
string c_type = frags[2].ptx_elt_type;
1094+
string d_type = frags[3].ptx_elt_type;
1095+
string geom = frags[0].geom;
1096+
1097+
bit ret = !cond(
1098+
!and(!eq(geom, "m16n8k128"),
1099+
!eq(kind, "mxf4"),
1100+
!eq(stype, "ue8m0"),
1101+
!or(!eq(scale_vec_size, ""),
1102+
!eq(scale_vec_size, ".scale_2x"))): true,
1103+
1104+
!and(!eq(geom, "m16n8k128"),
1105+
!eq(kind, "mxf4nvf4"),
1106+
!eq(stype, "ue8m0"),
1107+
!eq(scale_vec_size, ".scale_2x")): true,
1108+
1109+
!and(!eq(geom, "m16n8k128"),
1110+
!eq(kind, "mxf4nvf4"),
1111+
!eq(stype, "ue4m3"),
1112+
!eq(scale_vec_size, ".scale_4x")): true,
1113+
1114+
!and(!eq(geom, "m16n8k64"),
1115+
!eq(kind, "mxf8f6f4"),
1116+
!eq(stype, "ue8m0"),
1117+
!or(!eq(scale_vec_size, ""),
1118+
!eq(scale_vec_size, ".scale_1x"))): true,
1119+
1120+
// All other are NOT OK.
1121+
true: false
1122+
);
1123+
}
1124+
1125+
10011126
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
10021127
string Suffix = !if(sync, "sync_", "")
10031128
# mode # "_"
@@ -2452,6 +2577,31 @@ foreach layout_a = ["row", "col"] in {
24522577
} // layout_b
24532578
} // layout_a
24542579

2580+
class NVVM_MMA_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
2581+
: Intrinsic<D.regs,
2582+
!listconcat(A.regs, B.regs, C.regs,
2583+
[
2584+
llvm_i32_ty, // scale-a-data
2585+
llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
2586+
llvm_i32_ty, // scale-b-data,
2587+
llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
2588+
]),
2589+
[IntrNoMem, IntrNoCallback]>;
2590+
2591+
foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
2592+
foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
2593+
foreach stype = ["ue8m0", "ue4m3"] in {
2594+
foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
2595+
if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
2596+
def MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
2597+
op[0], op[1], op[2], op[3]>.record_name
2598+
: NVVM_MMA_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
2599+
}
2600+
} // op
2601+
} // stype
2602+
} // scale_vec_size
2603+
} // kind
2604+
24552605
// MMA.SP
24562606
class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
24572607
: Intrinsic<D.regs,
@@ -2499,6 +2649,45 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
24992649
} // kind
25002650
} // metadata
25012651

2652+
// MMA.SP BLOCK SCALE
2653+
class NVVM_MMA_SP_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
2654+
: Intrinsic<D.regs,
2655+
!listconcat(A.regs, B.regs, C.regs,
2656+
[
2657+
llvm_i32_ty, // metadata
2658+
llvm_i32_ty, // sparsity selector
2659+
llvm_i32_ty, // scale-a-data
2660+
llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
2661+
llvm_i32_ty, // scale-b-data
2662+
llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
2663+
])> {
2664+
int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
2665+
2666+
// The range [0;num_threads) is for the sparsity selector that indicates the threads
2667+
// which contribute metadata.
2668+
// According to PTX ISA 9.0, the sparsity selector is always 0
2669+
// for sparse MMA block scale instructions
2670+
int num_threads = 1;
2671+
let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
2672+
Range<ArgIndex<pos>, 0, num_threads>];
2673+
}
2674+
2675+
// According to PTX ISA 9.0
2676+
// a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"]
2677+
foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
2678+
foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
2679+
foreach stype = ["ue8m0", "ue4m3"] in {
2680+
foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
2681+
if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
2682+
def MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
2683+
op[0], op[1], op[2], op[3]>.record_name
2684+
: NVVM_MMA_SP_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
2685+
}
2686+
} // op
2687+
} // stype
2688+
} // scale_vec_size
2689+
} // kind
2690+
25022691
// LDMATRIX
25032692
class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
25042693
: Intrinsic<Frag.regs, [llvm_anyptr_ty],

0 commit comments

Comments
 (0)