Skip to content

Commit

Permalink
Add m_waves and n_waves attributes in miopen.gridwise_gemm -> miopen.…
Browse files Browse the repository at this point in the history
…blockwise_gemm lowering.
  • Loading branch information
whchung committed Jul 17, 2020
1 parent 20f3e05 commit 894d86c
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -1255,7 +1255,9 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
}
}

void affixBlockwiseGemmAttributes(miopen::BlockwiseGemmOp bop, miopen::GridwiseGemmOp gop) const {
void affixBlockwiseGemmAttributes(miopen::BlockwiseGemmOp bop,
miopen::GridwiseGemmOp gop,
PatternRewriter &b) const {
// Add attributes from C++ template arguments and ctor arguments.
//const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
// BlockSize, - block_size attribute
Expand All @@ -1279,6 +1281,24 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
bop.setAttr("m_level1_cluster", gop.getAttr("m_level1_cluster"));
bop.setAttr("n_level0_cluster", gop.getAttr("n_level0_cluster"));
bop.setAttr("n_level1_cluster", gop.getAttr("n_level1_cluster"));

// xdlops.
auto xdlopsAttr = gop.template getAttrOfType<BoolAttr>("xdlops");
if (xdlopsAttr && xdlopsAttr.getValue() == true) {
int64_t MPerBlock =
gop.getAttr("m_per_block").template dyn_cast<IntegerAttr>().getInt();
int64_t NPerBlock =
gop.getAttr("n_per_block").template dyn_cast<IntegerAttr>().getInt();
int64_t MPerThread =
gop.getAttr("m_per_thread").template dyn_cast<IntegerAttr>().getInt();
int64_t NPerThread =
gop.getAttr("n_per_thread").template dyn_cast<IntegerAttr>().getInt();
int64_t MWaves = MPerBlock / MPerThread;
int64_t NWaves = NPerBlock / NPerThread;

bop.setAttr("m_waves", b.getI32IntegerAttr(MWaves));
bop.setAttr("n_waves", b.getI32IntegerAttr(NWaves));
}
}

template <typename T>
Expand Down Expand Up @@ -1809,7 +1829,7 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
loc, lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp,
register2DMatrixCAllocOp, c_thread_mtx_index_row,
c_thread_mtx_index_col);
affixBlockwiseGemmAttributes(blockwiseGemmEvenOp, op);
affixBlockwiseGemmAttributes(blockwiseGemmEvenOp, op, b);

// Blockwise copy from register (naive tensor) to LDS (naive tensor).
auto blockwiseCopyOpAOdd = lb.create<miopen::BlockwiseCopyOp>(
Expand Down Expand Up @@ -1851,7 +1871,7 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
loc, lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp,
register2DMatrixCAllocOp, c_thread_mtx_index_row,
c_thread_mtx_index_col);
affixBlockwiseGemmAttributes(blockwiseGemmOddOp, op);
affixBlockwiseGemmAttributes(blockwiseGemmOddOp, op, b);

// Blockwise copy from register (naive tensor) to LDS (naive tensor).
auto blockwiseCopyAEvenSecondIteration = lb.create<miopen::BlockwiseCopyOp>(
Expand All @@ -1876,13 +1896,13 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
loc, lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp,
register2DMatrixCAllocOp, c_thread_mtx_index_row,
c_thread_mtx_index_col);
affixBlockwiseGemmAttributes(blockwiseGemmTailEvenOp, op);
affixBlockwiseGemmAttributes(blockwiseGemmTailEvenOp, op, b);
} else {
auto blockwiseGemmTailOddOp = b.create<miopen::BlockwiseGemmOp>(
loc, lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp,
register2DMatrixCAllocOp, c_thread_mtx_index_row,
c_thread_mtx_index_col);
affixBlockwiseGemmAttributes(blockwiseGemmTailOddOp, op);
affixBlockwiseGemmAttributes(blockwiseGemmTailOddOp, op, b);
}

// Threadwise copy from register (naive tensor) to global (generic tensor).
Expand Down

0 comments on commit 894d86c

Please sign in to comment.