Skip to content

Commit

Permalink
First implementation of threeadwise_copy considering generic tensors.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 45d62cc commit 014b84f
Showing 1 changed file with 155 additions and 20 deletions.
175 changes: 155 additions & 20 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -920,7 +920,10 @@ static void affixThreadwiseCopyAttributes(miopen::ThreadwiseCopyOp top,
top.setAttr("dest_data_per_write", b.getI32IntegerAttr(1));
} else {
top.setAttr("dim_access_order", bop.getAttr("dest_dim_access_order"));
top.setAttr("vector_read_write_dim", bop.getAttr("dest_vector_write_dim"));
// XXX. Figure this out. Symmetry is somehow lost here.
// top.setAttr("vector_read_write_dim",
// bop.getAttr("dest_vector_write_dim"));
top.setAttr("vector_read_write_dim", bop.getAttr("source_vector_read_dim"));
top.setAttr("source_data_per_read", b.getI32IntegerAttr(1));
top.setAttr("dest_data_per_write", bop.getAttr("dest_data_per_write"));
}
Expand Down Expand Up @@ -1259,6 +1262,14 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
auto n_block_data_on_global_i32 = b.create<IndexCastOp>(
loc, n_block_data_on_global, b.getIntegerType(32));

// llvm::errs() << "KPerBlock: " << KPerBlock << "\n";
// llvm::errs() << "MPerBlock: " << MPerBlock << "\n";
// llvm::errs() << "NPerBlock: " << NPerBlock << "\n";
// llvm::errs() << "matrix_a_source_data_per_read: " <<
// matrix_a_source_data_per_read << "\n"; llvm::errs() <<
// "matrix_b_source_data_per_read: " << matrix_b_source_data_per_read <<
// "\n";

// Compute ThreadClusterLengths for Matrix A.
int64_t GemmABlockCopyClusterLengths_GemmK =
KPerBlock / matrix_a_source_data_per_read;
Expand All @@ -1271,6 +1282,10 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
int64_t GemmABlockCopyThreadSliceLengths_GemmM =
MPerBlock / GemmABlockCopyClusterLengths_GemmM;

// llvm::errs() << "slice lengths for Matrix A\n";
// llvm::errs() << GemmABlockCopyThreadSliceLengths_GemmK << " ";
// llvm::errs() << GemmABlockCopyThreadSliceLengths_GemmM << "\n";

// Compute ThreadClusterLengths for Matrix B.
int64_t GemmBBlockCopyClusterLengths_GemmK =
KPerBlock /
Expand All @@ -1283,6 +1298,10 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
int64_t GemmBBlockCopyThreadSliceLengths_GemmN =
NPerBlock / GemmBBlockCopyClusterLengths_GemmN;

// llvm::errs() << "slice lengths for Matrix B\n";
// llvm::errs() << GemmBBlockCopyThreadSliceLengths_GemmK << " ";
// llvm::errs() << GemmBBlockCopyThreadSliceLengths_GemmN << "\n";

// Get current workitem ID.
auto tid = b.create<miopen::WorkitemIdOp>(loc, b.getIndexType());

