Skip to content

Commit

Permalink
Continue fill in logic of miopen.xdlops_gemm lowering.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jul 28, 2020
1 parent 3d476a0 commit 3816da8
Showing 1 changed file with 116 additions and 27 deletions.
143 changes: 116 additions & 27 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -1664,7 +1664,7 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// MPerThread is MPerWave.
// NPerThread is NPerWave.
int64_t TotalRegSize = MPerThread * NPerThread / WaveSize;
llvm::errs() << "TotalRegSize: " << TotalRegSize << "\n";
//llvm::errs() << "TotalRegSize: " << TotalRegSize << "\n";
auto threadCRegisterMemRefType =
MemRefType::get({TotalRegSize}, elementType, {},
gpu::GPUDialect::getPrivateAddressSpace());
Expand Down Expand Up @@ -3191,6 +3191,10 @@ struct XdlopsGemmRewritePattern
.getElementType()
.template dyn_cast<FloatType>();

auto MConstantIndexOp = b.create<ConstantIndexOp>(loc, M);
auto NConstantIndexOp = b.create<ConstantIndexOp>(loc, N);
auto KConstantIndexOp = b.create<ConstantIndexOp>(loc, K);

// Determine which XDLOPS be used.
int64_t MPerXdlops = 0, NPerXdlops = 0, MRepeats = 0, NRepeats = 0;
StringRef mfmaInstr = "";
Expand Down Expand Up @@ -3769,6 +3773,9 @@ struct XdlopsGemmRewritePattern
bool IsABroadcast = (NPerXdlops >= MPerXdlops);
bool IsKReduction = (num_output_blks == 1) && (num_input_blks > 1);

int64_t RegSizePerXdlops = MPerXdlops * NPerXdlops / wave_size;
auto RegSizePerXdlopsConstantIndexOp = b.create<ConstantIndexOp>(loc, RegSizePerXdlops);

// Original C++ logic.
// const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
// FloatA a[K * MRepeats];
Expand All @@ -3789,18 +3796,21 @@ struct XdlopsGemmRewritePattern
auto NXDlopsConstantIndexOp = b.create<ConstantIndexOp>(
loc, dataType.getWidth() / (dataType.getWidth() * k_base));

llvm::errs() << "mfmaInstr: " << mfmaInstr << "\n";
llvm::errs() << "MPerXdlops: " << MPerXdlops << "\n";
llvm::errs() << "NPerXdlops: " << NPerXdlops << "\n";
llvm::errs() << "MRepeats: " << MRepeats << "\n";
llvm::errs() << "NRepeats: " << NRepeats << "\n";
llvm::errs() << "IsABroadcast: " << IsABroadcast << "\n";
llvm::errs() << "IsKReduction: " << IsKReduction << "\n";
auto MPerXdlopsConstantIndexOp = b.create<ConstantIndexOp>(loc, MPerXdlops);
auto NPerXdlopsConstantIndexOp = b.create<ConstantIndexOp>(loc, NPerXdlops);
auto KBaseConstantIndexOp = b.create<ConstantIndexOp>(loc, k_base);

//llvm::errs() << "mfmaInstr: " << mfmaInstr << "\n";
//llvm::errs() << "MPerXdlops: " << MPerXdlops << "\n";
//llvm::errs() << "NPerXdlops: " << NPerXdlops << "\n";
//llvm::errs() << "MRepeats: " << MRepeats << "\n";
//llvm::errs() << "NRepeats: " << NRepeats << "\n";
//llvm::errs() << "IsABroadcast: " << IsABroadcast << "\n";
//llvm::errs() << "IsKReduction: " << IsKReduction << "\n";

auto zeroConstantIndexOp = b.create<ConstantIndexOp>(loc, 0);
auto oneConstantIndexOp = b.create<ConstantIndexOp>(loc, 1);

auto KConstantIndexOp = b.create<ConstantIndexOp>(loc, K);
auto MRepeatsConstantIndexOp = b.create<ConstantIndexOp>(loc, MRepeats);
auto NRepeatsConstantIndexOp = b.create<ConstantIndexOp>(loc, NRepeats);

