Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 190 additions & 1 deletion llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
string gft = Geom#":"#Frag#":"#ptx_elt_type;
string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !if(!eq(IsSparse, true),
bit isSparse = IsSparse;
list<LLVMType> regs = !if(!eq(isSparse, true),
!cond(
// mma sparse ops use other fragments for some arguments
!eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
Expand Down Expand Up @@ -277,6 +278,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
!eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),

// mma.block_scale e2m1 (mxf4, mxf4nvf4) -> f32 @ m16n8k64
!eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),

// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
// All other supported geometries use the same fragment format for f32 and
// f16, so we only need to consider {fragment, type}.
Expand Down Expand Up @@ -520,6 +525,18 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, strin
# signature;
}

class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
string record_name = "int_nvvm_mma_block_scale"
# "_" # A.geom
# "_row_col"
# "_" # Kind
# !subst(".", "_", ScaleVecSize)
# signature
# "_" # SType;
}

class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
WMMA_REGS A, WMMA_REGS B,
WMMA_REGS C, WMMA_REGS D> {
Expand All @@ -533,6 +550,19 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
# signature;
}

class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
WMMA_REGS A, WMMA_REGS B,
WMMA_REGS C, WMMA_REGS D> {
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
string record_name = "int_nvvm_mma_sp_ordered_metadata_block_scale"
# "_" # A.geom
# "_row_col"
# "_" # Kind
# !subst(".", "_", ScaleVecSize)
# signature
# "_" # SType;
}

// Helper class that takes an intrinsic name and construct a record name.
// Additionally, sets `intr_name` to be non-empty if the default name assigned
// to this intrinsic will not match the name given.
Expand Down Expand Up @@ -683,6 +713,18 @@ class NVVM_MMA_OPS {
fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
int_mma_ops, subint_mma_ops, bit_mma_ops);

list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"]
>.ret;

list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"]
>.ret;

list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
mxf4_mma_ops, mxf8f6f4_mma_ops);

list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
["m16n8k16", "m16n8k32"],
["bf16"], [], ["f32"], [], true>.ret;
Expand All @@ -707,6 +749,18 @@ class NVVM_MMA_OPS {
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
subint_mma_sp_ops, int_mma_sp_ops);

// combines available geoms and types for mxf4 and mxf4nvf4 kinds
list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
["m16n8k128"],
["e2m1"], ["e2m1"], ["f32"], [], true>.ret;
list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
["m16n8k64"],
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
["f32"], [], true>.ret;
list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);

list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
Expand Down Expand Up @@ -900,6 +954,32 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
);
}

class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string stype, string scale_vec_size> {
string geom = frags[0].geom;

bit ret = !cond(
!and(!eq(geom, "m16n8k64"),
!eq(kind, "mxf4"),
!or(!eq(scale_vec_size, ""),
!eq(scale_vec_size, ".scale_2x")),
!eq(stype, "ue8m0")) : true,
!and(!eq(geom, "m16n8k64"),
!eq(kind, "mxf4nvf4"),
!eq(scale_vec_size, ".scale_2x"),
!eq(stype, "ue8m0")) : true,
!and(!eq(geom, "m16n8k64"),
!eq(kind, "mxf4nvf4"),
!eq(scale_vec_size, ".scale_4x"),
!eq(stype, "ue4m3")) : true,
!and(!eq(geom, "m16n8k32"),
!eq(kind, "mxf8f6f4"),
!or(!eq(scale_vec_size, ""),
!eq(scale_vec_size, ".scale_1x")),
!eq(stype, "ue8m0")) : true,
true: false
);
}

// Returns true if the fragment is valid for ldmatrix ops is supported;
// false otherwise.
// E.g.
Expand Down Expand Up @@ -998,6 +1078,51 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
}


// Returns true if this combination of kind/scale_vec_size/stype
// for MMA.SP ops is supported;
// false otherwise.
// E.g.
// if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
string stype, string scale_vec_size> {
// MMA.SP ops check both layouts.
string a_type = frags[0].ptx_elt_type;
string b_type = frags[1].ptx_elt_type;
string c_type = frags[2].ptx_elt_type;
string d_type = frags[3].ptx_elt_type;
string geom = frags[0].geom;

bit ret = !cond(
!and(!eq(geom, "m16n8k128"),
!eq(kind, "mxf4"),
!eq(stype, "ue8m0"),
!or(!eq(scale_vec_size, ""),
!eq(scale_vec_size, ".scale_2x"))): true,

!and(!eq(geom, "m16n8k128"),
!eq(kind, "mxf4nvf4"),
!eq(stype, "ue8m0"),
!eq(scale_vec_size, ".scale_2x")): true,

!and(!eq(geom, "m16n8k128"),
!eq(kind, "mxf4nvf4"),
!eq(stype, "ue4m3"),
!eq(scale_vec_size, ".scale_4x")): true,

!and(!eq(geom, "m16n8k64"),
!eq(kind, "mxf8f6f4"),
!eq(stype, "ue8m0"),
!or(!eq(scale_vec_size, ""),
!eq(scale_vec_size, ".scale_1x"))): true,

// All other are NOT OK.
true: false
);
}


class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
Expand Down Expand Up @@ -2452,6 +2577,31 @@ foreach layout_a = ["row", "col"] in {
} // layout_b
} // layout_a

class NVVM_MMA_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
!listconcat(A.regs, B.regs, C.regs,
[
llvm_i32_ty, // scale-a-data
llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
llvm_i32_ty, // scale-b-data,
llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
]),
[IntrNoMem, IntrNoCallback]>;

foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
foreach stype = ["ue8m0", "ue4m3"] in {
foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
def MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
op[0], op[1], op[2], op[3]>.record_name
: NVVM_MMA_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
}
} // op
} // stype
} // scale_vec_size
} // kind

// MMA.SP
class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
Expand Down Expand Up @@ -2499,6 +2649,45 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
} // kind
} // metadata

// MMA.SP BLOCK SCALE
class NVVM_MMA_SP_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
!listconcat(A.regs, B.regs, C.regs,
[
llvm_i32_ty, // metadata
llvm_i32_ty, // sparsity selector
llvm_i32_ty, // scale-a-data
llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
llvm_i32_ty, // scale-b-data
llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
])> {
int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));

// The range [0;num_threads) is for the sparsity selector that indicates the threads
// which contribute metadata.
// According to PTX ISA 9.0, the sparsity selector is always 0
// for sparse MMA block scale instructions
int num_threads = 1;
let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
Range<ArgIndex<pos>, 0, num_threads>];
}

// According to PTX ISA 9.0
// a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"]
foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
foreach stype = ["ue8m0", "ue4m3"] in {
foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
def MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
op[0], op[1], op[2], op[3]>.record_name
: NVVM_MMA_SP_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
}
} // op
} // stype
} // scale_vec_size
} // kind

// LDMATRIX
class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
: Intrinsic<Frag.regs, [llvm_anyptr_ty],
Expand Down
Loading