Skip to content

Commit

Permalink
Compute threadwise_copy coordinates used in gridwise_gemm.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 292108f commit d8c09f8
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -1569,12 +1569,22 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
b.create<MulIOp>(op.getLoc(), level1_m_id,
MPerLevel0ClusterConstantIndexOp),
b.create<MulIOp>(op.getLoc(), level0_m_id, MPerThreadConstantIndexOp));
auto c_thread_mtx_index_row_i32 = b.create<IndexCastOp>(
op.getLoc(), c_thread_mtx_index_row, b.getIntegerType(32));

// mMyThreadOffsetB = BlockMatrixB::GetOffsetFromMultiIndex{0, c_thread_mtx_index.col} = c_thread_mtx_index_col
auto c_thread_mtx_index_col = b.create<AddIOp>(
op.getLoc(),
b.create<MulIOp>(op.getLoc(), level1_n_id,
NPerLevel0ClusterConstantIndexOp),
b.create<MulIOp>(op.getLoc(), level0_n_id, NPerThreadConstantIndexOp));
auto c_thread_mtx_index_col_i32 = b.create<IndexCastOp>(
op.getLoc(), c_thread_mtx_index_col, b.getIntegerType(32));

auto m_thread_data_on_global_i32 = b.create<AddIOp>(
op.getLoc(), m_block_data_on_global_i32, c_thread_mtx_index_row_i32);
auto n_thread_data_on_global_i32 = b.create<AddIOp>(
op.getLoc(), n_block_data_on_global_i32, c_thread_mtx_index_col_i32);

// Emit BlockwiseCopy ops.
auto blockwiseCopyA = b.create<miopen::BlockwiseCopyOp>(
Expand Down Expand Up @@ -1701,7 +1711,7 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
}

// Threadwise copy from register (naive tensor) to global (generic tensor).
// TBD add attributes from C++ template arguments and ctor arguments.
// Add attributes from C++ template arguments and ctor arguments.
// ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
// decltype(c_m0_m1_n0_n1_global_desc),
// decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
Expand All @@ -1718,17 +1728,16 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// n_thread_data_on_global / N1,
// n_thread_data_on_global % N1})
// .Run(p_c_thread, p_c_global);
// TBD use all 0 coordinates now. need to revisit this following original
// C++ implementation.
SmallVector<Value, 6> matrixCThreadwiseCopySourceAndDestCoords;
for (unsigned i = 0; i < threadCRegisterMemRefType.getRank(); ++i)
matrixCThreadwiseCopySourceAndDestCoords.push_back(zeroConstantI32Op);
for (unsigned i = 0;
i < op.getOperand(2).getType().cast<MemRefType>().getRank();
++i)
matrixCThreadwiseCopySourceAndDestCoords.push_back(zeroConstantI32Op);
// XXX. Use 2D coordinate only.
SmallVector<Value, 4> matrixCThreadwiseCopySourceAndDestCoords;
matrixCThreadwiseCopySourceAndDestCoords.push_back(zeroConstantI32Op);
matrixCThreadwiseCopySourceAndDestCoords.push_back(zeroConstantI32Op);
matrixCThreadwiseCopySourceAndDestCoords.push_back(
m_thread_data_on_global_i32);
matrixCThreadwiseCopySourceAndDestCoords.push_back(
n_thread_data_on_global_i32);
auto threadwiseCopyCMatrixOp = b.create<miopen::ThreadwiseCopyOp>(
op.getLoc(), register2DMatrixCAllocOp, op.getOperand(2),
op.getLoc(), register2DMatrixCAllocOp, op.output(),
matrixCThreadwiseCopySourceAndDestCoords);
affixThreadwiseCopyAttributes(threadwiseCopyCMatrixOp, op, b);

Expand Down

0 comments on commit d8c09f8

Please sign in to comment.