Expand All @@ -3809,33 +3819,65 @@ struct XdlopsGemmRewritePattern
// static_if<!IsKReduction>{}([&](auto) {
// for(index_t m_i = 0; m_i < MRepeats; ++m_i)
// for(index_t k_i = 0; k_i < K; ++k_i)
// a[k_i + m_i * K] = p_a_wave[k_i * M + laneId + MPerXdlops *
// m_i];
// a[k_i + m_i * K] = p_a_wave[k_i * M + laneId + MPerXdlops * m_i];
// p_a_wave need to be offseted by threadOffsetA.

auto outerLoopM = b.create<scf::ForOp>(loc, zeroConstantIndexOp, MRepeatsConstantIndexOp, oneConstantIndexOp);
auto olmb = OpBuilder::atBlockTerminator(outerLoopM.getBody());
auto olmiv = outerLoopM.getInductionVar();
auto innerLoopMK = olmb.create<scf::ForOp>(loc, zeroConstantIndexOp, KConstantIndexOp, oneConstantIndexOp);
auto ilmkb = OpBuilder::atBlockTerminator(innerLoopMK.getBody());

// TBD
// a[k_i + m_i * K] = p_a_wave[k_i * M + laneId + MPerXdlops *
// m_i];

auto ilmkiv = innerLoopMK.getInductionVar();

// TBD. Check if we need to apply coord_transform as well.
// a[k_i + m_i * K] = p_a_wave[k_i * M + laneId + MPerXdlops * m_i];
// p_a_wave need to be offseted by threadOffsetA.
auto sourceOffsetA = ilmkb.create<AddIOp>(
loc, op.threadOffsetA(),
ilmkb.create<AddIOp>(
loc,
ilmkb.create<AddIOp>(
loc, ilmkb.create<MulIOp>(loc, ilmkiv, MConstantIndexOp),
laneId),
ilmkb.create<MulIOp>(loc, MPerXdlopsConstantIndexOp, olmiv)));
auto destOffsetA = ilmkb.create<AddIOp>(
loc, ilmkiv, ilmkb.create<MulIOp>(loc, olmiv, KConstantIndexOp));

auto valueA = ilmkb.create<LoadOp>(loc, dataType, op.matrixA(),
ValueRange{sourceOffsetA});
ilmkb.create<StoreOp>(loc, valueA, arrayA, ValueRange{destOffsetA});

// Original C++ logic.
// for(index_t n_i = 0; n_i < NRepeats; ++n_i)
// for(index_t k_i = 0; k_i < K; ++k_i)
// b[k_i + n_i * K] = p_b_wave[k_i * N + laneId + NPerXdlops *
// n_i];
// b[k_i + n_i * K] = p_b_wave[k_i * N + laneId + NPerXdlops * n_i];
// p_b_wave need to be offseted by threadOffsetB.

auto outerLoopN = b.create<scf::ForOp>(loc, zeroConstantIndexOp, NRepeatsConstantIndexOp, oneConstantIndexOp);
auto olnb = OpBuilder::atBlockTerminator(outerLoopM.getBody());
auto olnb = OpBuilder::atBlockTerminator(outerLoopN.getBody());
auto olniv = outerLoopN.getInductionVar();
auto innerLoopNK = olnb.create<scf::ForOp>(loc, zeroConstantIndexOp, KConstantIndexOp, oneConstantIndexOp);
auto ilnkb = OpBuilder::atBlockTerminator(innerLoopNK.getBody());

// TBD
// b[k_i + n_i * K] = p_b_wave[k_i * N + laneId + NPerXdlops *
// n_i];
auto ilnkiv = innerLoopNK.getInductionVar();

// TBD. Check if we need to apply coord_transform as well.
// b[k_i + n_i * K] = p_b_wave[k_i * N + laneId + NPerXdlops * n_i];
// p_b_wave need to be offseted by threadOffsetB.

auto sourceOffsetB = ilnkb.create<AddIOp>(
loc, op.threadOffsetB(),
ilnkb.create<AddIOp>(
loc,
ilnkb.create<AddIOp>(
loc, ilnkb.create<MulIOp>(loc, ilnkiv, NConstantIndexOp),
laneId),
ilnkb.create<MulIOp>(loc, NPerXdlopsConstantIndexOp, olniv)));
auto destOffsetB = ilnkb.create<AddIOp>(
loc, ilnkiv, ilnkb.create<MulIOp>(loc, olniv, KConstantIndexOp));

auto valueB = ilnkb.create<LoadOp>(loc, dataType, op.matrixB(),
ValueRange{sourceOffsetB});
ilnkb.create<StoreOp>(loc, valueB, arrayB, ValueRange{destOffsetB});

