Skip to content

Commit

Permalink
Improving attribute affixing logic.
Browse files Browse the repository at this point in the history
- gridwise_gemm -> blockwise_copy
- gridwise_gemm -> blockwise_gemm
- gridwise_gemm -> threadwise_copy
- blockwise_copy -> threadwise_copy
- blockwise_gemm -> threadwise_copy
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 1879895 commit a89becd
Showing 1 changed file with 105 additions and 45 deletions.
150 changes: 105 additions & 45 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -870,12 +870,23 @@ static void affixThreadwiseCopyAttributes(miopen::ThreadwiseCopyOp top, miopen::
// AddressSpace::Global, - addrspace on dest memref
// CGlobalMemoryDataOperation>( - NOT USED

top.setAttr("source_vector_read_dim", gop.getAttr("matrix_c_source_dest_vector_read_write_dim"));
top.setAttr("dim_access_order", b.getArrayAttr({
b.getI32IntegerAttr(0),
b.getI32IntegerAttr(1),
b.getI32IntegerAttr(2),
b.getI32IntegerAttr(3),
}));
top.setAttr("vector_read_write_dim",
gop.getAttr("matrix_c_source_dest_vector_read_write_dim"));
top.setAttr("source_data_per_read", b.getI32IntegerAttr(1));
top.setAttr("dest_data_per_write", gop.getAttr("matrix_c_dest_data_per_write"));
}

