Skip to content

Commit

Permalink
[AMDGPU] Add GFX12 WMMA and SWMMAC instructions (#77795)
Browse files Browse the repository at this point in the history
Co-authored-by: Petar Avramovic <Petar.Avramovic@amd.com>
Co-authored-by: Piotr Sobczak <piotr.sobczak@amd.com>
  • Loading branch information
3 people committed Jan 24, 2024
1 parent 27cfe7a commit 7fdf608
Show file tree
Hide file tree
Showing 65 changed files with 17,701 additions and 111 deletions.
62 changes: 62 additions & 0 deletions clang/include/clang/Basic/BuiltinsAMDGPU.def
Original file line number Diff line number Diff line change
Expand Up @@ -436,5 +436,67 @@ TARGET_BUILTIN(__builtin_amdgcn_global_load_tr_i32, "ii*1", "nc", "gfx12-insts,w
TARGET_BUILTIN(__builtin_amdgcn_global_load_tr_v4i16, "V4sV4s*1", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_global_load_tr_v4f16, "V4hV4h*1", "nc", "gfx12-insts,wavefrontsize64")

//===----------------------------------------------------------------------===//
// WMMA builtins.
// Postfix w32 indicates the builtin requires wavefront size of 32.
// Postfix w64 indicates the builtin requires wavefront size of 64.
//
// Some of these are very similar to their GFX11 counterparts, but they don't
// require replication of the A,B matrices, so they use fewer vector elements.
// Therefore, we add an "_gfx12" suffix to distinguish them from the existing
// builtins.
//===----------------------------------------------------------------------===//
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12, "V8fV8hV8hV8f", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12, "V8fV8sV8sV8f", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12, "V8hV8hV8hV8h", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12, "V8sV8sV8sV8s", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12, "V8iIbV2iIbV2iV8iIb", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12, "V8iIbiIbiV8iIb", "nc", "gfx12-insts,wavefrontsize32")
// These are gfx12-only, but for consistency with the other WMMA variants we're
// keeping the "_gfx12" suffix.
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12, "V8iIbV2iIbV2iV8iIb", "nc", "gfx12-insts,wavefrontsize32")

TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12, "V4fV4hV4hV4f", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12, "V4fV4sV4sV4f", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12, "V4hV4hV4hV4h", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12, "V4sV4sV4sV4s", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
// These are gfx12-only, but for consistency with the other WMMA variants we're
// keeping the "_gfx12" suffix.
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")

TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32, "V8fV8hV16hV8fs", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32, "V8fV8sV16sV8fs", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32, "V8hV8hV16hV8hs", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32, "V8sV8sV16sV8ss", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32, "V8iIbV2iIbV4iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32, "V8iIbiIbV2iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32, "V8iIbV2iIbV4iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")

TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64, "V4fV4hV8hV4fs", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64, "V4fV4sV8sV4fs", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64, "V4hV4hV8hV4hs", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64, "V4sV4sV8sV4ss", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64, "V4iIbiIbV2iV4isIb", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64, "V4iIbiIbiV4isIb", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64, "V4iIbiIbV2iV4isIb", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")

#undef BUILTIN
#undef TARGET_BUILTIN
177 changes: 164 additions & 13 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18279,65 +18279,216 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64: {
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64: {

// These operations perform a matrix multiplication and accumulation of
// the form:
// D = A * B + C
// The return type always matches the type of matrix C.
unsigned ArgForMatchingRetType;
// We need to specify one type for matrices AB and one for matrices CD.
// Sparse matrix operations can have different types for A and B as well as
// an additional type for sparsity index.
// Destination type should be put before types used for source operands.
SmallVector<unsigned, 2> ArgsForMatchingMatrixTypes;
// On GFX12, the intrinsics with 16-bit accumulator use a packed layout.
// There is no need for the variable opsel argument, so always set it to
// "false".
bool AppendFalseForOpselArg = false;
unsigned BuiltinWMMAOp;

switch (BuiltinID) {
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64:
ArgForMatchingRetType = 2;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_f16;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64:
ArgForMatchingRetType = 2;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf16;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
AppendFalseForOpselArg = true;
LLVM_FALLTHROUGH;
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64:
ArgForMatchingRetType = 2;
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
AppendFalseForOpselArg = true;
LLVM_FALLTHROUGH;
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
ArgForMatchingRetType = 2;
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w64:
ArgForMatchingRetType = 2;
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w64:
ArgForMatchingRetType = 2;
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
ArgForMatchingRetType = 4;
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu8;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
ArgForMatchingRetType = 4;
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu4;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8;
break;
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x32_iu4;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_f16;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f16_16x16x32_f16;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8;
break;
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64:
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8;
break;
}

SmallVector<Value *, 6> Args;
for (int i = 0, e = E->getNumArgs(); i != e; ++i)
Args.push_back(EmitScalarExpr(E->getArg(i)));
if (AppendFalseForOpselArg)
Args.push_back(Builder.getFalse());

Function *F = CGM.getIntrinsic(BuiltinWMMAOp,
{Args[ArgForMatchingRetType]->getType()});
SmallVector<llvm::Type *, 6> ArgTypes;
for (auto ArgIdx : ArgsForMatchingMatrixTypes)
ArgTypes.push_back(Args[ArgIdx]->getType());

Function *F = CGM.getIntrinsic(BuiltinWMMAOp, ArgTypes);
return Builder.CreateCall(F, Args);
}

Expand Down

0 comments on commit 7fdf608

Please sign in to comment.