Skip to content

Commit

Permalink
WIP: attempt to reduce stack usage by moving allocation inside xdlops…
Browse files Browse the repository at this point in the history
…_gemm_v2 to gridwise_gemm_v2.
  • Loading branch information
whchung committed Sep 17, 2020
1 parent e15b6ff commit 2467036
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 29 deletions.
66 changes: 41 additions & 25 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -3282,6 +3282,20 @@ struct GridwiseGemmV2RewritePattern : public OpRewritePattern<miopen::GridwiseGe
blockwiseCopyBDst, threadBOddAllocOp);
affixBlockwiseCopyAttributes(blockwiseCopyB, op, b, /*isMatrixA=*/false);

// ---------------------
// TBD. FloatA / FloatB could be vectorized via KPack. Ignore this for now.
auto arrayAType =
//MemRefType::get({K * MRepeats}, dataType, {},
MemRefType::get({KPerBlock * MRepeats}, dataType, {},
gpu::GPUDialect::getPrivateAddressSpace());
auto arrayA = b.create<miopen::GpuAllocOp>(loc, arrayAType);
auto arrayBType =
//MemRefType::get({K * NRepeats}, dataType, {},
MemRefType::get({KPerBlock * NRepeats}, dataType, {},
gpu::GPUDialect::getPrivateAddressSpace());
auto arrayB = b.create<miopen::GpuAllocOp>(loc, arrayBType);
// ---------------------

// Emit loop.
// Compute loop iterations from attributes.

Expand Down Expand Up @@ -3322,7 +3336,7 @@ struct GridwiseGemmV2RewritePattern : public OpRewritePattern<miopen::GridwiseGe
// Emit blockwise V2 GEMM.
auto blockwiseGemmV2EvenOp = mfmalb.create<miopen::BlockwiseGemmV2Op>(
loc, vectorCTypes, lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp,
mMyThreadOffsetA, mMyThreadOffsetB, mfmaLoopOp.getRegionIterArgs());
mMyThreadOffsetA, mMyThreadOffsetB, arrayA, arrayB, mfmaLoopOp.getRegionIterArgs());
affixBlockwiseGemmV2Attributes(blockwiseGemmV2EvenOp, op, b);

// Blockwise copy from register (naive tensor) to LDS (naive tensor).
Expand Down Expand Up @@ -3363,7 +3377,7 @@ struct GridwiseGemmV2RewritePattern : public OpRewritePattern<miopen::GridwiseGe
// Emit blockwise V2 GEMM.
auto blockwiseGemmV2OddOp = mfmalb.create<miopen::BlockwiseGemmV2Op>(
loc, vectorCTypes, lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp,
mMyThreadOffsetA, mMyThreadOffsetB, blockwiseGemmV2EvenOp.getResults());
mMyThreadOffsetA, mMyThreadOffsetB, arrayA, arrayB, blockwiseGemmV2EvenOp.getResults());
affixBlockwiseGemmV2Attributes(blockwiseGemmV2OddOp, op, b);

// Blockwise copy from register (naive tensor) to LDS (naive tensor).
Expand All @@ -3389,12 +3403,12 @@ struct GridwiseGemmV2RewritePattern : public OpRewritePattern<miopen::GridwiseGe
// Emit blockwise GEMM for the loop tail.
auto blockwiseGemmV2TailEvenOp = b.create<miopen::BlockwiseGemmV2Op>(
loc, vectorCTypes, lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp,
mMyThreadOffsetA, mMyThreadOffsetB, mfmaLoopOp.getResults());
mMyThreadOffsetA, mMyThreadOffsetB, arrayA, arrayB, mfmaLoopOp.getResults());
affixBlockwiseGemmV2Attributes(blockwiseGemmV2TailEvenOp, op, b);

auto blockwiseGemmV2TailOddOp = b.create<miopen::BlockwiseGemmV2Op>(
loc, vectorCTypes, lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp,
mMyThreadOffsetA, mMyThreadOffsetB, blockwiseGemmV2TailEvenOp.getResults());
mMyThreadOffsetA, mMyThreadOffsetB, arrayA, arrayB, blockwiseGemmV2TailEvenOp.getResults());
affixBlockwiseGemmV2Attributes(blockwiseGemmV2TailOddOp, op, b);