// Original C++ logic.
// // get pointer of registers
Expand All @@ -3846,14 +3888,24 @@ struct XdlopsGemmRewritePattern
// for(index_t k_i = 0; k_i < K; ++k_i) {
// for(index_t i = 0; i < nxdlops; ++i)

auto loopM = b.create<scf::ForOp>(loc, zeroConstantIndexOp, MRepeatsConstantIndexOp, oneConstantIndexOp);
auto loopM =
b.create<scf::ForOp>(loc, zeroConstantIndexOp,
MRepeatsConstantIndexOp, oneConstantIndexOp);
auto lmb = OpBuilder::atBlockTerminator(loopM.getBody());
auto loopN = lmb.create<scf::ForOp>(loc, zeroConstantIndexOp, NRepeatsConstantIndexOp, oneConstantIndexOp);
auto lmiv = loopM.getInductionVar();
auto loopN =
lmb.create<scf::ForOp>(loc, zeroConstantIndexOp,
NRepeatsConstantIndexOp, oneConstantIndexOp);
auto lnb = OpBuilder::atBlockTerminator(loopN.getBody());
auto loopK = lnb.create<scf::ForOp>(loc, zeroConstantIndexOp, KConstantIndexOp, oneConstantIndexOp);
auto lniv = loopN.getInductionVar();
auto loopK = lnb.create<scf::ForOp>(loc, zeroConstantIndexOp,
KConstantIndexOp, oneConstantIndexOp);
auto lkb = OpBuilder::atBlockTerminator(loopK.getBody());
auto loopI = lkb.create<scf::ForOp>(loc, zeroConstantIndexOp, NXDlopsConstantIndexOp, oneConstantIndexOp);
auto lkiv = loopK.getInductionVar();
auto loopI = lkb.create<scf::ForOp>(
loc, zeroConstantIndexOp, NXDlopsConstantIndexOp, oneConstantIndexOp);
auto lib = OpBuilder::atBlockTerminator(loopI.getBody());
auto liiv = loopI.getInductionVar();

// Original C++ logic.
// mfma_type.template run<MPerXdlops, NPerXdlops>(
Expand All @@ -3863,6 +3915,43 @@ struct XdlopsGemmRewritePattern
// n_i * K * nxdlops * mfma_type.k_base],
// p_c_thread + (NRepeats * m_i + n_i) *
// GetRegSizePerXdlops());
auto addressA = lib.create<AddIOp>(loc,
lib.create<MulIOp>(loc,
lib.create<AddIOp>(loc,
lib.create<MulIOp>(loc, lkiv, NXDlopsConstantIndexOp),
liiv),
KBaseConstantIndexOp),
lib.create<MulIOp>(loc,
lmiv,
lib.create<MulIOp>(loc,
KConstantIndexOp,
lib.create<MulIOp>(loc,
NXDlopsConstantIndexOp, KBaseConstantIndexOp))));
auto addressB = lib.create<AddIOp>(loc,
lib.create<MulIOp>(loc,
lib.create<AddIOp>(loc,
lib.create<MulIOp>(loc, lkiv, NXDlopsConstantIndexOp),
liiv),
KBaseConstantIndexOp),
lib.create<MulIOp>(loc,
lniv,
lib.create<MulIOp>(loc,
KConstantIndexOp,
lib.create<MulIOp>(loc,
NXDlopsConstantIndexOp, KBaseConstantIndexOp))));
// TBD: use vector.type_cast for FP16/BF16 types.
auto argA = lib.create<LoadOp>(loc, dataType, arrayA, ValueRange{addressA});
auto argB = lib.create<LoadOp>(loc, dataType, arrayB, ValueRange{addressB});

auto addressC = lib.create<MulIOp>(loc,
lib.create<AddIOp>(loc,
lib.create<MulIOp>(loc, NRepeatsConstantIndexOp, lmiv),
lniv),
RegSizePerXdlopsConstantIndexOp);
// TBD. use addressC.
auto mfma = lib.create<miopen::MFMAOp>(loc, argA, argB, op.matrixC());
mfma.setAttr("m_per_wave", lib.getI32IntegerAttr(MPerXdlops));
mfma.setAttr("n_per_wave", lib.getI32IntegerAttr(NPerXdlops));

} else {
// Original C++ logic.
Expand Down

0 comments on commit 3816da8

Please sign in to comment.