diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index cd7a0bc9c4b48..130fa27e4f870 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -2161,6 +2161,7 @@ class NVVM_MMA_SP // The range [0;num_threads) is for the sparsity selector that indicates the threads // which contribute metadata. int num_threads = !if(!or(!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "bf16")), + !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "f16")), !and(!eq(A.geom, "m16n8k16"), !eq(A.ptx_elt_type, "tf32")), !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "u8")), !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "s8")), @@ -2175,7 +2176,11 @@ class NVVM_MMA_SP !eq(A.ptx_elt_type, "e3m2"), !eq(A.ptx_elt_type, "e2m3"), !eq(A.ptx_elt_type, "e2m1"))), - 1, 4)); + 1, + !if(!and(!eq(A.geom, "m16n8k128"), + !or(!eq(A.ptx_elt_type, "s4"), + !eq(A.ptx_elt_type, "u4"))), + 1, 4))); let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg>, Range, 0, num_threads>]; } diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index f4f166c4018d0..6d73bce46da7c 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -1135,6 +1135,7 @@ def sp_selector_gen(op): # (geom, type) -> allowed selector range range_01 = { ("m16n8k32", "bf16"), + ("m16n8k32", "f16"), ("m16n8k16", "tf32"), ("m16n8k32", "u8"), ("m16n8k32", "s8"), @@ -1154,6 +1155,11 @@ def sp_selector_gen(op): "e2m1", ]: return range(1) + if op.a.geom == "m16n8k128" and op.a.mma_type.ptx_type in [ + "u4", + "s4", + ]: + return range(1) return range(4)