// Matrix C write out logic.
Expand Down Expand Up @@ -5524,15 +5538,16 @@ struct XdlopsGemmV2RewritePattern
auto laneId = b.create<SignedRemIOp>(
loc, tid, b.create<ConstantIndexOp>(loc, wave_size));

// TBD. FloatA / FloatB could be vectorized via KPack. Ignore this for now.
auto arrayAType =
MemRefType::get({K * MRepeats}, dataType, {},
gpu::GPUDialect::getPrivateAddressSpace());
auto arrayA = b.create<miopen::GpuAllocOp>(loc, arrayAType);
auto arrayBType =
MemRefType::get({K * NRepeats}, dataType, {},
gpu::GPUDialect::getPrivateAddressSpace());
auto arrayB = b.create<miopen::GpuAllocOp>(loc, arrayBType);
// // TBD. FloatA / FloatB could be vectorized via KPack. Ignore this for now.
// auto arrayAType =
// MemRefType::get({K * MRepeats}, dataType, {},
// gpu::GPUDialect::getPrivateAddressSpace());
// auto arrayA = b.create<miopen::GpuAllocOp>(loc, arrayAType);
// auto arrayBType =
// MemRefType::get({K * NRepeats}, dataType, {},
// MemRefType::get({1}, dataType, {},
// gpu::GPUDialect::getPrivateAddressSpace());
// auto arrayB = b.create<miopen::GpuAllocOp>(loc, arrayBType);

// TBD. FloatA / FloatB could be vectorized via KPack tuning parameter. Ignore this for now.
// use arrayA as pa for now.
Expand Down Expand Up @@ -5588,7 +5603,7 @@ struct XdlopsGemmV2RewritePattern

auto valueA = ilmkb.create<LoadOp>(loc, dataType, op.matrixA(),
ValueRange{sourceOffsetA});
ilmkb.create<StoreOp>(loc, valueA, arrayA, ValueRange{destOffsetA});
ilmkb.create<StoreOp>(loc, valueA, op.bufferA(), ValueRange{destOffsetA});

// Original C++ logic.
// for(index_t n_i = 0; n_i < NRepeats; ++n_i)
Expand Down Expand Up @@ -5620,7 +5635,7 @@ struct XdlopsGemmV2RewritePattern

auto valueB = ilnkb.create<LoadOp>(loc, dataType, op.matrixB(),
ValueRange{sourceOffsetB});
ilnkb.create<StoreOp>(loc, valueB, arrayB, ValueRange{destOffsetB});
ilnkb.create<StoreOp>(loc, valueB, op.bufferB(), ValueRange{destOffsetB});

// Original C++ logic.
// for(index_t k_i = 0; k_i < K * KRepeats; ++k_i)
Expand All @@ -5638,8 +5653,8 @@ struct XdlopsGemmV2RewritePattern
auto loopKiv = loopK.getInductionVar();

auto offset = loopKb.create<MulIOp>(loc, loopKiv, KBaseConstantOp);
auto argA = loopKb.create<LoadOp>(loc, dataType, arrayA, ValueRange{offset});
auto argB = loopKb.create<LoadOp>(loc, dataType, arrayB, ValueRange{offset});
auto argA = loopKb.create<LoadOp>(loc, dataType, op.bufferA(), ValueRange{offset});
auto argB = loopKb.create<LoadOp>(loc, dataType, op.bufferB(), ValueRange{offset});

