Skip to content

Commit

Permalink
Revise -miopen-lowering-step2 wrt gridwise_gemm lowering.
Browse files Browse the repository at this point in the history
Now LDS buffers are passed to blockwise_gemm and blockwise_copy in 2D forms.
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 0dcee28 commit 562edbc
Showing 1 changed file with 56 additions and 49 deletions.
105 changes: 56 additions & 49 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -991,10 +991,18 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
Type outputElementType) const {
auto inputAffineMaps = inputType.getAffineMaps();

auto expr = getAffineDimExpr(0, op.getContext()) +
getAffineConstantExpr(offset, op.getContext());
AffineMap transformAffineMap =
AffineMap::get(1, 0, ArrayRef<AffineExpr>{expr}, op.getContext());
auto outputRank = outputShape.size();

auto expr = getAffineConstantExpr(offset, op.getContext());
unsigned stride = 1;
for (int i = outputRank - 1; i >= 0; --i) {
expr = expr + getAffineDimExpr(i, op.getContext()) *
getAffineConstantExpr(stride, op.getContext());
stride *= outputShape[i];
}

AffineMap transformAffineMap = AffineMap::get(
outputRank, 0, ArrayRef<AffineExpr>{expr}, op.getContext());
AffineMap outputAffineMap;
if (inputAffineMaps.size() != 0) {
auto inputAffineMap = inputAffineMaps[0];
Expand Down Expand Up @@ -1080,22 +1088,18 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// // be careful of LDS alignment
// constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
// Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
auto lds2DMatrixAHeight = KPerBlock;
auto lds2DMatrixAWidth = MPerBlock;
auto lds2DMatrixAEvenMemRefType = computeSubviewResultType(
op, ldsBlockAEvenMemRefType, 0, {KPerBlock, MPerBlock}, elementType);

//auto lds2DMatrixAMemRefType =
// MemRefType::get({lds2DMatrixAHeight, lds2DMatrixAWidth}, b.getF32Type(), {}, ldsMemorySpace);
auto lds2DMatrixAOddMemRefType = computeSubviewResultType(
op, ldsBlockAOddMemRefType, 0, {KPerBlock, MPerBlock}, elementType);

//llvm::SmallVector<int64_t, 2> lds2DMatrixADim {lds2DMatrixAHeight, lds2DMatrixAWidth};
//llvm::SmallVector<NamedAttribute, 8> lds2DMatrixADimAttr {
// b.getNamedAttr("dimensions", b.getI64ArrayAttr(lds2DMatrixADim)),
//};

//auto lds2DMatrixAEvenSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixAMemRefType, ldsBlockAEvenSubviewOp, zeroConstantIndexOp);
//lds2DMatrixAEvenSubviewOp.setAttrs(lds2DMatrixADimAttr);
//auto lds2DMatrixAOddSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixAMemRefType, ldsBlockAOddSubviewOp, zeroConstantIndexOp);
//lds2DMatrixAOddSubviewOp.setAttrs(lds2DMatrixADimAttr);

auto lds2DMatrixAEvenSubviewOp = b.create<miopen::SubviewOp>(
op.getLoc(), lds2DMatrixAEvenMemRefType, ldsBlockAEvenSubviewOp,
zeroConstantIndexOp);
auto lds2DMatrixAOddSubviewOp =
b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixAOddMemRefType,
ldsBlockAOddSubviewOp, zeroConstantIndexOp);

// Subviews for Matrix B.
auto ldsBlockBDoubleSize = ldsBlockBSize * 2;
Expand Down Expand Up @@ -1132,22 +1136,18 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// // be careful of LDS alignment
// constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
// Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
auto lds2DMatrixBHeight = KPerBlock;
auto lds2DMatrixBWidth = NPerBlock;
auto lds2DMatrixBEvenMemRefType = computeSubviewResultType(
op, ldsBlockBEvenMemRefType, 0, {KPerBlock, NPerBlock}, elementType);

//auto lds2DMatrixBMemRefType =
// MemRefType::get({lds2DMatrixBHeight, lds2DMatrixBWidth}, b.getF32Type(), {}, ldsMemorySpace);

//llvm::SmallVector<int64_t, 2> lds2DMatrixBDim {lds2DMatrixBHeight, lds2DMatrixBWidth};
//llvm::SmallVector<NamedAttribute, 1> lds2DMatrixBDimAttr {
// b.getNamedAttr("dimensions", b.getI64ArrayAttr(lds2DMatrixBDim)),
//};

//auto lds2DMatrixBEvenSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixBMemRefType, ldsBlockBEvenSubviewOp, zeroConstantIndexOp);
//lds2DMatrixBEvenSubviewOp.setAttrs(lds2DMatrixBDimAttr);
//auto lds2DMatrixBOddSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixBMemRefType, ldsBlockBOddSubviewOp, zeroConstantIndexOp);
//lds2DMatrixBOddSubviewOp.setAttrs(lds2DMatrixBDimAttr);
auto lds2DMatrixBOddMemRefType = computeSubviewResultType(
op, ldsBlockBOddMemRefType, 0, {KPerBlock, NPerBlock}, elementType);

auto lds2DMatrixBEvenSubviewOp = b.create<miopen::SubviewOp>(
op.getLoc(), lds2DMatrixBEvenMemRefType, ldsBlockBEvenSubviewOp,
zeroConstantIndexOp);
auto lds2DMatrixBOddSubviewOp =
b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixBOddMemRefType,
ldsBlockBOddSubviewOp, zeroConstantIndexOp);