static void affixThreadwiseCopyAttributes(miopen::ThreadwiseCopyOp top, miopen::BlockwiseCopyOp bop, PatternRewriter &b) {
// XXX: Figure out a way to do away with isThreadwiseLoad parameter.
static void affixThreadwiseCopyAttributes(miopen::ThreadwiseCopyOp top,
miopen::BlockwiseCopyOp bop,
PatternRewriter &b,
bool isThreadwiseLoad) {
// Add attributes from C++ template arguments and ctor arguments.
//
// in blockwise_copy:
Expand All @@ -902,17 +913,23 @@ static void affixThreadwiseCopyAttributes(miopen::ThreadwiseCopyOp top, miopen::
// DstAddressSpace, - addrspace of dest memref
// DstInMemOp>; - NOT USE

// TBD need to figure out should we use attributes for Matrix A or Matrix B.
top.setAttr("source_vector_read_dim", bop.getAttr("matrix_a_source_vector_read_dim"));
top.setAttr("source_data_per_read", bop.getAttr("matrix_a_source_data_per_read"));
top.setAttr("dest_data_per_write", b.getI32IntegerAttr(1));

top.setAttr("source_vector_read_dim", bop.getAttr("matrix_b_source_vector_read_dim"));
top.setAttr("source_data_per_read", b.getI32IntegerAttr(1));
top.setAttr("dest_data_per_write", bop.getAttr("matrix_b_dest_data_per_write_dim_n"));
if (isThreadwiseLoad) {
top.setAttr("dim_access_order", bop.getAttr("source_dim_access_order"));
top.setAttr("vector_read_write_dim", bop.getAttr("source_vector_read_dim"));
top.setAttr("source_data_per_read", bop.getAttr("source_data_per_read"));
top.setAttr("dest_data_per_write", b.getI32IntegerAttr(1));
} else {
top.setAttr("dim_access_order", bop.getAttr("dest_dim_access_order"));
top.setAttr("vector_read_write_dim", bop.getAttr("dest_vector_write_dim"));
top.setAttr("source_data_per_read", b.getI32IntegerAttr(1));
top.setAttr("dest_data_per_write", bop.getAttr("dest_data_per_write"));
}
}

static void affixThreadwiseCopyAttributes(miopen::ThreadwiseCopyOp top, miopen::BlockwiseGemmOp bop, PatternRewriter &b) {
// XXX: figure out a better way to get rid of isMatrixA parameter.
static void affixThreadwiseCopyAttributes(miopen::ThreadwiseCopyOp top,
miopen::BlockwiseGemmOp bop,
PatternRewriter &b, bool isMatrixA) {
// in blockwise_gemm:
//
// constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA, - source memref
Expand All @@ -926,9 +943,15 @@ static void affixThreadwiseCopyAttributes(miopen::ThreadwiseCopyOp top, miopen::
// NPerThreadSubC, - n_per_thread attribute
// ThreadGemmBDataPerRead_N>{}; - n_per_thread attribute

top.setAttr("k_per_thread", bop.getAttr("k_per_thread"));
top.setAttr("m_per_thread", bop.getAttr("m_per_thread"));
top.setAttr("n_per_thread", bop.getAttr("n_per_thread"));
if (isMatrixA) {
top.setAttr("n_slice_row", bop.getAttr("k_per_thread"));
top.setAttr("n_slice_col", bop.getAttr("m_per_thread"));
top.setAttr("data_per_access", bop.getAttr("m_per_thread"));
} else {
top.setAttr("n_slice_row", bop.getAttr("k_per_thread"));
top.setAttr("n_slice_col", bop.getAttr("n_per_thread"));
top.setAttr("data_per_access", bop.getAttr("n_per_thread"));
}
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1036,8 +1059,10 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
double_block_space = 2 * (a_block_space + b_block_space);
}

// XXX. Figure out a way to do away with isMatrixA parameter.
void affixBlockwiseCopyAttributes(miopen::BlockwiseCopyOp bop,
miopen::GridwiseGemmOp gop) const {
miopen::GridwiseGemmOp gop,
PatternRewriter &b, bool isMatrixA) const {
// Add attributes from C++ template arguments and ctor arguments.
// a_blockwise_copy:
// BlockSize - block_size attribute
Expand Down Expand Up @@ -1070,19 +1095,41 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// BBlockCopyDstDataPerWrite_N - matrix_b_dest_data_per_write_dim_n attribute
bop.setAttr("block_size", gop.getAttr("block_size"));

bop.setAttr("matrix_a_source_vector_read_dim",
gop.getAttr("matrix_a_source_vector_read_dim"));
bop.setAttr("matrix_a_source_data_per_read",
gop.getAttr("matrix_a_source_data_per_read"));
bop.setAttr("matrix_a_dest_data_per_write_dim_m",
gop.getAttr("matrix_a_dest_data_per_write_dim_m"));

bop.setAttr("matrix_b_source_vector_read_dim",
gop.getAttr("matrix_b_source_vector_read_dim"));
bop.setAttr("matrix_b_source_data_per_read",
gop.getAttr("matrix_b_source_data_per_read"));
bop.setAttr("matrix_b_dest_data_per_write_dim_n",
gop.getAttr("matrix_b_dest_data_per_write_dim_n"));
if (isMatrixA) {
bop.setAttr("source_dim_access_order", b.getArrayAttr({
b.getI32IntegerAttr(1),
b.getI32IntegerAttr(0),
}));
bop.setAttr("dest_dim_access_order", b.getArrayAttr({
b.getI32IntegerAttr(0),
b.getI32IntegerAttr(1),
}));
bop.setAttr("source_vector_read_dim",
gop.getAttr("matrix_a_source_vector_read_dim"));
bop.setAttr("dest_vector_write_dim", b.getI32IntegerAttr(1));

bop.setAttr("source_data_per_read",
gop.getAttr("matrix_a_source_data_per_read"));
bop.setAttr("dest_data_per_write",
gop.getAttr("matrix_a_dest_data_per_write_dim_m"));
} else {
bop.setAttr("source_dim_access_order", b.getArrayAttr({
b.getI32IntegerAttr(0),
b.getI32IntegerAttr(1),
}));
bop.setAttr("dest_dim_access_order", b.getArrayAttr({
b.getI32IntegerAttr(0),
b.getI32IntegerAttr(1),
}));
bop.setAttr("source_vector_read_dim",
gop.getAttr("matrix_b_source_vector_read_dim"));
bop.setAttr("dest_vector_write_dim", b.getI32IntegerAttr(1));

bop.setAttr("source_data_per_read",
gop.getAttr("matrix_b_source_data_per_read"));
bop.setAttr("dest_data_per_write",
gop.getAttr("matrix_b_dest_data_per_write_dim_n"));
}
}

void affixBlockwiseGemmAttributes(miopen::BlockwiseGemmOp bop, miopen::GridwiseGemmOp gop) const {
Expand Down Expand Up @@ -1486,11 +1533,11 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
auto blockwiseCopyA = b.create<miopen::BlockwiseCopyOp>(
op.getLoc(), op.getOperand(0), lds2DMatrixAEvenSubviewOp,
blockwiseCopyASrc, blockwiseCopyADst, threadAOddAllocOp);
affixBlockwiseCopyAttributes(blockwiseCopyA, op);
affixBlockwiseCopyAttributes(blockwiseCopyA, op, b, /*isMatrixA=*/true);
auto blockwiseCopyB = b.create<miopen::BlockwiseCopyOp>(
op.getLoc(), op.getOperand(1), lds2DMatrixBEvenSubviewOp,
blockwiseCopyBSrc, blockwiseCopyBDst, threadBOddAllocOp);
affixBlockwiseCopyAttributes(blockwiseCopyB, op);
affixBlockwiseCopyAttributes(blockwiseCopyB, op, b, /*isMatrixA=*/false);

// Emit loop.
// Compute loop iterations from attributes.
Expand All @@ -1514,14 +1561,16 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
auto blockwiseCopyOpAEven = lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), op.getOperand(0), threadAEvenAllocOp, blockwiseCopyASrc,
blockwiseCopyADst, /*buffer=*/nullptr);
affixBlockwiseCopyAttributes(blockwiseCopyOpAEven, op);
affixBlockwiseCopyAttributes(blockwiseCopyOpAEven, op, b,
/*isMatrixA=*/true);
lb.create<miopen::MovePosOp>(
op.getLoc(), blockwiseCopyBSrc,
ValueRange{KPerBlockConstantI32Op, zeroConstantI32Op});
auto blockwiseCopyOpBEven = lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), op.getOperand(1), threadBEvenAllocOp, blockwiseCopyBSrc,
blockwiseCopyBDst, /*buffer=*/nullptr);
affixBlockwiseCopyAttributes(blockwiseCopyOpBEven, op);
affixBlockwiseCopyAttributes(blockwiseCopyOpBEven, op, b,
/*isMatrixA=*/false);

// Emit blockwise GEMM.
auto blockwiseGemmEvenOp = lb.create<miopen::BlockwiseGemmOp>(
Expand All @@ -1533,11 +1582,13 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
auto blockwiseCopyOpAOdd = lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), threadAEvenAllocOp, lds2DMatrixAOddSubviewOp,
blockwiseCopyASrc, blockwiseCopyADst, /*buffer=*/nullptr);
affixBlockwiseCopyAttributes(blockwiseCopyOpAOdd, op);
affixBlockwiseCopyAttributes(blockwiseCopyOpAOdd, op, b,
/*isMatrixA=*/true);
auto blockwiseCopyOpBOdd = lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), threadBEvenAllocOp, lds2DMatrixBOddSubviewOp,
blockwiseCopyBSrc, blockwiseCopyBDst, /*buffer=*/nullptr);
affixBlockwiseCopyAttributes(blockwiseCopyOpBOdd, op);
affixBlockwiseCopyAttributes(blockwiseCopyOpBOdd, op, b,
/*isMatrixA=*/false);

