Skip to content

Commit 9f6ba4b

Browse files
committed
[mlir][vector] Extend transfer_write to read propagation
Folding of transfer_write into transfer_read is already supported but this requires the read and write to have the same permuation map. After linalg vectorization it is common to have different ppermuation map for write followed by read even though the cases could be propagated. This canonicalization handle cases where the permuation maps are different but the data read and written match and replace the transfer ops with broadcast and permuation Differential Revision: https://reviews.llvm.org/D130135
1 parent 3878973 commit 9f6ba4b

File tree

3 files changed

+150
-1
lines changed

3 files changed

+150
-1
lines changed

mlir/include/mlir/Interfaces/VectorInterfaces.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,36 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
243243
fun(resultIdx, indicesIdx);
244244
}]
245245
>,
246+
InterfaceMethod<
247+
/*desc=*/[{
248+
Return an upper-bound shape accessed by the transfer op within the
249+
tensor/memref operand.
250+
For example:
251+
```
252+
vector.transfer %w0[%i, %j] {
253+
permutation_map = affine_map<(d0, d1) -> (d1, d0, 0)>} :
254+
tensor<?x?xf32>, vector<4x2x6xf32>
255+
```
256+
returns a shape [2, 4].
257+
}],
258+
/*retTy=*/"SmallVector<int64_t>",
259+
/*methodName=*/"getTransferChunkAccessed",
260+
/*args=*/(ins),
261+
/*methodBody=*/"",
262+
/*defaultImplementation=*/[{
263+
SmallVector<int64_t> dimSizes($_op.getPermutationMap().getNumDims(), 0);
264+
for (auto vecDims : llvm::zip($_op.getPermutationMap().getResults(),
265+
$_op.getVectorType().getShape())) {
266+
AffineExpr dim = std::get<0>(vecDims);
267+
int64_t size = std::get<1>(vecDims);
268+
// Skip broadcast.
269+
if (dim.isa<AffineConstantExpr>())
270+
continue;
271+
dimSizes[dim.cast<AffineDimExpr>().getPosition()] = size;
272+
}
273+
return dimSizes;
274+
}]
275+
>,
246276
];
247277
}
248278

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3284,11 +3284,91 @@ struct FoldExtractSliceIntoTransferRead
32843284
return success();
32853285
}
32863286
};
3287+
3288+
/// Store to load forwarding for transfer operations with permuation maps.
3289+
/// Even if the permutation maps are different we can still propagate the store
3290+
/// into the load if the size of the dimensions read and written match. Then we
3291+
/// can replace the transfer_read + transfer_write by vector.broadcast and
3292+
/// vector.transpose.
3293+
/// Example:
3294+
/// ```
3295+
/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
3296+
/// {in_bounds = [true, true],
3297+
/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
3298+
/// vector<4x1xf32>, tensor<4x4x4xf32>
3299+
/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
3300+
/// {in_bounds = [true, true, true, true],
3301+
/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
3302+
/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
3303+
/// ```
3304+
/// To:
3305+
/// ```
3306+
/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
3307+
/// %r = vector.transpose %0, [3, 0, 2, 1] :
3308+
/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
3309+
/// ```
3310+
struct TransferReadAfterWriteToBroadcast
3311+
: public OpRewritePattern<TransferReadOp> {
3312+
using OpRewritePattern<TransferReadOp>::OpRewritePattern;
3313+
3314+
LogicalResult matchAndRewrite(TransferReadOp readOp,
3315+
PatternRewriter &rewriter) const override {
3316+
if (readOp.hasOutOfBoundsDim() ||
3317+
!readOp.getShapedType().isa<RankedTensorType>())
3318+
return failure();
3319+
auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3320+
if (!defWrite)
3321+
return failure();
3322+
3323+
SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed();
3324+
Value vec;
3325+
if (readOp.getIndices() == defWrite.getIndices() &&
3326+
readOp.getMask() == defWrite.getMask()) {
3327+
SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed();
3328+
// TODO: If the writeDim is a superset of the read dims we could do an
3329+
// extract_strided_slice.
3330+
if (writeDims == readDims)
3331+
vec = defWrite.getVector();
3332+
}
3333+
// TODO: loop through the chain of transfer_write if we can prove that they
3334+
// don't overlap with the transfer_read. This requires improving
3335+
// `isDisjointTransferIndices` helper.
3336+
if (!vec)
3337+
return failure();
3338+
SmallVector<unsigned> permutation;
3339+
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
3340+
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
3341+
AffineMap map = readMap.compose(writeMap);
3342+
if (map.getNumResults() == 0)
3343+
return failure();
3344+
// Calculate the permuation to apply to go from the vector stored to the
3345+
// vector read.
3346+
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
3347+
return failure();
3348+
3349+
Location loc = readOp.getLoc();
3350+
// Calculate the broadcast shape by applying the reverse permuation to the
3351+
// final shape we want.
3352+
ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
3353+
SmallVector<int64_t> broadcastShape(destShape.size());
3354+
for (const auto &pos : llvm::enumerate(permutation))
3355+
broadcastShape[pos.value()] = destShape[pos.index()];
3356+
VectorType broadcastedType = VectorType::get(
3357+
broadcastShape, defWrite.getVectorType().getElementType());
3358+
vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
3359+
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
3360+
rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
3361+
transposePerm);
3362+
return success();
3363+
}
3364+
};
32873365
} // namespace
32883366

