Skip to content

Commit

Permalink
Fix XDLOPS code selection logic for fp32.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Sep 11, 2020
1 parent f8b0090 commit e15b6ff
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions mlir/include/mlir/Dialect/MIOpen/XdlopsCodeSelection.h
Expand Up @@ -52,7 +52,7 @@ struct XdlopsCodeSelection {

if (dataType == b.getF32Type()) {
if (MPerWave == 128 && NPerWave == 64) {
mfmaInstr = "mfma_f32_32x32x1xf32";
mfmaInstr = "mfma_f32_32x32x1f32";
MPerXdlops = 64;
NPerXdlops = 64;
MRepeats = 2;
Expand All @@ -64,7 +64,7 @@ struct XdlopsCodeSelection {
imms.push_back({ 1, 0, 0 });
imms.push_back({ 1, 1, 0 });
} else if (MPerWave == 64 && NPerWave == 128) {
mfmaInstr = "mfma_f32_32x32x1xf32";
mfmaInstr = "mfma_f32_32x32x1f32";
MPerXdlops = 64;
NPerXdlops = 64;
MRepeats = 1;
Expand All @@ -76,7 +76,7 @@ struct XdlopsCodeSelection {
imms.push_back({ 1, 0, 0 });
imms.push_back({ 1, 1, 0 });
} else if (MPerWave == 64 && NPerWave == 64) {
mfmaInstr = "mfma_f32_32x32x1xf32";
mfmaInstr = "mfma_f32_32x32x1f32";
MPerXdlops = 64;
NPerXdlops = 64;
MRepeats = 1;
Expand All @@ -86,7 +86,7 @@ struct XdlopsCodeSelection {
imms.push_back({ 1, 0, 0 });
imms.push_back({ 1, 1, 0 });
} else if (MPerWave == 64 && NPerWave == 32) {
mfmaInstr = "mfma_f32_32x32x1xf32";
mfmaInstr = "mfma_f32_32x32x1f32";
MPerXdlops = 64;
NPerXdlops = 32;
MRepeats = 1;
Expand All @@ -95,7 +95,7 @@ struct XdlopsCodeSelection {
vectorNumber = 1;
imms.push_back({ 0, 0, 1 });
} else if (MPerWave == 32 && NPerWave == 64) {
mfmaInstr = "mfma_f32_32x32x1xf32";
mfmaInstr = "mfma_f32_32x32x1f32";
MPerXdlops = 32;
NPerXdlops = 64;
MRepeats = 1;
Expand All @@ -104,7 +104,7 @@ struct XdlopsCodeSelection {
vectorNumber = 1;
imms.push_back({ 1, 0, 0 });
} else if (MPerWave == 64 && NPerWave == 16) {
mfmaInstr = "mfma_f32_16x16x1xf32";
mfmaInstr = "mfma_f32_16x16x1f32";
MPerXdlops = 64;
NPerXdlops = 16;
MRepeats = 1;
Expand All @@ -113,7 +113,7 @@ struct XdlopsCodeSelection {
vectorNumber = 1;
imms.push_back({ 0, 0, 4 });
} else if (MPerWave == 16 && NPerWave == 64) {
mfmaInstr = "mfma_f32_16x16x1xf32";
mfmaInstr = "mfma_f32_16x16x1f32";
MPerXdlops = 16;
NPerXdlops = 64;
MRepeats = 1;
Expand All @@ -122,7 +122,7 @@ struct XdlopsCodeSelection {
vectorNumber = 1;
imms.push_back({ 2, 0, 0 });
} else if (MPerWave == 8 && NPerWave == 64) {
mfmaInstr = "mfma_f32_4x4x1xf32";
mfmaInstr = "mfma_f32_4x4x1f32";
MPerXdlops = 8;
NPerXdlops = 64;
MRepeats = 1;
Expand All @@ -132,7 +132,7 @@ struct XdlopsCodeSelection {
imms.push_back({ 4, 0, 0 });
imms.push_back({ 4, 1, 0 });
} else if (MPerWave == 4 && NPerWave == 64) {
mfmaInstr = "mfma_f32_4x4x1xf32";
mfmaInstr = "mfma_f32_4x4x1f32";
MPerXdlops = 4;
NPerXdlops = 64;
MRepeats = 1;
Expand All @@ -141,7 +141,7 @@ struct XdlopsCodeSelection {
vectorNumber = 1;
imms.push_back({ 4, 0, 0 });
} else if (MPerWave == 32 && NPerWave == 32) {
mfmaInstr = "mfma_f32_32x32x2xf32";
mfmaInstr = "mfma_f32_32x32x2f32";
MPerXdlops = 32;
NPerXdlops = 32;
MRepeats = 1;
Expand All @@ -150,7 +150,7 @@ struct XdlopsCodeSelection {
vectorNumber = 1;
imms.push_back({ 0, 0, 0 });
} else if (MPerWave == 16 && NPerWave == 16) {
mfmaInstr = "mfma_f32_16x16x4xf32";
mfmaInstr = "mfma_f32_16x16x4f32";
MPerXdlops = 16;
NPerXdlops = 16;
MRepeats = 1;
Expand Down Expand Up @@ -396,7 +396,7 @@ struct XdlopsCodeSelection {

// Obtain properties of MFMA instructions.
int64_t group_size, num_groups_blk, num_regs_blk, num_threads_blk, wave_size, num_input_blks, num_output_blks, num_regs_xdlops, m, n, k, cycles, k_base;
if (mfmaInstr == "mfma_f32_32x32x1xf32") {
if (mfmaInstr == "mfma_f32_32x32x1f32") {
group_size = 4;
num_groups_blk = 4;
num_regs_blk = group_size * num_groups_blk;
Expand All @@ -410,7 +410,7 @@ struct XdlopsCodeSelection {
k = 1;
cycles = 64;
k_base = 1;
} else if (mfmaInstr == "mfma_f32_32x32x2xf32") {
} else if (mfmaInstr == "mfma_f32_32x32x2f32") {
group_size = 4;
num_groups_blk = 4;
num_regs_blk = group_size * num_groups_blk;
Expand All @@ -424,7 +424,7 @@ struct XdlopsCodeSelection {
k = 2;
cycles = 64;
k_base = 1;
} else if (mfmaInstr == "mfma_f32_16x16x4xf32") {
} else if (mfmaInstr == "mfma_f32_16x16x4f32") {
group_size = 4;
num_groups_blk = 1;
num_regs_blk = group_size * num_groups_blk;
Expand All @@ -438,7 +438,7 @@ struct XdlopsCodeSelection {
k = 4;
cycles = 32;
k_base = 1;
} else if (mfmaInstr == "mfma_f32_16x16x1xf32") {
} else if (mfmaInstr == "mfma_f32_16x16x1f32") {
group_size = 4;
num_groups_blk = 1;
num_regs_blk = group_size * num_groups_blk;
Expand All @@ -452,7 +452,7 @@ struct XdlopsCodeSelection {
k = 1;
cycles = 32;
k_base = 1;
} else if (mfmaInstr == "mfma_f32_4x4x1xf32") {
} else if (mfmaInstr == "mfma_f32_4x4x1f32") {
group_size = 4;
num_groups_blk = 1;
num_regs_blk = group_size * num_groups_blk;
Expand Down

0 comments on commit e15b6ff

Please sign in to comment.