SmallVector<Value, 4> mfmas;
for (int64_t i = 0; i < vectorNumber; ++i) {
Expand Down Expand Up @@ -5699,7 +5714,7 @@ struct XdlopsGemmV2RewritePattern

auto valueA = lklb.create<LoadOp>(loc, dataType, op.matrixA(),
ValueRange{sourceOffsetA});
lklb.create<StoreOp>(loc, valueA, arrayA, ValueRange{lkliv});
lklb.create<StoreOp>(loc, valueA, op.bufferA(), ValueRange{lkliv});

auto sourceOffsetB = lklb.create<AddIOp>(
loc, op.threadOffsetB(),
Expand All @@ -5711,7 +5726,7 @@ struct XdlopsGemmV2RewritePattern

auto valueB = lklb.create<LoadOp>(loc, dataType, op.matrixB(),
ValueRange{sourceOffsetB});
lklb.create<StoreOp>(loc, valueB, arrayB, ValueRange{lkliv});
lklb.create<StoreOp>(loc, valueB, op.bufferB(), ValueRange{lkliv});

// Original C++ logic.
// for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
Expand All @@ -5732,8 +5747,8 @@ struct XdlopsGemmV2RewritePattern
auto innerLoopiv = innerLoop.getInductionVar();

auto offset = innerLoopb.create<MulIOp>(loc, innerLoopb.create<AddIOp>(loc, innerLoopb.create<MulIOp>(loc, outerLoopiv, KRepeatsConstantOp), innerLoopiv), KBaseConstantOp);
auto argA = innerLoopb.create<LoadOp>(loc, dataType, arrayA, ValueRange{offset});
auto argB = innerLoopb.create<LoadOp>(loc, dataType, arrayB, ValueRange{offset});
auto argA = innerLoopb.create<LoadOp>(loc, dataType, op.bufferA(), ValueRange{offset});
auto argB = innerLoopb.create<LoadOp>(loc, dataType, op.bufferB(), ValueRange{offset});

SmallVector<Value, 4> mfmas;
for (int64_t i = 0; i < vectorNumber; ++i) {
Expand Down Expand Up @@ -5807,7 +5822,7 @@ struct BlockwiseGemmV2RewritePattern

auto xdlopsGemmV2Op = b.create<miopen::XdlopsGemmV2Op>(
loc, resultTypes, op.matrixA(), op.matrixB(), op.threadOffsetA(),
op.threadOffsetB(), op.vectorCs());
op.threadOffsetB(), op.bufferA(), op.bufferB(), op.vectorCs());

xdlopsGemmV2Op.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op.setAttr("n", op.getAttr("n"));
Expand All @@ -5830,7 +5845,7 @@ struct BlockwiseGemmV2RewritePattern

auto xdlopsGemmV2Op0 = b.create<miopen::XdlopsGemmV2Op>(
loc, resultTypes0, op.matrixA(), op.matrixB(), op.threadOffsetA(),
op.threadOffsetB(), ValueRange{op.vectorCs()[0], op.vectorCs()[1]});
op.threadOffsetB(), op.bufferA(), op.bufferB(), ValueRange{op.vectorCs()[0], op.vectorCs()[1]});

xdlopsGemmV2Op0.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op0.setAttr("n", op.getAttr("n"));
Expand All @@ -5850,7 +5865,7 @@ struct BlockwiseGemmV2RewritePattern
auto xdlopsGemmV2Op1 = b.create<miopen::XdlopsGemmV2Op>(
loc, resultTypes1, op.matrixA(), op.matrixB(),
b.create<AddIOp>(loc, op.threadOffsetA(), MPerXdlopsConstantOp),
op.threadOffsetB(), ValueRange{op.vectorCs()[2], op.vectorCs()[3]});
op.threadOffsetB(), op.bufferA(), op.bufferB(), ValueRange{op.vectorCs()[2], op.vectorCs()[3]});

xdlopsGemmV2Op1.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op1.setAttr("n", op.getAttr("n"));
Expand All @@ -5877,7 +5892,7 @@ struct BlockwiseGemmV2RewritePattern

auto xdlopsGemmV2Op0 = b.create<miopen::XdlopsGemmV2Op>(
loc, resultTypes0, op.matrixA(), op.matrixB(), op.threadOffsetA(),
op.threadOffsetB(), ValueRange{op.vectorCs()[0], op.vectorCs()[1]});
op.threadOffsetB(), op.bufferA(), op.bufferB(), ValueRange{op.vectorCs()[0], op.vectorCs()[1]});

