diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 3af1750ffcf3f..6256baa50a1c6 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -277,6 +277,10 @@ class WMMA_REGS f32 @ m16n8k64 + !eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4), + !eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4), + // wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16 // All other supported geometries use the same fragment format for f32 and // f16, so we only need to consider {fragment, type}. @@ -520,6 +524,18 @@ class MMA_NAME { + string signature = MMA_SIGNATURE.ret; + string record = "int_nvvm_mma_block_scale" + # "_" # A.geom + # "_row_col" + # "_" # Kind + # !subst(".", "_", ScaleVecSize) + # signature + # "_" # SType; +} + class MMA_SP_NAME { @@ -533,6 +549,19 @@ class MMA_SP_NAME { + string signature = MMA_SIGNATURE.ret; + string record = "int_nvvm_mma_sp_ordered_metadata_block_scale" + # "_" # A.geom + # "_row_col" + # "_" # Kind + # !subst(".", "_", ScaleVecSize) + # signature + # "_" # SType; +} + class LDMATRIX_NAME { string intr = "llvm.nvvm.ldmatrix.sync.aligned" # "." # Frag.geom @@ -672,6 +701,18 @@ class NVVM_MMA_OPS { fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops); + list> mxf4_mma_ops = MMA_OPS< + ["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"] + >.ret; + + list> mxf8f6f4_mma_ops = MMA_OPS< + ["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], + ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"] + >.ret; + + list> all_mma_block_scale_ops = !listconcat( + mxf4_mma_ops, mxf8f6f4_mma_ops); + list> bf16_mma_sp_ops = MMA_OPS< ["m16n8k16", "m16n8k32"], ["bf16"], [], ["f32"], [], true>.ret; @@ -696,6 +737,18 @@ class NVVM_MMA_OPS { bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops, subint_mma_sp_ops, int_mma_sp_ops); + // combines available geoms and types for mxf4 and mxf4nvf4 kinds + list> mxf4xx_mma_sp_ops = MMA_OPS< + ["m16n8k128"], + ["e2m1"], ["e2m1"], ["f32"], [], true>.ret; + list> mxf8f6f4_mma_sp_ops = MMA_OPS< + ["m16n8k64"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["f32"], [], true>.ret; + list> all_mma_sp_block_scale_ops = !listconcat( + mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops); + list ldst_ab_ops = MMA_LDST_OPS< ["m16n16k16", "m32n8k16", "m8n32k16"], ["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret; @@ -889,6 +942,32 @@ class NVVM_MMA_SUPPORTED frags, string layout_a, string layout_b ); } +class NVVM_MMA_BLOCK_SCALE_SUPPORTED frags, string kind, string stype, string scale_vec_size> { + string geom = frags[0].geom; + + bit ret = !cond( + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf4"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_2x")), + !eq(stype, "ue8m0")) : true, + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf4nvf4"), + !eq(scale_vec_size, ".scale_2x"), + !eq(stype, "ue8m0")) : true, + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf4nvf4"), + !eq(scale_vec_size, ".scale_4x"), + !eq(stype, "ue4m3")) : true, + !and(!eq(geom, "m16n8k32"), + !eq(kind, "mxf8f6f4"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_1x")), + !eq(stype, "ue8m0")) : true, + true: false + ); +} + // Returns true if the fragment is valid for ldmatrix ops is supported; // false otherwise. // E.g. @@ -987,6 +1066,51 @@ class NVVM_MMA_SP_SUPPORTED frags, string metadata, } +// Returns true if this combination of kind/scale_vec_size/stype +// for MMA.SP ops is supported; +// false otherwise. +// E.g. +// if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then +// def : FOO<>; // The record will only be defined for supported ops. +// +class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED frags, string kind, + string stype, string scale_vec_size> { + // MMA.SP ops check both layouts. + string a_type = frags[0].ptx_elt_type; + string b_type = frags[1].ptx_elt_type; + string c_type = frags[2].ptx_elt_type; + string d_type = frags[3].ptx_elt_type; + string geom = frags[0].geom; + + bit ret = !cond( + !and(!eq(geom, "m16n8k128"), + !eq(kind, "mxf4"), + !eq(stype, "ue8m0"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_2x"))): true, + + !and(!eq(geom, "m16n8k128"), + !eq(kind, "mxf4nvf4"), + !eq(stype, "ue8m0"), + !eq(scale_vec_size, ".scale_2x")): true, + + !and(!eq(geom, "m16n8k128"), + !eq(kind, "mxf4nvf4"), + !eq(stype, "ue4m3"), + !eq(scale_vec_size, ".scale_4x")): true, + + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf8f6f4"), + !eq(stype, "ue8m0"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_1x"))): true, + + // All other are NOT OK. + true: false + ); +} + + class SHFL_INFO { string Suffix = !if(sync, "sync_", "") # mode # "_" @@ -2340,6 +2464,31 @@ foreach layout_a = ["row", "col"] in { } // layout_b } // layout_a +class NVVM_MMA_BLOCK_SCALE + : Intrinsic; + +foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in { + foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in { + foreach stype = ["ue8m0", "ue4m3"] in { + foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in { + if NVVM_MMA_BLOCK_SCALE_SUPPORTED.ret then { + def MMA_BLOCK_SCALE_NAME.record + : NVVM_MMA_BLOCK_SCALE; + } + } // op + } // stype + } // scale_vec_size +} // kind + // MMA.SP class NVVM_MMA_SP : Intrinsic + : Intrinsic { + int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty])); + + // The range [0;num_threads) is for the sparsity selector that indicates the threads + // which contribute metadata. + // According to PTX ISA 9.0, the sparsity selector is always 0 + // for sparse MMA block scale instructions + int num_threads = 1; + let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg>, + Range, 0, num_threads>]; +} + +// According to PTX ISA 9.0 +// a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"] +foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in { + foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in { + foreach stype = ["ue8m0", "ue4m3"] in { + foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in { + if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED.ret then { + def MMA_SP_BLOCK_SCALE_NAME.record + : NVVM_MMA_SP_BLOCK_SCALE; + } + } // op + } // stype + } // scale_vec_size +} // kind + // LDMATRIX class NVVM_LDMATRIX : Intrinsic MMAs = { } // defset } +// MMA.block_scale +class MMA_BLOCK_SCALE + : WMMA_INSTR.record, + [FragA.Ins, FragB.Ins, FragC.Ins, + (ins B32:$scale_a, B16:$byte_id_a, + B16:$thread_id_a, B32:$scale_b, + B16:$byte_id_b, B16:$thread_id_b)]>, + // Requires does not seem to have effect on Instruction w/o Patterns. + // We set it here anyways and propagate to the Pat<> we construct below. + Requires { + let OutOperandList = FragD.Outs; + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + string TypeList = !interleave([FragD.ptx_elt_type, + FragA.ptx_elt_type, + FragB.ptx_elt_type, + FragC.ptx_elt_type], "."); + string ScaleVecSizeStr = !cond( + !eq(ScaleVecSize, "") : "", + !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X", + !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X", + !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X" + ); + let AsmString = "mma.sync.aligned." + # FragA.geom + # ".row.col" + # ".kind::" # Kind + # ".block_scale" + # ScaleVecSizeStr + # "." # TypeList + # "." # SType # " \n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ",\n\t\t" + # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t" + # "$scale_b, {{$byte_id_b, $thread_id_b}};"; +} + +let isConvergent = true in { +defset list MMA_BLOCK_SCALEs = { + foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in { + foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in { + foreach stype = ["ue8m0", "ue4m3"] in { + foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in { + if NVVM_MMA_BLOCK_SCALE_SUPPORTED.ret then { + def : MMA_BLOCK_SCALE, + WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO, + kind, stype, scale_vec_size>; + } + } // op + } // stype + } // scale_vec_size + } // kind +} // defset +} + // MMA SP class MMA_SP MMA_SPs = { } // defset } +// MMA SP BLOCK SCALE +class MMA_SP_BLOCK_SCALE + : WMMA_INSTR.record, + [FragA.Ins, FragB.Ins, FragC.Ins, + (ins B32:$metadata, i32imm:$selector, + B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a, + B32:$scale_b, B16:$byte_id_b, B16:$thread_id_b)]>, + // Requires does not seem to have effect on Instruction w/o Patterns. + // We set it here anyways and propagate to the Pat<> we construct below. + Requires { + let OutOperandList = FragD.Outs; + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + string TypeList = "." # FragD.ptx_elt_type + # "." # FragA.ptx_elt_type + # "." # FragB.ptx_elt_type + # "." # FragC.ptx_elt_type; + string ScaleVecSizeStr = !cond( + !eq(ScaleVecSize, "") : "", + !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X", + !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X", + !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X" + ); + let AsmString = "mma.sp::ordered_metadata.sync.aligned." + # FragA.geom + # ".row.col" + # ".kind::" # Kind + # ".block_scale" + # ScaleVecSizeStr + # TypeList + # "." # SType # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ",\n\t\t" + # "$metadata" # ",\n\t\t" + # "$selector" # ",\n\t\t" + # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t" + # "$scale_b, {{$byte_id_b, $thread_id_b}};"; +} + +let isConvergent = true in { +defset list MMA_SP_BLOCK_SCALEs = { + foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in { + foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in { + foreach stype = ["ue8m0", "ue4m3"] in { + foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in { + if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED.ret then { + def : MMA_SP_BLOCK_SCALE, + WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO, + kind, stype, scale_vec_size>; + } + } // op + } // stype + } // scale_vec_size + } // kind +} // defset +} + // // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 // @@ -5023,7 +5150,8 @@ class MMA_PAT Requires; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in +foreach mma = !listconcat(MMAs, MMA_BLOCK_SCALEs, WMMAs, MMA_LDSTs, LDMATRIXs, + STMATRIXs, MMA_SPs, MMA_SP_BLOCK_SCALEs) in def : MMA_PAT; multiclass MAPA { diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index 8427ae4ad72da..c2e69dd48da99 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -131,7 +131,7 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False): "m16n8k64:b:e5m2": 4, "m16n8k64:b:e3m2": 4, "m16n8k64:b:e2m3": 4, - "m16n8k64:b:e2m1": 4, + "m16n8k64:b:e2m1": 4 if is_mma_sparse else 2, "m16n8k64:c:f16": 2, "m16n8k64:c:f32": 4, "m16n8k64:d:f16": 2, @@ -1131,6 +1131,160 @@ def gen_mma_tests(): return generated_items +def get_mma_block_scale_ops(): + return make_mma_ops(["m16n8k64"], ["e2m1"], [], ["f32"], []) + make_mma_ops( + ["m16n8k32"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["f32"], + [], + ) + + +def is_mma_block_scale_geom_supported(geom): + # geometries for FP. + if geom in [ + "m16n8k32", + "m16n8k64", + ]: + return True + raise ValueError(f"Unexpected MMA block scale geometry: {geom}") + + +def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype): + if not ( + is_type_supported(op.a.mma_type.ptx_type) + and is_mma_block_scale_geom_supported(op.a.geom) + ): + return False + + if ( + op.a.geom == "m16n8k64" + and kind == "mxf4" + and stype == "ue8m0" + and scale_vec_size in ["", ".scale_vec::2X"] + ): + return True + + if ( + op.a.geom == "m16n8k64" + and kind == "mxf4nvf4" + and stype == "ue8m0" + and scale_vec_size == ".scale_vec::2X" + ): + return True + + if ( + op.a.geom == "m16n8k64" + and kind == "mxf4nvf4" + and stype == "ue4m3" + and scale_vec_size == ".scale_vec::4X" + ): + return True + + if ( + op.a.geom == "m16n8k32" + and kind == "mxf8f6f4" + and stype == "ue8m0" + and scale_vec_size in ["", ".scale_vec::1X"] + ): + return True + + return False + + +def common_mma_block_scale_test_gen( + params, op, intrinsic_template, instruction_template +): + mma_block_scale_template = """ +declare ${ret_ty} @${intrinsic}( + ${args}); + +; CHECK-LABEL: .func {{.*}}test_${function}( +define ${ret_ty} @test_${function}( + ${args}) { +; CHECK: ${instruction} +; CHECK-NEXT: ${check_d} +; CHECK-NEXT: ${check_a} +; CHECK-NEXT: ${check_b} +; CHECK-NEXT: ${check_c} +; CHECK-NEXT: ${check_scale_a_data} +; CHECK-NEXT: ${check_byte_id_a} +; CHECK-NEXT: ${check_thread_id_a} +; CHECK-NEXT: ${check_scale_b_data} +; CHECK-NEXT: ${check_byte_id_b} +; CHECK-NEXT: ${check_thread_id_b} + %r = call ${ret_ty} @${intrinsic}( + ${args}); + ret ${ret_ty} %r; +} +""" + + test_params = params + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".", "_") + test_params["instruction"] = Template(instruction_template).substitute(params) + test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d) + test_params["check_a"] = check_pattern(op.a) + test_params["check_b"] = check_pattern(op.b) + test_params["check_c"] = check_pattern(op.c) + test_params["check_d"] = check_pattern(op.d) + test_params["check_scale_a_data"] = "{{%r[0-9]+}}" + test_params["check_byte_id_a"] = "{{%r[0-9]+}}" + test_params["check_thread_id_a"] = "{{%r[0-9]+}}" + test_params["check_scale_b_data"] = "{{%r[0-9]+}}" + test_params["check_byte_id_b"] = "{{%r[0-9]+}}" + test_params["check_thread_id_b"] = "{{%r[0-9]+}}" + args = ",\n ".join( + list(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c)) + + ["i32 %scale_a_data", "i16 %byte_id_a, i16 %thread_id_a"] + + ["i32 %scale_b_data", "i16 %byte_id_b, i16 %thread_id_b"] + ) + test_params["args"] = args + print(Template(mma_block_scale_template).substitute(test_params)) + return (test_params["intrinsic"], test_params["instruction"]) + + +def gen_mma_block_scale_tests(): + if not (ptx_version >= 87 and gpu_arch >= 120 and aa): + return [] + + mma_block_scale_intrinsic_template = "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}" + mma_block_scale_instruction_template = "mma.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}" + + generated_items = [] + + for op, kind, scale_vec_size, stype in product( + get_mma_block_scale_ops(), + ["mxf4", "mxf4nvf4", "mxf8f6f4"], + ["", ".scale_vec::1X", ".scale_vec::2X", ".scale_vec::4X"], + ["ue8m0", "ue4m3"], + ): + if not is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype): + continue + + params = { + "intrinsic_signature": mma_signature(op), + "ptx_signature": mma_ptx_signature(op), + "geom": op.a.geom, + "kind": kind, + "scale_vec_size": scale_vec_size, + "scale": scale_vec_size.replace("_vec::", ".").lower(), + "stype": stype, + } + + intrinsic_template = mma_block_scale_intrinsic_template + instruction_template = mma_block_scale_instruction_template + + generated_items.append( + common_mma_block_scale_test_gen( + params, op, intrinsic_template, instruction_template + ) + ) + + return generated_items + + def get_mma_sp_ops(): return ( make_mma_ops(["m16n8k16", "m16n8k32"], ["bf16"], [], ["f32"], [], True) @@ -1224,7 +1378,11 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf): return True -def sp_selector_gen(op): +def sp_selector_gen(op, block_scale=False): + if block_scale: + # PTX ISA 9.0 has the sparsity selector equal to 0 only + return range(1) + # (geom, type) -> allowed selector range range_01 = { ("m16n8k32", "bf16"), @@ -1355,6 +1513,178 @@ def gen_mma_sp_tests(): return generated_items +def get_mma_sp_block_scale_ops(): + return make_mma_ops(["m16n8k128"], ["e2m1"], [], ["f32"], [], True) + make_mma_ops( + ["m16n8k64"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["f32"], + [], + True, + ) + + +def is_mma_sp_block_scale_geom_supported(geom): + # geometries for FP. + if geom in [ + "m16n8k64", + "m16n8k128", + ]: + return True + raise ValueError(f"Unexpected sparse MMA block scale geometry: {geom}") + + +def is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype): + if not ( + is_type_supported(op.a.mma_type.ptx_type) + and is_mma_sp_block_scale_geom_supported(op.a.geom) + ): + return False + + if ( + op.a.geom == "m16n8k128" + and kind == "mxf4" + and stype == "ue8m0" + and scale_vec_size in ["", ".scale_vec::2X"] + ): + return True + + if ( + op.a.geom == "m16n8k128" + and kind == "mxf4nvf4" + and stype == "ue8m0" + and scale_vec_size == ".scale_vec::2X" + ): + return True + + if ( + op.a.geom == "m16n8k128" + and kind == "mxf4nvf4" + and stype == "ue4m3" + and scale_vec_size == ".scale_vec::4X" + ): + return True + + if ( + op.a.geom == "m16n8k64" + and kind == "mxf8f6f4" + and stype == "ue8m0" + and scale_vec_size in ["", ".scale_vec::1X"] + ): + return True + + return False + + +def common_mma_sp_block_scale_test_gen( + params, op, intrinsic_template, instruction_template +): + mma_sp_block_scale_decl_template = """ +declare ${ret_ty} @${intrinsic}( + ${args}); +""" + + mma_sp_block_scale_test_template = """ +; CHECK-LABEL: .func {{.*}}test_${function}_${selector}( +define ${ret_ty} @test_${function}_${selector}( + ${args}) { +; CHECK: ${instruction} +; CHECK-NEXT: ${check_d} +; CHECK-NEXT: ${check_a} +; CHECK-NEXT: ${check_b} +; CHECK-NEXT: ${check_c} +; CHECK-NEXT: ${check_metadata} +; CHECK-NEXT: ${check_selector} +; CHECK-NEXT: ${check_scale_a_data} +; CHECK-NEXT: ${check_byte_id_a} +; CHECK-NEXT: ${check_thread_id_a} +; CHECK-NEXT: ${check_scale_b_data} +; CHECK-NEXT: ${check_byte_id_b} +; CHECK-NEXT: ${check_thread_id_b} + %r = call ${ret_ty} @${intrinsic}( + ${call_args}); + ret ${ret_ty} %r; +} +""" + + test_params = params + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".", "_") + test_params["instruction"] = Template(instruction_template).substitute(params) + test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d) + test_params["check_a"] = check_pattern(op.a) + test_params["check_b"] = check_pattern(op.b) + test_params["check_c"] = check_pattern(op.c) + test_params["check_d"] = check_pattern(op.d) + test_params["check_metadata"] = "{{%r[0-9]+}}" + test_params["check_scale_a_data"] = "{{%r[0-9]+}}" + test_params["check_byte_id_a"] = "{{%r[0-9]+}}" + test_params["check_thread_id_a"] = "{{%r[0-9]+}}" + test_params["check_scale_b_data"] = "{{%r[0-9]+}}" + test_params["check_byte_id_b"] = "{{%r[0-9]+}}" + test_params["check_thread_id_b"] = "{{%r[0-9]+}}" + args = ",\n ".join( + list(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c)) + + ["i32 %metadata", "i32 %selector"] + + ["i32 %scale_a_data", "i16 %byte_id_a, i16 %thread_id_a"] + + ["i32 %scale_b_data", "i16 %byte_id_b, i16 %thread_id_b"] + ) + test_params["args"] = args + + print(Template(mma_sp_block_scale_decl_template).substitute(test_params)) + + for selector in [str(r) for r in sp_selector_gen(op, True)]: + test_params["selector"] = selector + test_params["check_selector"] = "{{" + test_params["selector"] + "}}" + test_params["call_args"] = test_params["args"].replace( + "%selector", test_params["selector"] + ) + + print(Template(mma_sp_block_scale_test_template).substitute(test_params)) + + return (test_params["intrinsic"], test_params["instruction"]) + + +def gen_mma_sp_block_scale_tests(): + if not (ptx_version >= 87 and gpu_arch >= 120 and aa): + return [] + + mma_sp_block_scale_intrinsic_template = "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}" + mma_sp_block_scale_instruction_template = "mma.sp::ordered_metadata.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}" + + generated_items = [] + + for op, kind, scale_vec_size, stype in product( + get_mma_sp_block_scale_ops(), + ["mxf4", "mxf4nvf4", "mxf8f6f4"], + ["", ".scale_vec::1X", ".scale_vec::2X", ".scale_vec::4X"], + ["ue8m0", "ue4m3"], + ): + if not is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype): + continue + + params = { + "intrinsic_signature": mma_signature(op), + "ptx_signature": mma_ptx_signature(op), + "geom": op.a.geom, + "kind": kind, + "scale_vec_size": scale_vec_size, + "scale": scale_vec_size.replace("_vec::", ".").lower(), + "stype": stype, + } + + intrinsic_template = mma_sp_block_scale_intrinsic_template + instruction_template = mma_sp_block_scale_instruction_template + + generated_items.append( + common_mma_sp_block_scale_test_gen( + params, op, intrinsic_template, instruction_template + ) + ) + + return generated_items + + # Append complete list of intrinsics and instructions we've generated tests for. # Generate set of checks to verify that that we did generate sensible set of # tests for the given combination of PTX and SM variants. @@ -1545,7 +1875,9 @@ def gen_tests(): items += gen_stmatrix_tests() items += gen_wmma_mma_tests() items += gen_mma_tests() + items += gen_mma_block_scale_tests() items += gen_mma_sp_tests() + items += gen_mma_sp_block_scale_tests() gen_check_unsupported_ops(items)