From e91dab928efa7ca68017f3a4b480243f81e877fe Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Sat, 3 Feb 2024 00:17:56 +0000 Subject: [PATCH] [mlir][memref] Fold memref.subview into out-of-bounds vector transfer ops This PR removes a precondition check to fold a `memref.subview` into a `vector.transfer_read` or `vector.transfer_write` with `in_bounds` set to false. There is no reason to not to do so as long as the same `in_bounds` is preserved after the folding, which is what the implementation was doing already. --- .../MemRef/Transforms/FoldMemRefAliasOps.cpp | 2 - .../Dialect/MemRef/fold-memref-alias-ops.mlir | 42 ++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index aa44455ada7f9..c15056cd168c9 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -348,8 +348,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, static_assert( !llvm::is_one_of::value, "must be a vector transfer op"); - if (xferOp.hasOutOfBoundsDim()) - return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim"); if (!subviewOp.hasUnitStride()) { return rewriter.notifyMatchFailure( xferOp, "non-1 stride subview, need to track strides in folded memref"); diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 5b853a6cc5a37..6ad13d2176b34 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -111,6 +111,27 @@ func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : in // ----- +func.func @fold_subview_with_oob_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> vector<32xf32> { + %f1 = arith.constant 1.0 : f32 + + %0 = memref.subview %arg0[%arg1, %arg2][8, 8][1, 1] : memref<12x32xf32> to memref<8x8xf32, strided<[256, 1], offset: ?>> + %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [false]} : memref<8x8xf32, strided<[256, 1], offset: ?>>, vector<32xf32> + return %1 : vector<32xf32> +} + +// CHECK: #[[MAP:[a-zA-Z0-9]+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func @fold_subview_with_oob_transfer_read +// CHECK-SAME: %[[MEM:[a-zA-Z0-9_]+]]: memref<12x32xf32> +// CHECK-SAME: %[[SZ0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[SZ1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[IDX0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[IDX1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[M0:[a-zA-Z0-9_]+]] = affine.apply #[[MAP]]()[%[[SZ0]], %[[IDX0]]] +// CHECK: %[[M1:[a-zA-Z0-9_]+]] = affine.apply #[[MAP]]()[%[[SZ1]], %[[IDX1]]] +// CHECK: vector.transfer_read %[[MEM]][%[[M0]], %[[M1]]], %{{[a-zA-Z0-9]+}} : memref<12x32xf32>, vector<32xf32> + +// ----- + func.func @fold_static_stride_subview_with_transfer_write_0d( %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %v : vector) { @@ -141,6 +162,25 @@ func.func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf // ----- +func.func @fold_subview_with_oob_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : vector<32xf32>) { + %0 = memref.subview %arg0[%arg1, %arg2][8, 8][1, 1] : memref<12x32xf32> to memref<8x8xf32, strided<[256, 1], offset: ?>> + vector.transfer_write %arg5, %0[%arg3, %arg4] {in_bounds = [false]} : vector<32xf32>, memref<8x8xf32, strided<[256, 1], offset: ?>> + return +} +// CHECK: #[[MAP:[a-zA-Z0-9]+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func @fold_subview_with_oob_transfer_write +// CHECK-SAME: %[[MEM:[a-zA-Z0-9_]+]]: memref<12x32xf32> +// CHECK-SAME: %[[SZ0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[SZ1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[IDX0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[IDX1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ST1:[a-zA-Z0-9_]+]]: vector<32xf32> +// CHECK: %[[M0:[a-zA-Z0-9_]+]] = affine.apply #[[MAP]]()[%[[SZ0]], %[[IDX0]]] +// CHECK: %[[M1:[a-zA-Z0-9_]+]] = affine.apply #[[MAP]]()[%[[SZ1]], %[[IDX1]]] +// CHECK: vector.transfer_write %[[ST1]], %[[MEM]][%[[M0]], %[[M1]]] : vector<32xf32>, memref<12x32xf32> + +// ----- + func.func @fold_rank_reducing_subview_with_load (%arg0 : memref, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, @@ -633,7 +673,7 @@ func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, // ----- // CHECK-LABEL: func @fold_store_keep_nontemporal( -// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32> +// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32> func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>