Skip to content

Commit

Permalink
Revise stack allocations for src / dest coordinates for blockwise_cop…
Browse files Browse the repository at this point in the history
…y for Matrix A and B.
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 228a600 commit cea2513
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -1251,24 +1251,31 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// AddressSpace::Lds,
// InMemoryDataOperation::Set>(
// {0, n_block_data_on_global}, {0, 0});
auto blockwiseCopyType =
auto blockwiseCopyCoordType =
MemRefType::get({2}, b.getIntegerType(32), {}, registerMemorySpace);

auto blockwiseCopyASrc = b.create<miopen::GpuAllocOp>(op.getLoc(), blockwiseCopyType);
auto blockwiseCopyASrc = b.create<miopen::GpuAllocOp>(op.getLoc(), blockwiseCopyCoordType);
// TBD compute m_block_data_on_global. Use (0, 0) for now.
// TBD add thread_data_id_begin.
b.create<miopen::FillOp>(op.getLoc(), blockwiseCopyASrc, zeroConstantI32Op);

auto blockwiseCopyBSrc = b.create<miopen::GpuAllocOp>(op.getLoc(), blockwiseCopyType);
auto blockwiseCopyADst = b.create<miopen::GpuAllocOp>(op.getLoc(), blockwiseCopyCoordType);
// TBD add thread_data_id_begin.
b.create<miopen::FillOp>(op.getLoc(), blockwiseCopyADst, zeroConstantI32Op);

auto blockwiseCopyBSrc = b.create<miopen::GpuAllocOp>(op.getLoc(), blockwiseCopyCoordType);
// TBD compute n_block_data_on_global. Use (0, 0) for now.
// TBD add thread_data_id_begin.
b.create<miopen::FillOp>(op.getLoc(), blockwiseCopyBSrc, zeroConstantI32Op);

auto blockwiseCopyZero = b.create<miopen::GpuAllocOp>(op.getLoc(), blockwiseCopyType);
b.create<miopen::FillOp>(op.getLoc(), blockwiseCopyZero, zeroConstantI32Op);
auto blockwiseCopyBDst = b.create<miopen::GpuAllocOp>(op.getLoc(), blockwiseCopyCoordType);
// TBD add thread_data_id_begin.
b.create<miopen::FillOp>(op.getLoc(), blockwiseCopyBDst, zeroConstantI32Op);

b.create<miopen::BlockwiseCopyOp>(op.getLoc(), op.getOperand(0),
lds2DMatrixAEvenSubviewOp, blockwiseCopyASrc, blockwiseCopyZero);
lds2DMatrixAEvenSubviewOp, blockwiseCopyASrc, blockwiseCopyADst);
b.create<miopen::BlockwiseCopyOp>(op.getLoc(), op.getOperand(1),
lds2DMatrixBEvenSubviewOp, blockwiseCopyBSrc, blockwiseCopyZero);
lds2DMatrixBEvenSubviewOp, blockwiseCopyBSrc, blockwiseCopyBDst);

// Emit loop.
// Compute loop iterations from attributes.
Expand All @@ -1291,13 +1298,13 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
ValueRange{KPerBlockConstantI32Op, zeroConstantI32Op});
auto blockwiseCopyOpAEven = lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), op.getOperand(0), threadAEvenAllocOp, blockwiseCopyASrc,
blockwiseCopyZero);
blockwiseCopyADst);
lb.create<miopen::MovePosOp>(
op.getLoc(), blockwiseCopyBSrc,
ValueRange{KPerBlockConstantI32Op, zeroConstantI32Op});
auto blockwiseCopyOpBEven = lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), op.getOperand(1), threadBEvenAllocOp, blockwiseCopyBSrc,
blockwiseCopyZero);
blockwiseCopyBDst);

// Emit blockwise GEMM.
auto blockwiseGemmEvenOp = lb.create<miopen::BlockwiseGemmOp>(
Expand All @@ -1307,9 +1314,9 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm

// Blockwise copy from register (naive tensor) to LDS (naive tensor).
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadAEvenAllocOp,
lds2DMatrixAOddSubviewOp, blockwiseCopyASrc, blockwiseCopyZero);
lds2DMatrixAOddSubviewOp, blockwiseCopyASrc, blockwiseCopyADst);
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadBEvenAllocOp,
lds2DMatrixBOddSubviewOp, blockwiseCopyBSrc, blockwiseCopyZero);
lds2DMatrixBOddSubviewOp, blockwiseCopyBSrc, blockwiseCopyBDst);

// LDS barrier.
lb.create<miopen::LdsBarrierOp>(op.getLoc());
Expand All @@ -1320,13 +1327,13 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
ValueRange{KPerBlockConstantI32Op, zeroConstantI32Op});
auto blockwiseCopyOpAOdd = lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), op.getOperand(0), threadAOddAllocOp, blockwiseCopyASrc,
blockwiseCopyZero);
blockwiseCopyADst);
lb.create<miopen::MovePosOp>(
op.getLoc(), blockwiseCopyBSrc,
ValueRange{KPerBlockConstantI32Op, zeroConstantI32Op});
auto blockwiseCopyOpBOdd = lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), op.getOperand(1), threadBOddAllocOp, blockwiseCopyBSrc,
blockwiseCopyZero);
blockwiseCopyBDst);

// Emit blockwise GEMM.
auto blockwiseGemmOddOp = lb.create<miopen::BlockwiseGemmOp>(
Expand All @@ -1336,9 +1343,9 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm

// Blockwise copy from register (naive tensor) to LDS (naive tensor).
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadAOddAllocOp,
lds2DMatrixAEvenSubviewOp, blockwiseCopyZero, blockwiseCopyZero);
lds2DMatrixAEvenSubviewOp, blockwiseCopyASrc, blockwiseCopyADst);
lb.create<miopen::BlockwiseCopyOp>(op.getLoc(), threadBOddAllocOp,
lds2DMatrixBEvenSubviewOp, blockwiseCopyZero, blockwiseCopyZero);
lds2DMatrixBEvenSubviewOp, blockwiseCopyBSrc, blockwiseCopyBDst);

// outside the loop.

Expand Down

0 comments on commit cea2513

Please sign in to comment.