Skip to content

Commit

Permalink
Fix matrix C output logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Aug 12, 2020
1 parent fca12c1 commit d359032
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 29 deletions.
112 changes: 84 additions & 28 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -1342,9 +1342,13 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
LogicalResult matchAndRewrite(miopen::GridwiseGemmOp op, PatternRewriter &b) const override {
auto loc = op.getLoc();

auto elementType = op.output().getType().cast<MemRefType>().getElementType();

// Prepare some useful constants.
auto zeroConstantFloatOp =
b.create<ConstantFloatOp>(loc, APFloat(0.0f), b.getF32Type());
auto oneConstantFloatOp =
b.create<ConstantFloatOp>(loc, APFloat(1.0f), b.getF32Type());
auto zeroConstantI32Op =
b.create<ConstantIntOp>(loc, 0, b.getIntegerType(32));

Expand Down Expand Up @@ -1406,10 +1410,16 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
int64_t MBlockWork = M / MPerBlock;
int64_t NBlockWork = N / NPerBlock;

// llvm::errs() << "M / MPerBlock: " << M << " / " << MPerBlock << "\n";
// llvm::errs() << "N / NPerBlock: " << N << " / " << NPerBlock << "\n";
// llvm::errs() << "MBlockWork: " << MBlockWork << "\n";
// llvm::errs() << "NBlockWork: " << NBlockWork << "\n";
llvm::errs() << "M: " << M << "\n";
llvm::errs() << "N:" << N << "\n";
llvm::errs() << "MPerBlock: " << MPerBlock << "\n";
llvm::errs() << "NPerBlock: " << NPerBlock << "\n";
llvm::errs() << "MBlockWork = M / MPerBlock: " << MBlockWork << "\n";
llvm::errs() << "NBlockWork = N / NPerBlock: " << NBlockWork << "\n";
llvm::errs() << "MPerWave: " << MPerWave << "\n";
llvm::errs() << "NPerWave: " << NPerWave << "\n";
llvm::errs() << "MWaves = MPerBlock / MPerWave: " << MWaves << "\n";
llvm::errs() << "NWaves = NPerBlock / NPerWave: " << NWaves << "\n";

auto MBlockWorkConstantOp = b.create<ConstantIndexOp>(loc, MBlockWork);
auto NBlockWorkConstantOp = b.create<ConstantIndexOp>(loc, NBlockWork);
Expand Down Expand Up @@ -1790,6 +1800,7 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// const index_t waveId_n = waveId % GemmNWaves;
// mMyWaveOffsetA = waveId_m * GemmMPerWave;
// mMyWaveOffsetB = waveId_n * GemmNPerWave;

auto waveId = b.create<SignedDivIOp>(loc, tid, waveSizeConstantOp);
auto waveId_m = b.create<SignedDivIOp>(loc, waveId, NWavesConstantOp);
auto waveId_n = b.create<SignedRemIOp>(loc, waveId, NWavesConstantOp);
Expand Down Expand Up @@ -2025,8 +2036,23 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// constexpr index_t NumBlks = blockwise_gemm.GetNumBlks();

int BlkSize = num_regs_blk;
int NumBlksPerXdlops = (m * n) * MRepeats * NRepeats;
int NumBlks = (MPerXdlops * NPerXdlops) / NumBlksPerXdlops;
int NumBlksPerXdlops = (MPerXdlops * NPerXdlops) / (m * n);
int NumBlks = NumBlksPerXdlops * MRepeats * NRepeats;

llvm::errs() << "MPerWave: " << MPerWave << "\n";
llvm::errs() << "NPerWave: " << NPerWave << "\n\n";

llvm::errs() << "MPerXlops: " << MPerXdlops << "\n";
llvm::errs() << "NPerXlops: " << NPerXdlops << "\n";
llvm::errs() << "m: " << m << "\n";
llvm::errs() << "n: " << n << "\n";
llvm::errs() << "MRepeat: " << MRepeats << "\n";
llvm::errs() << "NRepeat: " << NRepeats << "\n\n";

llvm::errs() << "BlkSize: " << BlkSize << "\n";
llvm::errs() << "NumBlksPerXdlops: " << NumBlksPerXdlops << "\n";
llvm::errs() << "NumBlks: " << NumBlks << "\n\n";

auto BlkSizeConstantI32Op = b.create<ConstantIntOp>(loc, BlkSize, b.getIntegerType(32));
auto NumBlksPerXdlopsConstantOp = b.create<ConstantIndexOp>(loc, NumBlksPerXdlops);
auto NumBlksConstantOp = b.create<ConstantIndexOp>(loc, NumBlks);
Expand All @@ -2049,25 +2075,33 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// }
// };
//
// // M1 = num_groups;
// // M0 = group_size;
// // N1 = num_blks_per_wave;
// // N0 = num_threads_per_blks;
// // CLayout.M1() = num_groups;
// // CLayout.M0() = group_size;
// // CLayout.N1() = num_blks_per_wave;
// // CLayout.N0() = num_threads_per_blks;
// constexpr auto CLayout = blockwise_gemm.GetOutputLayout();
// constexpr index_t M0 = CLayout.M1();
// constexpr index_t M1 = CLayout.N1();
// constexpr index_t M2 = CLayout.M0();

int64_t M0 = num_groups_blk;
int64_t M3 = num_groups_blk;
int64_t M1 = num_input_blks;
int64_t M2 = group_size;
int64_t M0 = M / (M1 * M2);

llvm::errs() << "M0: " << M0 << "\n";
llvm::errs() << "M1: num_input_blks: " << M1 << "\n";
llvm::errs() << "M2: group_size: " << M2 << "\n";
llvm::errs() << "M3: num_groups_blk: " << M3 << "\n\n";

auto M0ConstantI32Op =
b.create<ConstantIntOp>(loc, M0, b.getIntegerType(32));
auto M1ConstantI32Op =
b.create<ConstantIntOp>(loc, M1, b.getIntegerType(32));
auto M2ConstantI32Op =
b.create<ConstantIntOp>(loc, M2, b.getIntegerType(32));
auto M3ConstantI32Op =
b.create<ConstantIntOp>(loc, M3, b.getIntegerType(32));
auto NConstantI32Op =
b.create<ConstantIntOp>(loc, N, b.getIntegerType(32));

Expand Down Expand Up @@ -2109,20 +2143,19 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// Note: turns out this layout is not enough to cover the whole matrix C
// on VGPR. It only covers 1/NumBlks of it.

// A layout of Sequence<1, NumBlks, M0, M2> would cover the whole matrix C
// A layout of Sequence<NumBlks * M3, 1, M2, 1> would cover the whole matrix C
// on VGPR.
// build affine expression for Sequence<1, NumBlks, M0, M2>
// (d0, d1, d2, d3) -> (d1 * M0 * M2 + d2 * M2 + d3)
// build affine expression for Sequence<NumBlks * M3, 1, M2, 1>
// (d0, d1, d2, d3) -> (d0 * M2 + d2)
auto matrixCAffineMap4to1 = AffineMap::get(
4, 0,
{getAffineDimExpr(1, op.getContext()) * getAffineConstantExpr(M0, op.getContext()) * getAffineConstantExpr(M2, op.getContext()) +
getAffineDimExpr(2, op.getContext()) * getAffineConstantExpr(M2, op.getContext()) +
getAffineDimExpr(3, op.getContext())},
{getAffineDimExpr(0, op.getContext()) * getAffineConstantExpr(M2, op.getContext()) +
getAffineDimExpr(2, op.getContext())},
op.getContext());

// emit TransformOp for Matrix C on VGPR.
auto register4DMatrixCType = MemRefType::get(
{1, NumBlks, M0, M2}, elementType,
{NumBlks * M3, 1, M2, 1}, elementType,
{matrixCAffineMap4to1}, gpu::GPUDialect::getPrivateAddressSpace());
auto matrixCTransformOp = b.create<miopen::TransformOp>(
loc, register4DMatrixCType, registerMatrixCAllocOp);
Expand Down Expand Up @@ -2204,11 +2237,11 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
//
// Original C++ logic:
// const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
// const index_t m_thread_data_on_global =
// m_block_data_on_global + c_thread_mtx_on_block.row;
// const index_t n_thread_data_on_global =
// n_block_data_on_global + c_thread_mtx_on_block.col;

// compute waveId.
auto tid = lb.create<miopen::WorkitemIdOp>(loc, b.getIndexType());
auto waveId = lb.create<SignedDivIOp>(loc, tid, waveSizeConstantOp);

// compute thread_mtx_on_blk_row and thread_mtx_on_blk_col.
auto xdlops_i = lb.create<SignedDivIOp>(loc, iv, NumBlksPerXdlopsConstantOp);
auto j = lb.create<SignedRemIOp>(loc, iv, NumBlksPerXdlopsConstantOp);
Expand Down Expand Up @@ -2245,6 +2278,9 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// compute c_thread_mtx_index_row, c_thread_mtx_index_col.
// compute c_thread_mtx_index_row_i32, c_thread_mtx_index_col_i32.

// compute waveId.
auto waveId = lb.create<SignedDivIOp>(loc, tid, waveSizeConstantOp);

// Original C++ logic.
// const index_t col = (waveId % GemmNWaves) * GemmNPerWave + thread_mtx_on_blk.col;
c_thread_mtx_index_col = lb.create<AddIOp>(loc,
Expand Down Expand Up @@ -2275,8 +2311,9 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
loc, n_block_data_on_global_i32, c_thread_mtx_index_col_i32);

SmallVector<Value, 8> matrixCThreadwiseCopySourceAndDestCoords;
auto coord0 = lb.create<MulIOp>(loc, iv_i32, M2ConstantI32Op);
matrixCThreadwiseCopySourceAndDestCoords.push_back(coord0);
matrixCThreadwiseCopySourceAndDestCoords.push_back(zeroConstantI32Op);
matrixCThreadwiseCopySourceAndDestCoords.push_back(iv_i32);
matrixCThreadwiseCopySourceAndDestCoords.push_back(zeroConstantI32Op);
matrixCThreadwiseCopySourceAndDestCoords.push_back(zeroConstantI32Op);

Expand All @@ -2293,6 +2330,14 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
loc, matrixCTransformOp, newOutputTransformOp,
matrixCThreadwiseCopySourceAndDestCoords);
affixThreadwiseCopyAttributes(threadwiseCopyCMatrixOp, op, lb);

