Skip to content

Commit

Permalink
Use scalar load / store transferring data from/to naive tensor for now.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 10, 2020
1 parent b316228 commit b93f6ed
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -1058,11 +1058,15 @@ static void affixThreadwiseCopyAttributes(miopen::ThreadwiseCopyOp top,
if (isMatrixA) {
top.setAttr("n_slice_row", bop.getAttr("k_per_thread"));
top.setAttr("n_slice_col", bop.getAttr("m_per_thread"));
top.setAttr("data_per_access", bop.getAttr("m_per_thread"));
// XXX: TBD review how vector load/store attributes are passed down.
//top.setAttr("data_per_access", bop.getAttr("m_per_thread"));
top.setAttr("data_per_access", b.getI32IntegerAttr(1));
} else {
top.setAttr("n_slice_row", bop.getAttr("k_per_thread"));
top.setAttr("n_slice_col", bop.getAttr("n_per_thread"));
top.setAttr("data_per_access", bop.getAttr("n_per_thread"));
// XXX: TBD review how vector load/store attributes are passed down.
//top.setAttr("data_per_access", bop.getAttr("n_per_thread"));
top.setAttr("data_per_access", b.getI32IntegerAttr(1));
}
}

Expand Down Expand Up @@ -2535,14 +2539,22 @@ struct ThreadwiseCopyRewritePattern
else
srcLowerIndices = srcUpperIndices;

Value vectorValue;
Value scalarValue;
// Load from source.
auto vectorType =
VectorType::get(DataPerAccess, sourceType.getElementType());
auto srcExpr =
getAffineDimExpr(sourceType.getRank() - 1, op.getContext());
auto srcProjection = AffineMap::get(sourceType.getRank(), 0, srcExpr);
auto vectorValue = lib.create<vector::TransferReadOp>(
loc, vectorType, op.source(), srcLowerIndices, srcProjection);
if (DataPerAccess > 1) {
// Issue vector load.
auto vectorType =
VectorType::get(DataPerAccess, sourceType.getElementType());
auto srcExpr =
getAffineDimExpr(sourceType.getRank() - 1, op.getContext());
auto srcProjection = AffineMap::get(sourceType.getRank(), 0, srcExpr);
vectorValue = lib.create<vector::TransferReadOp>(
loc, vectorType, op.source(), srcLowerIndices, srcProjection);
} else {
// Issue scalar load.
scalarValue = lib.create<LoadOp>(loc, sourceType.getElementType(), op.source(), srcLowerIndices);
}

// Compute high-level coordinate for dest memref.
// dst_index = (ivo_i32, ivi_i32) + destCoord
Expand All @@ -2564,10 +2576,15 @@ struct ThreadwiseCopyRewritePattern
destLowerIndices = destUpperIndices;

// Store to dest.
auto dstExpr = getAffineDimExpr(destType.getRank() - 1, op.getContext());
auto dstProjection = AffineMap::get(destType.getRank(), 0, dstExpr);
lib.create<vector::TransferWriteOp>(loc, vectorValue, op.dest(),
destLowerIndices, dstProjection);
if (DataPerAccess > 1) {
auto dstExpr = getAffineDimExpr(destType.getRank() - 1, op.getContext());
auto dstProjection = AffineMap::get(destType.getRank(), 0, dstExpr);
lib.create<vector::TransferWriteOp>(loc, vectorValue, op.dest(),
destLowerIndices, dstProjection);
} else {
// Issue scalar store.
lib.create<StoreOp>(loc, scalarValue, op.dest(), destLowerIndices);
}

} else {
// The more elaborated algorithm.
Expand Down

0 comments on commit b93f6ed

Please sign in to comment.