diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp index 1feda57d8de036..5ac3113f620e44 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp @@ -90,12 +90,13 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter, } /// Helpers to access the memref operand for each op. -static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); } +template +static Value getMemRefOperand(LoadOrStoreOpTy op) { + return op.memref(); +} static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } -static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); } - static Value getMemRefOperand(vector::TransferWriteOp op) { return op.source(); } @@ -154,12 +155,12 @@ class StoreOpOfSubViewFolder final : public OpRewritePattern { PatternRewriter &rewriter) const; }; -template <> -void LoadOpOfSubViewFolder::replaceOp( - memref::LoadOp loadOp, memref::SubViewOp subViewOp, - ArrayRef sourceIndices, PatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), - sourceIndices); +template +void LoadOpOfSubViewFolder::replaceOp( + LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, + PatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), + sourceIndices); } template <> @@ -178,12 +179,12 @@ void LoadOpOfSubViewFolder::replaceOp( /*mask=*/Value(), transferReadOp.in_boundsAttr()); } -template <> -void StoreOpOfSubViewFolder::replaceOp( - memref::StoreOp storeOp, memref::SubViewOp subViewOp, +template +void StoreOpOfSubViewFolder::replaceOp( + StoreOpTy storeOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp( - storeOp, storeOp.value(), subViewOp.source(), sourceIndices); + rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), + subViewOp.source(), sourceIndices); } template <> @@ -239,8 +240,10 @@ StoreOpOfSubViewFolder::matchAndRewrite(OpTy storeOp, } void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { - patterns.add, + patterns.add, + LoadOpOfSubViewFolder, LoadOpOfSubViewFolder, + StoreOpOfSubViewFolder, StoreOpOfSubViewFolder, StoreOpOfSubViewFolder>( patterns.getContext()); diff --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir index e177bb37f936d1..fc06cd35cc8cd3 100644 --- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir @@ -251,3 +251,24 @@ func @fold_vector_transfer_write_with_inner_rank_reduced_subview( // CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG7]])[%[[ARG3]]] // CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]] // CHECK-SAME: {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, memref, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 { + %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + %1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> + // CHECK-NEXT: affine.apply + // CHECK-NEXT: affine.apply + // CHECK-NEXT: affine.load + affine.store %1, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> + // CHECK-NEXT: affine.apply + // CHECK-NEXT: affine.apply + // CHECK-NEXT: affine.store + // CHECK-NEXT: return + return %1 : f32 +}