xdlopsGemmV2Op0.setAttr("m", op.getAttr("m"));
xdlopsGemmV2Op0.setAttr("n", op.getAttr("n"));
Expand All @@ -5897,6 +5912,7 @@ struct BlockwiseGemmV2RewritePattern
auto xdlopsGemmV2Op1 = b.create<miopen::XdlopsGemmV2Op>(
loc, resultTypes1, op.matrixA(), op.matrixB(), op.threadOffsetA(),
b.create<AddIOp>(loc, op.threadOffsetB(), NPerXdlopsConstantOp),
op.bufferA(), op.bufferB(),
ValueRange{op.vectorCs()[2], op.vectorCs()[3]});

xdlopsGemmV2Op1.setAttr("m", op.getAttr("m"));
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/MIOpen/MIOpenOps.td
Expand Up @@ -254,6 +254,8 @@ def MIOpen_BlockwiseGemmV2Op:
MemRefOf<[F32, F16, BF16]>:$matrixB,
Index:$threadOffsetA,
Index:$threadOffsetB,
MemRefOf<[F32, F16, BF16]>:$bufferA,
MemRefOf<[F32, F16, BF16]>:$bufferB,
Variadic<VectorOfRankAndType<[1], [F32, F16, BF16]>>:$vectorCs)>,
Results<(outs Variadic<VectorOfRankAndType<[1], [F32, F16, BF16]>>: $vectorDs)> {
let summary = "Blockwise GEMM XDLOPS version";
Expand Down Expand Up @@ -328,6 +330,8 @@ def MIOpen_XdlopsGemmV2Op:
MemRefOf<[F32, F16, BF16]>:$matrixB,
Index:$threadOffsetA,
Index:$threadOffsetB,
MemRefOf<[F32, F16, BF16]>:$bufferA,
MemRefOf<[F32, F16, BF16]>:$bufferB,
Variadic<VectorOfRankAndType<[1], [F32, F16, BF16]>>:$vectorCs)>,
Results<(outs Variadic<VectorOfRankAndType<[1], [F32, F16, BF16]>>: $vectorDs)> {
let summary = "XDLOPS GEMM V2";
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Dialect/MIOpen/MIOpenOps.cpp
Expand Up @@ -633,9 +633,11 @@ static ParseResult parseXdlopsGemmV2Op(OpAsmParser &parser, OperationState &resu
parser.resolveOperand(ops[0], types[0], result.operands) ||
parser.resolveOperand(ops[1], types[1], result.operands) ||
parser.resolveOperand(ops[2], types[2], result.operands) ||
parser.resolveOperand(ops[3], types[3], result.operands);
parser.resolveOperand(ops[3], types[3], result.operands) ||
parser.resolveOperand(ops[4], types[4], result.operands) ||
parser.resolveOperand(ops[5], types[5], result.operands);

for (unsigned i = 4; i < ops.size(); ++i) {
for (unsigned i = 6; i < ops.size(); ++i) {
ret &= succeeded(parser.resolveOperand(ops[i], types[i], result.operands));
parser.addTypeToList(types[i], result.types);
}
Expand Down Expand Up @@ -666,9 +668,11 @@ static ParseResult parseBlockwiseGemmV2Op(OpAsmParser &parser, OperationState &r
parser.resolveOperand(ops[0], types[0], result.operands) ||
parser.resolveOperand(ops[1], types[1], result.operands) ||
parser.resolveOperand(ops[2], types[2], result.operands) ||
parser.resolveOperand(ops[3], types[3], result.operands);
parser.resolveOperand(ops[3], types[3], result.operands) ||
parser.resolveOperand(ops[4], types[4], result.operands) ||
parser.resolveOperand(ops[5], types[5], result.operands);

for (unsigned i = 4; i < ops.size(); ++i) {
for (unsigned i = 6; i < ops.size(); ++i) {
ret &= succeeded(parser.resolveOperand(ops[i], types[i], result.operands));
parser.addTypeToList(types[i], result.types);
}
Expand Down

0 comments on commit 2467036

Please sign in to comment.