Skip to content

Commit

Permalink
[LLVMGPU][ROCm] Add MFMA_F32_16x16x4_F32 instruction
Browse files Browse the repository at this point in the history
  • Loading branch information
pashu123 committed Jul 10, 2024
1 parent 9ac1015 commit 9412a5e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 9 deletions.
38 changes: 38 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
Type i32 = IntegerType::get(context, 32);

switch (type) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return OpaqueMmaLayout{16, 16, 4, f32, f32, f32};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
Expand Down Expand Up @@ -251,6 +254,23 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
LayoutDimensionAttr::get(context, LayoutDimension::VECTORZ);
(void)laneZ, (void)vectorZ;
switch (type) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 1]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>

auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 1});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
auto cNLayout = outer;
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
Expand Down Expand Up @@ -404,6 +424,12 @@ MMAAttr::getABCVectorTypes() const {
// amd_matrix_instruction_calculator tells us about the number of 32-bit
// registers. So need to adjust accordingly. All vectors should be 1-D.
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
Expand Down Expand Up @@ -450,6 +476,7 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {

int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
Expand All @@ -465,6 +492,7 @@ int64_t MMAAttr::getBlockSize() const {

int64_t MMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
Expand All @@ -482,6 +510,10 @@ int64_t MMAAttr::getSubgroupSize() const {

MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 4}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 4}};
Expand Down Expand Up @@ -509,6 +541,10 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {

MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
Expand Down Expand Up @@ -536,6 +572,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {

MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
Expand Down Expand Up @@ -571,6 +608,7 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
return failure();
}
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,18 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>
}

// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 2>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 3>;
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>;
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 1>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 2>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 3>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 4>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 4>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 5>;
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 5>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 6>;

def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F32_16x16x4_F32,
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
MFMA_I8_16x16x32_I32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,8 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,

const WgpDetails *getCDNA3WgpDetails() {
static const MMAIntrinsic cdna3MMAOps[] = {
MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32,
MMAIntrinsic::MFMA_I8_16x16x32_I32,
MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32, MMAIntrinsic::MFMA_I8_16x16x32_I32,
MMAIntrinsic::MFMA_I8_32x32x16_I32,
};
static const WgpDetails cdna3Wgp = {
Expand Down
28 changes: 28 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,34 @@ iree_generated_e2e_runner_test(
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_f32_large_cdna3_mfma
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
"--shapes=gpu_large_aligned"
"--compilation_info=LLVMGPUVectorDistributeMFMA"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"hip"
COMPILER_FLAGS
${IREE_HIP_TEST_COMPILER_FLAGS}
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_f16_large_cdna3_mfma_tb
Expand Down
5 changes: 5 additions & 0 deletions tests/e2e/matmul/generate_e2e_matmul_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def get_rocm_test_compilation_infos(
schedules = []
if intrinsic == "MFMA":
schedules = [
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1),
Expand Down Expand Up @@ -302,6 +303,10 @@ def get_rocm_test_compilation_infos(
if lhs_rhs_type.value.upper() not in schedule.intrinsic:
continue

if schedule.intrinsic == "MFMA_F32_16x16x4_F32":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 4
if schedule.intrinsic == "MFMA_F16_16x16x16_F32":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
Expand Down

0 comments on commit 9412a5e

Please sign in to comment.