Skip to content

Commit

Permalink
Implementation of miopen.blockwise_gemm_v2 lowering logic.
Browse files Browse the repository at this point in the history
- Considers MRepeats = NRepeats = 1 case.
- Considers MRepeats = 1, NRepeats = 2 case.
- Considers MRepeats = 2, NRepeats = 1 case.
- Add attributes.
  • Loading branch information
whchung committed Sep 8, 2020
1 parent b7bd26c commit c3f2b96
Showing 1 changed file with 132 additions and 1 deletion.
133 changes: 132 additions & 1 deletion mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -4313,7 +4313,138 @@ struct BlockwiseGemmV2RewritePattern
LogicalResult matchAndRewrite(miopen::BlockwiseGemmV2Op op,
PatternRewriter &b) const override {
auto loc = op.getLoc();
//op.erase();

int64_t MPerWave =
op.getAttr("m_per_wave").template dyn_cast<IntegerAttr>().getInt();
int64_t NPerWave =
op.getAttr("n_per_wave").template dyn_cast<IntegerAttr>().getInt();

auto dataType = op.matrixA()
.getType()
.template dyn_cast<MemRefType>()
.getElementType()
.template dyn_cast<FloatType>();

// Original C++ logic.
// static constexpr index_t MRepeats = (GemmMPerWave > 64) ? (GemmMPerWave /
// 64) : 1; static constexpr index_t NRepeats = (GemmNPerWave > 64) ?
// (GemmNPerWave / 64) : 1; static constexpr index_t MPerXdlops =
// (GemmMPerWave > 64) ? 64 : GemmMPerWave; static constexpr index_t
// NPerXdlops = (GemmNPerWave > 64) ? 64 : GemmNPerWave;

int64_t MRepeats = (MPerWave > 64) ? (MPerWave / 64) : 1;
int64_t NRepeats = (NPerWave > 64) ? (NPerWave / 64) : 1;
int64_t MPerXdlops = (MPerWave > 64) ? 64 : MPerWave;
int64_t NPerXdlops = (NPerWave > 64) ? 64 : NPerWave;

if (MRepeats == 1 && NRepeats == 1) {
SmallVector<Type, 2> resultTypes;
for (auto result : op.vectorDs()) {
resultTypes.push_back(result.getType());
}

auto xdlopsGemmV2Op = b.create<miopen::XdlopsGemmV2Op>(
loc, resultTypes, op.matrixA(), op.matrixB(), op.threadOffsetA(),
op.threadOffsetB(), op.vectorCs());

xdlopsGemmV2Op.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op.setAttr("n", op.getAttr("n"));
xdlopsGemmV2Op.setAttr("k", op.getAttr("k"));
xdlopsGemmV2Op.setAttr("m_per_wave", op.getAttr("m_per_wave"));
xdlopsGemmV2Op.setAttr("n_per_wave", op.getAttr("n_per_wave"));
xdlopsGemmV2Op.setAttr("coord_transforms",
op.getAttr("coord_transforms"));

op.replaceAllUsesWith(xdlopsGemmV2Op.vectorDs());
op.erase();
} else if (MRepeats == 2 && NRepeats == 1) {
// Original C++ logic.
// p_c_thread.s.x.l = XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread.s.x.l);
// p_c_thread.s.y.l = XdlopsGemm.template Run<M, N, K>(p_a_block + MPerXdlops, p_b_block, p_c_thread.s.y.l);

SmallVector<Type, 2> resultTypes0;
resultTypes0.push_back(op.vectorDs()[0].getType());
resultTypes0.push_back(op.vectorDs()[1].getType());

auto xdlopsGemmV2Op0 = b.create<miopen::XdlopsGemmV2Op>(
loc, resultTypes0, op.matrixA(), op.matrixB(), op.threadOffsetA(),
op.threadOffsetB(), ValueRange{op.vectorCs()[0], op.vectorCs()[1]});

xdlopsGemmV2Op0.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op0.setAttr("n", op.getAttr("n"));
xdlopsGemmV2Op0.setAttr("k", op.getAttr("k"));
xdlopsGemmV2Op0.setAttr("m_per_wave", op.getAttr("m_per_wave"));
xdlopsGemmV2Op0.setAttr("n_per_wave", op.getAttr("n_per_wave"));
xdlopsGemmV2Op0.setAttr("coord_transforms",
op.getAttr("coord_transforms"));

SmallVector<Type, 2> resultTypes1;
resultTypes1.push_back(op.vectorDs()[2].getType());
resultTypes1.push_back(op.vectorDs()[3].getType());

auto MPerXdlopsConstantOp = b.create<ConstantIndexOp>(loc, MPerXdlops);
auto xdlopsGemmV2Op1 = b.create<miopen::XdlopsGemmV2Op>(
loc, resultTypes1, op.matrixA(), op.matrixB(),
b.create<AddIOp>(loc, op.threadOffsetA(), MPerXdlopsConstantOp),
op.threadOffsetB(), ValueRange{op.vectorCs()[2], op.vectorCs()[3]});

xdlopsGemmV2Op1.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op1.setAttr("n", op.getAttr("n"));
xdlopsGemmV2Op1.setAttr("k", op.getAttr("k"));
xdlopsGemmV2Op1.setAttr("m_per_wave", op.getAttr("m_per_wave"));
xdlopsGemmV2Op1.setAttr("n_per_wave", op.getAttr("n_per_wave"));
xdlopsGemmV2Op1.setAttr("coord_transforms",
op.getAttr("coord_transforms"));

op.replaceAllUsesWith(ValueRange{
xdlopsGemmV2Op0.vectorDs()[0], xdlopsGemmV2Op0.vectorDs()[1],
xdlopsGemmV2Op1.vectorDs()[0], xdlopsGemmV2Op1.vectorDs()[1]});
op.erase();
} else if (MRepeats == 1 && NRepeats == 2) {
// Original C++ logic.
// p_c_thread.s.x.l = XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread.s.x.l);
// p_c_thread.s.y.l = XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block + NPerXdlops, p_c_thread.s.y.l);

SmallVector<Type, 2> resultTypes0;
resultTypes0.push_back(op.vectorDs()[0].getType());
resultTypes0.push_back(op.vectorDs()[1].getType());

auto xdlopsGemmV2Op0 = b.create<miopen::XdlopsGemmV2Op>(
loc, resultTypes0, op.matrixA(), op.matrixB(), op.threadOffsetA(),
op.threadOffsetB(), ValueRange{op.vectorCs()[0], op.vectorCs()[1]});

xdlopsGemmV2Op0.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op0.setAttr("n", op.getAttr("n"));
xdlopsGemmV2Op0.setAttr("k", op.getAttr("k"));
xdlopsGemmV2Op0.setAttr("m_per_wave", op.getAttr("m_per_wave"));
xdlopsGemmV2Op0.setAttr("n_per_wave", op.getAttr("n_per_wave"));
xdlopsGemmV2Op0.setAttr("coord_transforms",
op.getAttr("coord_transforms"));

SmallVector<Type, 2> resultTypes1;
resultTypes1.push_back(op.vectorDs()[2].getType());
resultTypes1.push_back(op.vectorDs()[3].getType());

auto NPerXdlopsConstantOp = b.create<ConstantIndexOp>(loc, NPerXdlops);
auto xdlopsGemmV2Op1 = b.create<miopen::XdlopsGemmV2Op>(
loc, resultTypes1, op.matrixA(), op.matrixB(), op.threadOffsetA(),
b.create<AddIOp>(loc, op.threadOffsetB(), NPerXdlopsConstantOp),
ValueRange{op.vectorCs()[2], op.vectorCs()[3]});

xdlopsGemmV2Op1.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op1.setAttr("n", op.getAttr("n"));
xdlopsGemmV2Op1.setAttr("k", op.getAttr("k"));
xdlopsGemmV2Op1.setAttr("m_per_wave", op.getAttr("m_per_wave"));
xdlopsGemmV2Op1.setAttr("n_per_wave", op.getAttr("n_per_wave"));
xdlopsGemmV2Op1.setAttr("coord_transforms",
op.getAttr("coord_transforms"));

op.replaceAllUsesWith(ValueRange{
xdlopsGemmV2Op0.vectorDs()[0], xdlopsGemmV2Op0.vectorDs()[1],
xdlopsGemmV2Op1.vectorDs()[0], xdlopsGemmV2Op1.vectorDs()[1]});
op.erase();
}

return success();
}
};
Expand Down

0 comments on commit c3f2b96

Please sign in to comment.