32893367
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
32903368
MLIRContext *context) {
3291-
results.add<FoldExtractSliceIntoTransferRead>(context);
3369+
results
3370+
.add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>(
3371+
context);
32923372
}
32933373

32943374
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,45 @@ func.func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
10581058

10591059
// -----
10601060

1061+
// CHECK-LABEL: func @store_to_load_tensor_broadcast
1062+
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>, %[[V0:.*]]: vector<4x2xf32>)
1063+
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x2xf32> to vector<6x4x2xf32>
1064+
// CHECK: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<6x4x2xf32> to vector<4x2x6xf32>
1065+
// CHECK: return %[[T]] : vector<4x2x6xf32>
1066+
func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
1067+
%v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
1068+
%c0 = arith.constant 0 : index
1069+
%cf0 = arith.constant 0.0 : f32
1070+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] {in_bounds = [true, true]} :
1071+
vector<4x2xf32>, tensor<4x4xf32>
1072+
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
1073+
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
1074+
tensor<4x4xf32>, vector<4x2x6xf32>
1075+
return %0 : vector<4x2x6xf32>
1076+
}
1077+
1078+
// -----
1079+
1080+
// CHECK-LABEL: func @store_to_load_tensor_perm_broadcast
1081+
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4x4xf32>, %[[V0:.*]]: vector<4x1xf32>)
1082+
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x1xf32> to vector<100x5x4x1xf32>
1083+
// CHECK: %[[T:.*]] = vector.transpose %[[B]], [3, 0, 2, 1] : vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
1084+
// CHECK: return %[[T]] : vector<1x100x4x5xf32>
1085+
func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>,
1086+
%v0 : vector<4x1xf32>) -> vector<1x100x4x5xf32> {
1087+
%c0 = arith.constant 0 : index
1088+
%cf0 = arith.constant 0.0 : f32
1089+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0] {in_bounds = [true, true],
1090+
permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
1091+
vector<4x1xf32>, tensor<4x4x4xf32>
1092+
%0 = vector.transfer_read %w0[%c0, %c0, %c0], %cf0 {in_bounds = [true, true, true, true],
1093+
permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
1094+
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
1095+
return %0 : vector<1x100x4x5xf32>
1096+
}
1097+
1098+
// -----
1099+
10611100

10621101
// CHECK-LABEL: func @dead_store_tensor
10631102
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index

0 commit comments

Comments
 (0)