@@ -178,7 +178,8 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
178178 string gft = Geom#":"#Frag#":"#ptx_elt_type;
179179 string gf = Geom#":"#Frag;
180180 string ft = frag#":"#ptx_elt_type;
181- list<LLVMType> regs = !if(!eq(IsSparse, true),
181+ bit isSparse = IsSparse;
182+ list<LLVMType> regs = !if(!eq(isSparse, true),
182183 !cond(
183184 // mma sparse ops use other fragments for some arguments
184185 !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
@@ -277,6 +278,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
277278 !eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
278279 !eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
279280
281+ // mma.block_scale e2m1 (mxf4, mxf4nvf4) -> f32 @ m16n8k64
282+ !eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
283+ !eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),
284+
280285 // wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
281286 // All other supported geometries use the same fragment format for f32 and
282287 // f16, so we only need to consider {fragment, type}.
@@ -520,6 +525,18 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, strin
520525 # signature;
521526}
522527
528+ class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
529+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
530+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
531+ string record_name = "int_nvvm_mma_block_scale"
532+ # "_" # A.geom
533+ # "_row_col"
534+ # "_" # Kind
535+ # !subst(".", "_", ScaleVecSize)
536+ # signature
537+ # "_" # SType;
538+ }
539+
523540class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
524541 WMMA_REGS A, WMMA_REGS B,
525542 WMMA_REGS C, WMMA_REGS D> {
@@ -533,6 +550,19 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
533550 # signature;
534551}
535552
553+ class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
554+ WMMA_REGS A, WMMA_REGS B,
555+ WMMA_REGS C, WMMA_REGS D> {
556+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
557+ string record_name = "int_nvvm_mma_sp_ordered_metadata_block_scale"
558+ # "_" # A.geom
559+ # "_row_col"
560+ # "_" # Kind
561+ # !subst(".", "_", ScaleVecSize)
562+ # signature
563+ # "_" # SType;
564+ }
565+
536566// Helper class that takes an intrinsic name and construct a record name.
537567// Additionally, sets `intr_name` to be non-empty if the default name assigned
538568// to this intrinsic will not match the name given.
@@ -683,6 +713,18 @@ class NVVM_MMA_OPS {
683713 fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
684714 int_mma_ops, subint_mma_ops, bit_mma_ops);
685715
716+ list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
717+ ["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"]
718+ >.ret;
719+
720+ list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
721+ ["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
722+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"]
723+ >.ret;
724+
725+ list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
726+ mxf4_mma_ops, mxf8f6f4_mma_ops);
727+
686728 list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
687729 ["m16n8k16", "m16n8k32"],
688730 ["bf16"], [], ["f32"], [], true>.ret;
@@ -707,6 +749,18 @@ class NVVM_MMA_OPS {
707749 bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
708750 subint_mma_sp_ops, int_mma_sp_ops);
709751
752+ // combines available geoms and types for mxf4 and mxf4nvf4 kinds
753+ list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
754+ ["m16n8k128"],
755+ ["e2m1"], ["e2m1"], ["f32"], [], true>.ret;
756+ list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
757+ ["m16n8k64"],
758+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
759+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
760+ ["f32"], [], true>.ret;
761+ list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
762+ mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
763+
710764 list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
711765 ["m16n16k16", "m32n8k16", "m8n32k16"],
712766 ["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -900,6 +954,32 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
900954 );
901955}
902956
957+ class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string stype, string scale_vec_size> {
958+ string geom = frags[0].geom;
959+
960+ bit ret = !cond(
961+ !and(!eq(geom, "m16n8k64"),
962+ !eq(kind, "mxf4"),
963+ !or(!eq(scale_vec_size, ""),
964+ !eq(scale_vec_size, ".scale_2x")),
965+ !eq(stype, "ue8m0")) : true,
966+ !and(!eq(geom, "m16n8k64"),
967+ !eq(kind, "mxf4nvf4"),
968+ !eq(scale_vec_size, ".scale_2x"),
969+ !eq(stype, "ue8m0")) : true,
970+ !and(!eq(geom, "m16n8k64"),
971+ !eq(kind, "mxf4nvf4"),
972+ !eq(scale_vec_size, ".scale_4x"),
973+ !eq(stype, "ue4m3")) : true,
974+ !and(!eq(geom, "m16n8k32"),
975+ !eq(kind, "mxf8f6f4"),
976+ !or(!eq(scale_vec_size, ""),
977+ !eq(scale_vec_size, ".scale_1x")),
978+ !eq(stype, "ue8m0")) : true,
979+ true: false
980+ );
981+ }
982+
903983// Returns true if the fragment is valid for ldmatrix ops is supported;
904984// false otherwise.
905985// E.g.
@@ -998,6 +1078,51 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
9981078}
9991079
10001080
1081+ // Returns true if this combination of kind/scale_vec_size/stype
1082+ // for MMA.SP ops is supported;
1083+ // false otherwise.
1084+ // E.g.
1085+ // if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then
1086+ // def : FOO<>; // The record will only be defined for supported ops.
1087+ //
1088+ class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
1089+ string stype, string scale_vec_size> {
1090+ // MMA.SP ops check both layouts.
1091+ string a_type = frags[0].ptx_elt_type;
1092+ string b_type = frags[1].ptx_elt_type;
1093+ string c_type = frags[2].ptx_elt_type;
1094+ string d_type = frags[3].ptx_elt_type;
1095+ string geom = frags[0].geom;
1096+
1097+ bit ret = !cond(
1098+ !and(!eq(geom, "m16n8k128"),
1099+ !eq(kind, "mxf4"),
1100+ !eq(stype, "ue8m0"),
1101+ !or(!eq(scale_vec_size, ""),
1102+ !eq(scale_vec_size, ".scale_2x"))): true,
1103+
1104+ !and(!eq(geom, "m16n8k128"),
1105+ !eq(kind, "mxf4nvf4"),
1106+ !eq(stype, "ue8m0"),
1107+ !eq(scale_vec_size, ".scale_2x")): true,
1108+
1109+ !and(!eq(geom, "m16n8k128"),
1110+ !eq(kind, "mxf4nvf4"),
1111+ !eq(stype, "ue4m3"),
1112+ !eq(scale_vec_size, ".scale_4x")): true,
1113+
1114+ !and(!eq(geom, "m16n8k64"),
1115+ !eq(kind, "mxf8f6f4"),
1116+ !eq(stype, "ue8m0"),
1117+ !or(!eq(scale_vec_size, ""),
1118+ !eq(scale_vec_size, ".scale_1x"))): true,
1119+
1120+ // All other are NOT OK.
1121+ true: false
1122+ );
1123+ }
1124+
1125+
10011126class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
10021127 string Suffix = !if(sync, "sync_", "")
10031128 # mode # "_"
@@ -2452,6 +2577,31 @@ foreach layout_a = ["row", "col"] in {
24522577 } // layout_b
24532578} // layout_a
24542579
2580+ class NVVM_MMA_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
2581+ : Intrinsic<D.regs,
2582+ !listconcat(A.regs, B.regs, C.regs,
2583+ [
2584+ llvm_i32_ty, // scale-a-data
2585+ llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
2586+ llvm_i32_ty, // scale-b-data,
2587+ llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
2588+ ]),
2589+ [IntrNoMem, IntrNoCallback]>;
2590+
2591+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
2592+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
2593+ foreach stype = ["ue8m0", "ue4m3"] in {
2594+ foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
2595+ if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
2596+ def MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
2597+ op[0], op[1], op[2], op[3]>.record_name
2598+ : NVVM_MMA_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
2599+ }
2600+ } // op
2601+ } // stype
2602+ } // scale_vec_size
2603+ } // kind
2604+
24552605// MMA.SP
24562606class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
24572607 : Intrinsic<D.regs,
@@ -2499,6 +2649,45 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
24992649 } // kind
25002650} // metadata
25012651
2652+ // MMA.SP BLOCK SCALE
2653+ class NVVM_MMA_SP_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
2654+ : Intrinsic<D.regs,
2655+ !listconcat(A.regs, B.regs, C.regs,
2656+ [
2657+ llvm_i32_ty, // metadata
2658+ llvm_i32_ty, // sparsity selector
2659+ llvm_i32_ty, // scale-a-data
2660+ llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
2661+ llvm_i32_ty, // scale-b-data
2662+ llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
2663+ ])> {
2664+ int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
2665+
2666+ // The range [0;num_threads) is for the sparsity selector that indicates the threads
2667+ // which contribute metadata.
2668+ // According to PTX ISA 9.0, the sparsity selector is always 0
2669+ // for sparse MMA block scale instructions
2670+ int num_threads = 1;
2671+ let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
2672+ Range<ArgIndex<pos>, 0, num_threads>];
2673+ }
2674+
2675+ // According to PTX ISA 9.0
2676+ // a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"]
2677+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
2678+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
2679+ foreach stype = ["ue8m0", "ue4m3"] in {
2680+ foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
2681+ if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
2682+ def MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
2683+ op[0], op[1], op[2], op[3]>.record_name
2684+ : NVVM_MMA_SP_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
2685+ }
2686+ } // op
2687+ } // stype
2688+ } // scale_vec_size
2689+ } // kind
2690+
25022691// LDMATRIX
25032692class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
25042693 : Intrinsic<Frag.regs, [llvm_anyptr_ty],
0 commit comments