Expand Down Expand Up @@ -2423,36 +2442,152 @@ struct ThreadwiseCopyRewritePattern
destLowerIndices, dstProjectionAttr);

} else {
auto vectorType = VectorType::get(4, sourceType.getElementType());
SmallVector<Value, 4> srcIndices;
SmallVector<Value, 4> dstIndices;

// TBD compute the actual indices following attached affine map, or
// coordinate.
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
srcIndices.push_back(zeroConstantIndexOp);
// The more elaborated algorithm.
// Refer to ThreadwiseGenericTensorSliceCopy_v4r2::Run() for the original
// C++ implementation.

// llvm::errs() << "\nthreadwise_copy op:\n";
// op.dump();
// llvm::errs() << "\n";

auto dimAccessOrder =
op.getAttr("dim_access_order").template cast<ArrayAttr>();
auto vectorAccessDim = op.getAttr("vector_read_write_dim")
.template cast<IntegerAttr>()
.getInt();
auto srcDataPerRead = op.getAttr("source_data_per_read")
.template cast<IntegerAttr>()
.getInt();
auto destDataPerWrite = op.getAttr("dest_data_per_write")
.template cast<IntegerAttr>()
.getInt();

auto longVectorSize = math::lcm(srcDataPerRead, destDataPerWrite);

// llvm::errs() << "vector_read_write_dim: " << vectorAccessDim << "\n";
// llvm::errs() << "source_data_per_read: " << srcDataPerRead << "\n";
// llvm::errs() << "dest_data_per_write: " << destDataPerWrite << "\n";
// llvm::errs() << "longVectorSize: " << longVectorSize << "\n";

// Figure out which memref is the one without affine transformations.
SmallVector<int64_t, 2> sliceLengths;
if (sourceExternalTransform || sourceEmbeddedTransform) {
if (destExternalTransform || destEmbeddedTransform) {
// Couldn't happen.
llvm::errs()
<< "Unsupported case: both memrefs have affine transforms!\n";
return failure();
} else {
for (auto dim : destType.getShape())
sliceLengths.push_back(dim);
}
} else {
if (sourceExternalTransform || sourceEmbeddedTransform) {
// Couldn't happen.
llvm::errs()
<< "Unsupported case: both memrefs have affine transforms!\n";
return failure();
} else
for (auto dim : sourceType.getShape())
sliceLengths.push_back(dim);
}
// llvm::errs() << "slice lengths: ";
// for (unsigned i = 0; i < sliceLengths.size(); ++i)
// llvm::errs() << sliceLengths[i] << " ";
// llvm::errs() << "\n";

// Modify slice lenths per vector access dim.
sliceLengths[vectorAccessDim] =
sliceLengths[vectorAccessDim] / longVectorSize;
SmallVector<Value, 2> loopBounds;
for (unsigned iter = 0; iter < sliceLengths.size(); ++iter)
loopBounds.push_back(
b.create<ConstantIndexOp>(loc, sliceLengths[iter]));

// llvm::errs() << "modified slice lengths: ";
// for (unsigned i = 0; i < sliceLengths.size(); ++i)
// llvm::errs() << sliceLengths[i] << " ";
// llvm::errs() << "\n";

// Emit loops for vector loads / stores.
SmallVector<loop::ForOp, 2> loopOps;
SmallVector<OpBuilder, 2> loopBuilders;
SmallVector<Value, 2> loopIVs;
SmallVector<Value, 2> loopIV_i32s;
for (unsigned iter = 0; iter < dimAccessOrder.size(); ++iter) {
auto dim = dimAccessOrder[iter].template cast<IntegerAttr>().getInt();
auto loopBuilder = (iter == 0) ? b : loopBuilders[iter - 1];

auto loopOp = loopBuilder.create<loop::ForOp>(
loc, zeroConstantIndexOp, loopBounds[dim], oneConstantIndexOp);
loopOps.push_back(loopOp);
auto loopOpBuilder = OpBuilder::atBlockTerminator(loopOp.getBody());
loopBuilders.push_back(loopOpBuilder);
auto loopIV = loopOp.getInductionVar();
loopIVs.push_back(loopIV);
auto loopIV_i32 = loopOpBuilder.create<IndexCastOp>(
loc, loopIV, b.getIntegerType(32));
loopIV_i32s.push_back(loopIV_i32);
}

// Emit loop body.
auto innerLoopBuilder = loopBuilders[loopBuilders.size() - 1];

// Compute high-level coordinate for source memref.
// src_index = (iv_0, iv_1, ...) + sourceCoord
SmallVector<Value, 8> srcUpperIndices;
for (unsigned iter = 0; iter < loopIV_i32s.size(); ++iter)
srcUpperIndices.push_back(innerLoopBuilder.create<IndexCastOp>(
loc,
innerLoopBuilder.create<AddIOp>(loc, loopIV_i32s[iter],
sourceCoord[iter]),
b.getIndexType()));

// TBD improve logic here to adhere original C++ implementation.
// Apply affine transformations to compute the low-level coordinate.
SmallVector<Value, 8> srcLowerIndices;
if (sourceExternalTransform || sourceEmbeddedTransform)
srcLowerIndices = expandAffineMap(innerLoopBuilder, loc,
sourceTransform, srcUpperIndices)
.getValue();
else
srcLowerIndices = srcUpperIndices;

// Load from source.
auto sourceVectorType =
VectorType::get(srcDataPerRead, sourceType.getElementType());
auto srcExpr =
getAffineDimExpr(sourceType.getRank() - 1, op.getContext());
auto srcProjection = AffineMap::get(sourceType.getRank(), 0, srcExpr);
auto srcProjectionAttr = AffineMapAttr::get(srcProjection);
auto vectorValue = b.create<vector::TransferReadOp>(
loc, vectorType, op.source(), srcIndices, srcProjectionAttr,
zeroConstantFloatOp);
auto vectorValue = innerLoopBuilder.create<vector::TransferReadOp>(
loc, sourceVectorType, op.source(), srcLowerIndices,
srcProjectionAttr, zeroConstantFloatOp);

// TBD compute the actual indices following attached affine map, or
// coordinate.
for (unsigned i = 0; i < destType.getRank(); ++i) {
dstIndices.push_back(zeroConstantIndexOp);
}
// Compute high-level coordinate for dest memref.
// dst_index = (iv_0, iv_1, ...) + destCoord
SmallVector<Value, 8> destUpperIndices;
for (unsigned iter = 0; iter < loopIV_i32s.size(); ++iter)
destUpperIndices.push_back(innerLoopBuilder.create<IndexCastOp>(
loc,
innerLoopBuilder.create<AddIOp>(loc, loopIV_i32s[iter],
destCoord[iter]),
b.getIndexType()));

// Apply affine transformations to compute the low-level coordinate.
SmallVector<Value, 8> destLowerIndices;
if (destExternalTransform || destEmbeddedTransform)
destLowerIndices = expandAffineMap(innerLoopBuilder, loc, destTransform,
destUpperIndices)
.getValue();
else
destLowerIndices = destUpperIndices;

// Store to dest.
auto dstExpr = getAffineDimExpr(destType.getRank() - 1, op.getContext());
auto dstProjection = AffineMap::get(destType.getRank(), 0, dstExpr);
auto dstProjectionAttr = AffineMapAttr::get(dstProjection);
b.create<vector::TransferWriteOp>(loc, vectorValue, op.dest(), dstIndices,
dstProjectionAttr);
innerLoopBuilder.create<vector::TransferWriteOp>(
loc, vectorValue, op.dest(), destLowerIndices, dstProjectionAttr);
}

op.erase();
Expand Down

0 comments on commit 014b84f

Please sign in to comment.