Skip to content

Commit

Permalink
XXX WIP tuning miopen.xdlops_gemm_v2 lowering logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Sep 11, 2020
1 parent 12e4d12 commit f8b0090
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
16 changes: 12 additions & 4 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -5835,7 +5835,9 @@ struct BlockwiseGemmV2RewritePattern
xdlopsGemmV2Op0.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op0.setAttr("n", op.getAttr("n"));
xdlopsGemmV2Op0.setAttr("k", op.getAttr("k"));
xdlopsGemmV2Op0.setAttr("m_per_wave", op.getAttr("m_per_wave"));
// TBD. hard-coded as 64 for now.
xdlopsGemmV2Op0.setAttr("m_per_wave", b.getI32IntegerAttr(64));
//xdlopsGemmV2Op0.setAttr("m_per_wave", op.getAttr("m_per_wave"));
xdlopsGemmV2Op0.setAttr("n_per_wave", op.getAttr("n_per_wave"));
xdlopsGemmV2Op0.setAttr("coord_transforms",
op.getAttr("coord_transforms"));
Expand All @@ -5853,7 +5855,9 @@ struct BlockwiseGemmV2RewritePattern
xdlopsGemmV2Op1.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op1.setAttr("n", op.getAttr("n"));
xdlopsGemmV2Op1.setAttr("k", op.getAttr("k"));
xdlopsGemmV2Op1.setAttr("m_per_wave", op.getAttr("m_per_wave"));
// TBD. hard-coded as 64 for now.
xdlopsGemmV2Op1.setAttr("m_per_wave", b.getI32IntegerAttr(64));
//xdlopsGemmV2Op1.setAttr("m_per_wave", op.getAttr("m_per_wave"));
xdlopsGemmV2Op1.setAttr("n_per_wave", op.getAttr("n_per_wave"));
xdlopsGemmV2Op1.setAttr("coord_transforms",
op.getAttr("coord_transforms"));
Expand All @@ -5878,7 +5882,9 @@ struct BlockwiseGemmV2RewritePattern
xdlopsGemmV2Op0.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op0.setAttr("n", op.getAttr("n"));
xdlopsGemmV2Op0.setAttr("k", op.getAttr("k"));
xdlopsGemmV2Op0.setAttr("m_per_wave", op.getAttr("m_per_wave"));
// TBD. hard-coded as 64 for now.
xdlopsGemmV2Op0.setAttr("m_per_wave", b.getI32IntegerAttr(64));
//xdlopsGemmV2Op0.setAttr("m_per_wave", op.getAttr("m_per_wave"));
xdlopsGemmV2Op0.setAttr("n_per_wave", op.getAttr("n_per_wave"));
xdlopsGemmV2Op0.setAttr("coord_transforms",
op.getAttr("coord_transforms"));
Expand All @@ -5896,7 +5902,9 @@ struct BlockwiseGemmV2RewritePattern
xdlopsGemmV2Op1.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op1.setAttr("n", op.getAttr("n"));
xdlopsGemmV2Op1.setAttr("k", op.getAttr("k"));
xdlopsGemmV2Op1.setAttr("m_per_wave", op.getAttr("m_per_wave"));
// TBD. hard-coded as 64 for now.
xdlopsGemmV2Op1.setAttr("m_per_wave", b.getI32IntegerAttr(64));
//xdlopsGemmV2Op1.setAttr("m_per_wave", op.getAttr("m_per_wave"));
xdlopsGemmV2Op1.setAttr("n_per_wave", op.getAttr("n_per_wave"));
xdlopsGemmV2Op1.setAttr("coord_transforms",
op.getAttr("coord_transforms"));
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/MIOpen/lowering_xdlops_gemm_v2.mlir
@@ -1,5 +1,5 @@
// XFAIL: *
// RUN: mlir-opt -miopen-lowering-step4 %s | FileCheck %s
// RUN: mlir-opt -miopen-lowering-step5 %s | FileCheck %s

func @miopen_xdlops_gemm_v2_two_results(%matrixA : memref<12288xf32, 3>, %matrixB : memref<12288xf32, 3>) -> (vector<32xf32>, vector<32xf32>) {
%c0 = constant 0 : index
Expand Down

0 comments on commit f8b0090

Please sign in to comment.