// LDS barrier.
lb.create<miopen::WorkgroupBarrierOp>(op.getLoc());
Expand All @@ -1550,15 +1601,17 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), op.getOperand(0), threadAOddAllocOp, blockwiseCopyASrc,
blockwiseCopyADst, /*buffer=*/nullptr);
affixBlockwiseCopyAttributes(blockwiseCopyOpAOddSecondIteration, op);
affixBlockwiseCopyAttributes(blockwiseCopyOpAOddSecondIteration, op, b,
/*isMatrixA=*/true);
lb.create<miopen::MovePosOp>(
op.getLoc(), blockwiseCopyBSrc,
ValueRange{KPerBlockConstantI32Op, zeroConstantI32Op});
auto blockwiseCopyOpBOddSecondIteration =
lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), op.getOperand(1), threadBOddAllocOp, blockwiseCopyBSrc,
blockwiseCopyBDst, /*buffer=*/nullptr);
affixBlockwiseCopyAttributes(blockwiseCopyOpBOddSecondIteration, op);
affixBlockwiseCopyAttributes(blockwiseCopyOpBOddSecondIteration, op, b,
/*isMatrixA=*/false);

// Emit blockwise GEMM.
auto blockwiseGemmOddOp = lb.create<miopen::BlockwiseGemmOp>(
Expand All @@ -1570,15 +1623,16 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
auto blockwiseCopyAEvenSecondIteration = lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), threadAOddAllocOp, lds2DMatrixAEvenSubviewOp,
blockwiseCopyASrc, blockwiseCopyADst, /*buffer=*/nullptr);
affixBlockwiseCopyAttributes(blockwiseCopyAEvenSecondIteration, op);
affixBlockwiseCopyAttributes(blockwiseCopyAEvenSecondIteration, op, b,
/*isMatrixA=*/true);
auto blockwiseCopyBEvenSecondIteration = lb.create<miopen::BlockwiseCopyOp>(
op.getLoc(), threadBOddAllocOp, lds2DMatrixBEvenSubviewOp,
blockwiseCopyBSrc, blockwiseCopyBDst, /*buffer=*/nullptr);
affixBlockwiseCopyAttributes(blockwiseCopyBEvenSecondIteration, op);
affixBlockwiseCopyAttributes(blockwiseCopyBEvenSecondIteration, op, b,
/*isMatrixA=*/false);

// outside the loop.


// LDS barrier.
b.create<miopen::WorkgroupBarrierOp>(op.getLoc());