// affix bound attributes.
threadwiseCopyCMatrixOp.setAttr("bound", b.getArrayAttr({
b.getI32IntegerAttr(M3),
b.getI32IntegerAttr(1),
b.getI32IntegerAttr(M2),
b.getI32IntegerAttr(1),
}));
} else {
// non-XDLOPS path.

Expand Down Expand Up @@ -2904,6 +2949,8 @@ struct ThreadwiseCopyRewritePattern

auto zeroConstantFloatOp =
b.create<ConstantFloatOp>(loc, APFloat(0.0f), b.getF32Type());
auto oneConstantFloatOp =
b.create<ConstantFloatOp>(loc, APFloat(1.0f), b.getF32Type());
auto zeroConstantOp = b.create<ConstantIndexOp>(loc, 0);
auto oneConstantOp = b.create<ConstantIndexOp>(loc, 1);

Expand Down Expand Up @@ -3161,11 +3208,20 @@ struct ThreadwiseCopyRewritePattern
auto operandIndex =
dictAttr.get("operand").template cast<IntegerAttr>().getInt();
if (operandIndex == 0) {
auto domainAttr =
dictAttr.get("domain").template cast<ArrayAttr>();
for (unsigned i = 0; i < domainAttr.size(); ++i)
sliceLengths.push_back(
domainAttr[i].template cast<IntegerAttr>().getInt());
// bound attribute take precendence over domain attribute.
if (op.getAttr("bound")) {
auto boundAttr =
op.getAttr("bound").template cast<ArrayAttr>();
for (unsigned i = 0; i < boundAttr.size(); ++i)
sliceLengths.push_back(
boundAttr[i].template cast<IntegerAttr>().getInt());
} else {
auto domainAttr =
dictAttr.get("domain").template cast<ArrayAttr>();
for (unsigned i = 0; i < domainAttr.size(); ++i)
sliceLengths.push_back(
domainAttr[i].template cast<IntegerAttr>().getInt());
}
}
}
} else {
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/MIOpen/Transforms/AffixTuningParameters.cpp
Expand Up @@ -146,6 +146,10 @@ void AffixTuningParameters::runOnFunction() {
validParams.gemmMPerThread = 128;
validParams.gemmNPerThread = 64;
validParams.blockSize = 256;

// XXX. fix gridSize.
// need to use (M/MPerBlock)*(N/NPerBlock).
gridSize = 784;
}

if (launchDimCallback) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/tools/mlir-miopen-driver/mlir-miopen-driver.cpp
Expand Up @@ -329,7 +329,7 @@ static LogicalResult populateHostHarnessLogic(ModuleOp &module, OpBuilder &build
ValueRange{inputMemRefCastOp, oneConstantFloatOp});
auto outputCpuMemsetOp = builder.create<CallOp>(
builder.getUnknownLoc(), mcpuMemset4DFloatFuncOp,
ValueRange{outputMemRefCastOp, oneConstantFloatOp});
ValueRange{outputMemRefCastOp, zeroConstantFloatOp});
block->push_back(filterCpuMemsetOp);
block->push_back(inputCpuMemsetOp);
block->push_back(outputCpuMemsetOp);
Expand Down

0 comments on commit d359032

Please sign in to comment.