// Alloc for Matrix C on registers.
// Compute register size from attributes.
Expand Down Expand Up @@ -1267,8 +1267,10 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// AddressSpace::Lds,
// InMemoryDataOperation::Set>(
// {0, n_block_data_on_global}, {0, 0});
b.create<miopen::BlockwiseCopyOp>(op.getLoc(), op.getOperand(0), ldsBlockAEvenSubviewOp);
b.create<miopen::BlockwiseCopyOp>(op.getLoc(), op.getOperand(1), ldsBlockBEvenSubviewOp);
b.create<miopen::BlockwiseCopyOp>(op.getLoc(), op.getOperand(0),
lds2DMatrixAEvenSubviewOp);
b.create<miopen::BlockwiseCopyOp>(op.getLoc(), op.getOperand(1),
lds2DMatrixBEvenSubviewOp);

// Emit loop.
// Compute loop iterations from attributes.
Expand All @@ -1291,13 +1293,16 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
blockwiseCopyOpBEven.setAttr("move_source_slice_window", b.getI32IntegerAttr(KPerBlock));

// Emit blockwise GEMM.
//auto blockwiseGemmEvenOp = lb.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp, register2DMatrixCSubviewOp);
auto blockwiseGemmEvenOp = lb.create<miopen::BlockwiseGemmOp>(op.getLoc(), ldsBlockAEvenSubviewOp, ldsBlockBEvenSubviewOp, threadCAllocOp);
auto blockwiseGemmEvenOp = lb.create<miopen::BlockwiseGemmOp>(
op.getLoc(), lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp,
threadCAllocOp);
affixBlockwiseGemmAttributes(blockwiseGemmEvenOp, op);

// Blockwise copy from reigster (naitve tensor) to LDS (naive tensor).
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadAEvenAllocOp, ldsBlockAOddSubviewOp);
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadBEvenAllocOp, ldsBlockBOddSubviewOp);
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadAEvenAllocOp,
lds2DMatrixAOddSubviewOp);
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadBEvenAllocOp,
lds2DMatrixBOddSubviewOp);

// LDS barrier.
lb.create<miopen::LdsBarrierOp>(op.getLoc());
Expand All @@ -1311,13 +1316,16 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
blockwiseCopyOpBOdd.setAttr("move_source_slice_window", b.getI32IntegerAttr(KPerBlock));

// Emit blockwise GEMM.
//auto blockwiseGemmOddOp = lb.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp, register2DMatrixCSubviewOp);
auto blockwiseGemmOddOp = lb.create<miopen::BlockwiseGemmOp>(op.getLoc(), ldsBlockAOddSubviewOp, ldsBlockBOddSubviewOp, threadCAllocOp);
auto blockwiseGemmOddOp = lb.create<miopen::BlockwiseGemmOp>(
op.getLoc(), lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp,
threadCAllocOp);
affixBlockwiseGemmAttributes(blockwiseGemmOddOp, op);

// Blockwise copy from reigster (naitve tensor) to LDS (naive tensor).
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadAOddAllocOp, ldsBlockAEvenSubviewOp);
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadBOddAllocOp, ldsBlockBEvenSubviewOp);
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadAOddAllocOp,
lds2DMatrixAEvenSubviewOp);
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadBOddAllocOp,
lds2DMatrixBEvenSubviewOp);

// outside the loop.

Expand All @@ -1327,12 +1335,14 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm

// Emit blockwise GEMM for the loop tail.
if (loopIteration % 2) {
//auto blockwiseGemmTailEvenOp = b.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp, register2DMatrixCSubviewOp);
auto blockwiseGemmTailEvenOp = b.create<miopen::BlockwiseGemmOp>(op.getLoc(), ldsBlockAEvenSubviewOp, ldsBlockBEvenSubviewOp, threadCAllocOp);
auto blockwiseGemmTailEvenOp = b.create<miopen::BlockwiseGemmOp>(
op.getLoc(), lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp,
threadCAllocOp);
affixBlockwiseGemmAttributes(blockwiseGemmTailEvenOp, op);
} else {
//auto blockwiseGemmTailOddOp = b.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp, register2DMatrixCSubviewOp);
auto blockwiseGemmTailOddOp = b.create<miopen::BlockwiseGemmOp>(op.getLoc(), ldsBlockAOddSubviewOp, ldsBlockBOddSubviewOp, threadCAllocOp);
auto blockwiseGemmTailOddOp = b.create<miopen::BlockwiseGemmOp>(
op.getLoc(), lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp,
threadCAllocOp);
affixBlockwiseGemmAttributes(blockwiseGemmTailOddOp, op);
}

Expand Down Expand Up @@ -1671,8 +1681,6 @@ struct SubviewRewritePattern : public OpRewritePattern<miopen::SubviewOp> {

LogicalResult matchAndRewrite(miopen::SubviewOp op,
PatternRewriter &b) const override {
auto loc = op.getLoc();
auto inputType = op.input().getType().cast<MemRefType>();
auto outputType = op.output().getType().cast<MemRefType>();

// Pass the output affine map to users of this op.
Expand Down Expand Up @@ -1700,7 +1708,6 @@ struct TransformRewritePattern : public OpRewritePattern<miopen::TransformOp> {

LogicalResult matchAndRewrite(miopen::TransformOp op,
PatternRewriter &b) const override {
auto loc = op.getLoc();
auto outputType = op.output().getType().cast<MemRefType>();

// Pass the output affine map to users of this op.
Expand Down

0 comments on commit 562edbc

Please sign in to comment.