Expand Down Expand Up @@ -1712,7 +1766,8 @@ struct BlockwiseGemmRewritePattern : public OpRewritePattern<miopen::BlockwiseGe
auto threadwiseCopyAMatrixOp = lab.create<miopen::ThreadwiseCopyOp>(
op.getLoc(), op.getOperand(0), threadAAllocOp,
matrixAThreadwiseCopySourceAndDestCoords);
affixThreadwiseCopyAttributes(threadwiseCopyAMatrixOp, op, b);
affixThreadwiseCopyAttributes(threadwiseCopyAMatrixOp, op, b,
/*isMatrixA=*/true);

// read matrix B loop.
auto loopReadMatrixBIteration = NRepeat;
Expand All @@ -1735,7 +1790,8 @@ struct BlockwiseGemmRewritePattern : public OpRewritePattern<miopen::BlockwiseGe
auto threadwiseCopyBMatrixOp = lbb.create<miopen::ThreadwiseCopyOp>(
op.getLoc(), op.getOperand(1), threadBAllocOp,
matrixBThreadwiseCopySourceAndDestCoords);
affixThreadwiseCopyAttributes(threadwiseCopyBMatrixOp, op, b);
affixThreadwiseCopyAttributes(threadwiseCopyBMatrixOp, op, b,
/*isMatrixA=*/false);

lb.create<miopen::ThreadwiseGemmOp>(op.getLoc(), threadAAllocOp,
threadBAllocOp, op.getOperand(2));
Expand Down Expand Up @@ -1803,7 +1859,8 @@ struct BlockwiseCopyRewritePattern : public OpRewritePattern<miopen::BlockwiseCo
auto threadwiseCopyLoadOp = b.create<miopen::ThreadwiseCopyOp>(
op.getLoc(), op.source(), op.buffer(),
ThreadwiseCopySourceAndBufferCoords);
affixThreadwiseCopyAttributes(threadwiseCopyLoadOp, op, b);
affixThreadwiseCopyAttributes(threadwiseCopyLoadOp, op, b,
/*isThreadwiseLoad=*/true);

// Threadwise copy from register (naive tensor) to LDS (naive tensor).
SmallVector<Value, 4> ThreadwiseCopyBufferAndDestCoords;
Expand All @@ -1819,7 +1876,8 @@ struct BlockwiseCopyRewritePattern : public OpRewritePattern<miopen::BlockwiseCo
auto threadwiseCopyStoreOp = b.create<miopen::ThreadwiseCopyOp>(
op.getLoc(), op.buffer(), op.dest(),
ThreadwiseCopyBufferAndDestCoords);
affixThreadwiseCopyAttributes(threadwiseCopyStoreOp, op, b);
affixThreadwiseCopyAttributes(threadwiseCopyStoreOp, op, b,
/*isThreadwiseLoad=*/false);
} else if (sourceType.getMemorySpace() == 0 && destType.getMemorySpace() == 5) {
// Threadwise copy from global (generic tensor) to register (naive
// tensor).
Expand All @@ -1836,7 +1894,8 @@ struct BlockwiseCopyRewritePattern : public OpRewritePattern<miopen::BlockwiseCo
auto threadwiseCopyLoadOp = b.create<miopen::ThreadwiseCopyOp>(
op.getLoc(), op.source(), op.dest(),
ThreadwiseCopySourceAndDestCoords);
affixThreadwiseCopyAttributes(threadwiseCopyLoadOp, op, b);
affixThreadwiseCopyAttributes(threadwiseCopyLoadOp, op, b,
/*isThreadwiseLoad=*/true);
} else if (sourceType.getMemorySpace() == 5 && destType.getMemorySpace() == 3) {
// Threadwise copy from register (naive tensor) to LDS (naive tensor).
SmallVector<Value, 4> ThreadwiseCopySourceAndDestCoords;
Expand All @@ -1852,7 +1911,8 @@ struct BlockwiseCopyRewritePattern : public OpRewritePattern<miopen::BlockwiseCo
auto threadwiseCopyStoreOp = b.create<miopen::ThreadwiseCopyOp>(
op.getLoc(), op.source(), op.dest(),
ThreadwiseCopySourceAndDestCoords);
affixThreadwiseCopyAttributes(threadwiseCopyStoreOp, op, b);
affixThreadwiseCopyAttributes(threadwiseCopyStoreOp, op, b,
/*isThreadwiseLoad=*/false);
} else {
llvm::errs() << "UNSUPPORTED ThreadwiseCopyOp\n";
rewritten = false;
Expand Down

0 comments on commit a89becd

Please sign in to comment.