-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Restrict DropInnerMostUnitDimsTransfer{Read|Write} #96218
[mlir][vector] Restrict DropInnerMostUnitDimsTransfer{Read|Write} #96218
Conversation
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/96218.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b824508728ac8..890cfe2746dae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1395,6 +1395,11 @@ class DropInnerMostUnitDimsTransferWrite
if (dimsToDrop == 0)
return failure();
+ // Make sure that the indices to be dropped are equal 0.
+ // TODO: Deal with cases when the indices are not 0.
+ if (!llvm::all_of(writeOp.getIndices().take_back(dimsToDrop), isZeroIndex))
+ return failure();
+
auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType(),
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 5183205db1b47..df1ae547bcdfa 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -1,5 +1,7 @@
// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
+// TODO: Unify how memref and vectors are named
+
//-----------------------------------------------------------------------------
// 1. vector.transfer_read
//-----------------------------------------------------------------------------
@@ -254,14 +256,14 @@ func.func @negative_non_unit_inner_memref_dim(%arg0: memref<4x8xf32>) -> vector<
// 2. vector.transfer_write
//-----------------------------------------------------------------------------
-func.func @drop_two_inner_most_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
+func.func @contiguous_inner_most(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
{in_bounds = [true, true, true, true, true]}
: vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
return
}
-// CHECK: func.func @drop_two_inner_most_dim
+// CHECK: func.func @contiguous_inner_most
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -276,14 +278,14 @@ func.func @drop_two_inner_most_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vecto
// dim scalable. Note that this example only makes sense when "16 = [16]" (i.e.
// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
-func.func @drop_two_inner_most_dim_scalable_inner_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x[16]x1x1xf32>, %arg2: index) {
+func.func @contiguous_inner_most_scalable_inner_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x[16]x1x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
{in_bounds = [true, true, true, true, true]}
: vector<1x16x[16]x1x1xf32>, memref<1x512x16x1x1xf32>
return
}
-// CHECK: func.func @drop_two_inner_most_dim_scalable_inner_dim
+// CHECK: func.func @contiguous_inner_most_scalable_inner_dim
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -325,6 +327,73 @@ func.func @negative_scalable_one_trailing_dim(%arg0: memref<1x512x16x1x1xf32>, %
// -----
+func.func @contiguous_inner_most_dynamic_outer(%a: index, %b: index, %arg0: memref<?x?x16x1xf32>, %arg1: vector<8x1xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%a, %b, %c0, %c0] {in_bounds = [true, true]} : vector<8x1xf32>, memref<?x?x16x1xf32>
+ return
+}
+// CHECK-LABEL: func.func @contiguous_inner_most_dynamic_outer(
+// CHECK-SAME: %[[IDX_0:.*]]: index, %[[IDX_1:.*]]: index,
+// CHECK-SAME: %[[MEM:.*]]: memref<?x?x16x1xf32>,
+// CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>) {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM0:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?x?x16x1xf32>
+// CHECK: %[[DIM1:.*]] = memref.dim %[[MEM]], %[[C1]] : memref<?x?x16x1xf32>
+// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0, 0, 0] {{\[}}%[[DIM0]], %[[DIM1]], 16, 1] [1, 1, 1, 1] : memref<?x?x16x1xf32> to memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
+// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX_0]], %[[IDX_1]], %[[C0]]] {in_bounds = [true]} : vector<8xf32>, memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
+
+// Same as the top example within this split, but with the outer vector
+// dim scalable. Note that this example only makes sense when "8 = [8]" (i.e.
+// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
+
+func.func @contiguous_inner_most_dynamic_outer_scalable_inner_dim(%a: index, %b: index, %arg0: memref<?x?x16x1xf32>, %arg1: vector<[8]x1xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%a, %b, %c0, %c0] {in_bounds = [true, true]} : vector<[8]x1xf32>, memref<?x?x16x1xf32>
+ return
+}
+// CHECK-LABEL: func.func @contiguous_inner_most_dynamic_outer_scalable_inner_dim(
+// CHECK-SAME: %[[IDX_0:.*]]: index, %[[IDX_1:.*]]: index,
+// CHECK-SAME: %[[MEM:.*]]: memref<?x?x16x1xf32>,
+// CHECK-SAME: %[[VEC:.*]]: vector<[8]x1xf32>) {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM0:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?x?x16x1xf32>
+// CHECK: %[[DIM1:.*]] = memref.dim %[[MEM]], %[[C1]] : memref<?x?x16x1xf32>
+// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0, 0, 0] {{\[}}%[[DIM0]], %[[DIM1]], 16, 1] [1, 1, 1, 1] : memref<?x?x16x1xf32> to memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<[8]x1xf32> to vector<[8]xf32>
+// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX_0]], %[[IDX_1]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf32>, memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
+
+// -----
+
+func.func @contiguous_inner_most_non_zero_idxs(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%i, %c0] {in_bounds = [true, true]} : vector<8x1xf32>, memref<16x1xf32>
+ return
+}
+// CHECK-LABEL: func.func @contiguous_inner_most_non_zero_idxs(
+// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
+// CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) {
+// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
+// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>>
+
+// The index to be dropped is != 0 - this is currently not supported.
+
+func.func @negative_contiguous_inner_most_dim_non_zero_idxs(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%i, %i] {in_bounds = [true, true]} : vector<8x1xf32>, memref<16x1xf32>
+ return
+}
+// CHECK-LABEL: func @negative_contiguous_inner_most_dim_non_zero_idxs
+// CHECK-NOT: memref.subview
+// CHECK-NOT: memref.shape_cast
+// CHECK: vector.transfer_write
+
+// -----
+
func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
@@ -345,27 +414,6 @@ func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16,
// -----
-func.func @outer_dyn_drop_inner_most_dim(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
- {in_bounds = [true, true, true, true]}
- : vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
- return
-}
-// CHECK: func.func @outer_dyn_drop_inner_most_dim
-// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
-// CHECK-SAME: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
-// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
-// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
-// CHECK-SAME: [%[[IDX]], %[[C0]], %[[C0]]]
-
-// -----
-
func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/96218.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b824508728ac8..890cfe2746dae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1395,6 +1395,11 @@ class DropInnerMostUnitDimsTransferWrite
if (dimsToDrop == 0)
return failure();
+ // Make sure that the indices to be dropped are equal 0.
+ // TODO: Deal with cases when the indices are not 0.
+ if (!llvm::all_of(writeOp.getIndices().take_back(dimsToDrop), isZeroIndex))
+ return failure();
+
auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType(),
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 5183205db1b47..df1ae547bcdfa 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -1,5 +1,7 @@
// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
+// TODO: Unify how memref and vectors are named
+
//-----------------------------------------------------------------------------
// 1. vector.transfer_read
//-----------------------------------------------------------------------------
@@ -254,14 +256,14 @@ func.func @negative_non_unit_inner_memref_dim(%arg0: memref<4x8xf32>) -> vector<
// 2. vector.transfer_write
//-----------------------------------------------------------------------------
-func.func @drop_two_inner_most_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
+func.func @contiguous_inner_most(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
{in_bounds = [true, true, true, true, true]}
: vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
return
}
-// CHECK: func.func @drop_two_inner_most_dim
+// CHECK: func.func @contiguous_inner_most
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -276,14 +278,14 @@ func.func @drop_two_inner_most_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vecto
// dim scalable. Note that this example only makes sense when "16 = [16]" (i.e.
// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
-func.func @drop_two_inner_most_dim_scalable_inner_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x[16]x1x1xf32>, %arg2: index) {
+func.func @contiguous_inner_most_scalable_inner_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x[16]x1x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
{in_bounds = [true, true, true, true, true]}
: vector<1x16x[16]x1x1xf32>, memref<1x512x16x1x1xf32>
return
}
-// CHECK: func.func @drop_two_inner_most_dim_scalable_inner_dim
+// CHECK: func.func @contiguous_inner_most_scalable_inner_dim
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -325,6 +327,73 @@ func.func @negative_scalable_one_trailing_dim(%arg0: memref<1x512x16x1x1xf32>, %
// -----
+func.func @contiguous_inner_most_dynamic_outer(%a: index, %b: index, %arg0: memref<?x?x16x1xf32>, %arg1: vector<8x1xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%a, %b, %c0, %c0] {in_bounds = [true, true]} : vector<8x1xf32>, memref<?x?x16x1xf32>
+ return
+}
+// CHECK-LABEL: func.func @contiguous_inner_most_dynamic_outer(
+// CHECK-SAME: %[[IDX_0:.*]]: index, %[[IDX_1:.*]]: index,
+// CHECK-SAME: %[[MEM:.*]]: memref<?x?x16x1xf32>,
+// CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>) {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM0:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?x?x16x1xf32>
+// CHECK: %[[DIM1:.*]] = memref.dim %[[MEM]], %[[C1]] : memref<?x?x16x1xf32>
+// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0, 0, 0] {{\[}}%[[DIM0]], %[[DIM1]], 16, 1] [1, 1, 1, 1] : memref<?x?x16x1xf32> to memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
+// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX_0]], %[[IDX_1]], %[[C0]]] {in_bounds = [true]} : vector<8xf32>, memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
+
+// Same as the top example within this split, but with the outer vector
+// dim scalable. Note that this example only makes sense when "8 = [8]" (i.e.
+// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
+
+func.func @contiguous_inner_most_dynamic_outer_scalable_inner_dim(%a: index, %b: index, %arg0: memref<?x?x16x1xf32>, %arg1: vector<[8]x1xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%a, %b, %c0, %c0] {in_bounds = [true, true]} : vector<[8]x1xf32>, memref<?x?x16x1xf32>
+ return
+}
+// CHECK-LABEL: func.func @contiguous_inner_most_dynamic_outer_scalable_inner_dim(
+// CHECK-SAME: %[[IDX_0:.*]]: index, %[[IDX_1:.*]]: index,
+// CHECK-SAME: %[[MEM:.*]]: memref<?x?x16x1xf32>,
+// CHECK-SAME: %[[VEC:.*]]: vector<[8]x1xf32>) {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM0:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?x?x16x1xf32>
+// CHECK: %[[DIM1:.*]] = memref.dim %[[MEM]], %[[C1]] : memref<?x?x16x1xf32>
+// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0, 0, 0] {{\[}}%[[DIM0]], %[[DIM1]], 16, 1] [1, 1, 1, 1] : memref<?x?x16x1xf32> to memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<[8]x1xf32> to vector<[8]xf32>
+// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX_0]], %[[IDX_1]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf32>, memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
+
+// -----
+
+func.func @contiguous_inner_most_non_zero_idxs(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%i, %c0] {in_bounds = [true, true]} : vector<8x1xf32>, memref<16x1xf32>
+ return
+}
+// CHECK-LABEL: func.func @contiguous_inner_most_non_zero_idxs(
+// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
+// CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) {
+// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
+// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>>
+
+// The index to be dropped is != 0 - this is currently not supported.
+
+func.func @negative_contiguous_inner_most_dim_non_zero_idxs(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%i, %i] {in_bounds = [true, true]} : vector<8x1xf32>, memref<16x1xf32>
+ return
+}
+// CHECK-LABEL: func @negative_contiguous_inner_most_dim_non_zero_idxs
+// CHECK-NOT: memref.subview
+// CHECK-NOT: memref.shape_cast
+// CHECK: vector.transfer_write
+
+// -----
+
func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
@@ -345,27 +414,6 @@ func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16,
// -----
-func.func @outer_dyn_drop_inner_most_dim(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
- {in_bounds = [true, true, true, true]}
- : vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
- return
-}
-// CHECK: func.func @outer_dyn_drop_inner_most_dim
-// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
-// CHECK-SAME: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
-// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
-// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
-// CHECK-SAME: [%[[IDX]], %[[C0]], %[[C0]]]
-
-// -----
-
func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]
|
b2e8f75
to
265ace6
Compare
@@ -367,6 +367,33 @@ func.func @contiguous_inner_most_dynamic_outer_scalable_inner_dim(%a: index, %b: | |||
|
|||
// ----- | |||
|
|||
func.func @contiguous_inner_most_non_zero_idxs(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What additional feature does this test show compared to contiguous_inner_most_dynamic_outer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! At one point I convinced myself that there was a difference, but now I see that I was wrong 😅 Let me delete this.
I'm not convinced that the transform you show in the description is invalid.
Looking at this there's no mask and both dims are marked as in_bounds. Therefore So the program is invalid (and I think it invokes UB) if If So, I think you just need to check the |
Restrict `DropInnerMostUnitDimsTransferWrite` so that it fails when one of the indices to be dropped could be != 0, e.g. ```mlir func.func @negative_example( %arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %idx_1: index, %idx_2: index) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%idx_1, %idx_2] {in_bounds = [true, true]} : vector<8x1xf32>, memref<16x1xf32> return } ``` This is an edge case that could represent an out-of-bounds access, though that will depend on the actual value of `%i`. Importantly, _without this change_ it would be transformed as follows: ```mlir func.func @negative_example( %arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %idx_1: index, %idx_2: index) { %subview = memref.subview %arg0[0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>> %0 = vector.shape_cast %arg1 : vector<8x1xf32> to vector<8xf32> vector.transfer_write %0, %subview[%idx_1] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>> return } ``` This is incorrect - `%idx_2` is ignored. Hence the extra restriction to avoid such cases. NOTE: This PR is limited to `vector.transfer_write`. Similar patch for `vector.transfer_read`: llvm#94904
Remove duplicate test
ccc440d
to
9441eb0
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
Allow non-zero indices when in-bounds
9441eb0
to
6a3fe47
Compare
That's a very good point, thanks for taking the time to analyse this.
👍🏻 For that I'd like to clarify/formalise the interplay between masks and
Agreed, thanks! That's already updated (see the latest commit). I have also added a few tests to make the distinct cases very clear. Also, while |
bool indexOutOfBounds = true; | ||
if (writeOp.getInBounds()) | ||
indexOutOfBounds = llvm::any_of( | ||
llvm::zip(writeOp.getInBounds()->getValue().take_back(dimsToDrop), | ||
writeOp.getIndices().take_back(dimsToDrop)), | ||
[](auto zipped) { | ||
auto inBounds = cast<BoolAttr>(std::get<0>(zipped)).getValue(); | ||
auto nonZeroIdx = !isZeroIndex(std::get<1>(zipped)); | ||
return !inBounds && nonZeroIdx; | ||
}); | ||
else | ||
indexOutOfBounds = !llvm::all_of( | ||
writeOp.getIndices().take_back(dimsToDrop), isZeroIndex); | ||
if (indexOutOfBounds) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks much more complex than I'd expect. I don't see the need to check the indices at all.
Also, dropping the dim if the index is zero and marked as out-of-bounds does not seem valid. If index zero is out-of-bounds, then we can't safely write to that unit dimension (I'm not sure of the precise meaning of that, but it does not seem like a safe transform in that case...).
Why not just:
auto inBounds = writeOp.getInBoundsValues();
auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
if (llvm::is_contained(droppedInBounds, false))
return failure();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(or maybe even):
if (writeOp.hasOutOfBoundsDim())
return failure();
out-of-bounds dims are an edge-case (that I've not really seen used in practice, so are probably not worth worrying about much).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side note: If index 0 and in-bounds = false for a unit-dim actually means in-bounds = true, that should be an xferOp fold (not something everyone that looks at in-bounds values needs to interpret).
Edit: Looks like it already is! https://godbolt.org/z/9Gqo37YeG :) It's an actual fold too (foldTransferInBoundsAttribute
, so it's running pretty much all the time).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dropping the dim if the index is zero and marked as out-of-bounds does not seem valid. If index zero is out-of-bounds, then we can't safely write to that unit dimension
As you observed further down:
If index 0 and in-bounds = false for a unit-dim actually means in-bounds = true
:) So:
- if the index is == 0 then
in_bounds
is effectively irrelevant (safe to collapse) - if index is != 0, but
in_bounds = true
, the index is effectively ==0 and (safe to collapse) - if index != 0 and
in_bounds = false
, bail out.
Note that once in_bounds
is mandatory, I will be able to simply the above as:
if(llvm::any_of(
llvm::zip(writeOp.getInBounds()->getValue().take_back(dimsToDrop),
writeOp.getIndices().take_back(dimsToDrop)),
[](auto zipped) {
auto inBounds = cast<BoolAttr>(std::get<0>(zipped)).getValue();
auto nonZeroIdx = !isZeroIndex(std::get<1>(zipped));
return !inBounds && nonZeroIdx;
}))
return failure();
out-of-bounds dims are an edge-case
Not sure we can say that - the default for in_bounds
is "out of bounds".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an actual fold too (foldTransferInBoundsAttribute, so it's running pretty much all the time).
I knew about the folder, but incorrectly assumed that the "folder" wasn't guaranteed to be run before the pattern. I was wrong:
llvm-project/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Lines 107 to 108 in 1ed84a8
/// Also performs folding and simple dead-code elimination before attempting to | |
/// match any of the provided patterns. |
/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
So:
Why not just:
auto inBounds = writeOp.getInBoundsValues();
auto droppedInBounds = ArrayRef(inBounds).take_back(dimsToDrop);
if (llvm::is_contained(droppedInBounds, false))
return failure();
Yeah, that's perfectly sufficient and covers all the cases. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure we can say that - the default for in_bounds is "out of bounds".
I think this is definitely safer.
…nsferWrite Simplify as per Ben's suggestion. Given that the Op folder is guaranteed to run before the pattern, we can safely assume that the in_bounds attribute already contains all the info that we need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (though I think #94904 should be updated like this too :))
(also PR description needs updating) |
…nsferWrite Extend to xfer_read
7f3fbf8
to
2342b2f
Compare
LGTM 👍 |
…vm#96218) Restrict `DropInnerMostUnitDimsTransfer{Read|Write}` so that it fails when one of the indices to be dropped could be != 0 and "out of bounds": ```mlir func.func @negative_example(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %idx_1: index, %idx_2: index) { vector.transfer_write %arg1, %arg0[%idx_1, %idx_2] {in_bounds = [true, false]} : vector<8x1xf32>, memref<16x1xf32> return } ``` This is an edge case that could represent an out-of-bounds access, though that will depend on the actual value of %i. Importantly, without this change it would be transformed as follows: ```mlir func.func @negative_example(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %arg2: index, %arg3: index) { %subview = memref.subview %arg0[0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>> %0 = vector.shape_cast %arg1 : vector<8x1xf32> to vector<8xf32> vector.transfer_write %0, %subview[%arg2] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>> return } ``` This is incorrect - `%idx_2` is ignored and the "out of bounds" flags is not propagated. Hence the extra restriction to avoid such cases. NOTE: This is a follow-up for: llvm#94904
Restrict
DropInnerMostUnitDimsTransfer{Read|Write}
so that it fails whenone of the indices to be dropped could be != 0 and "out of bounds":
This is an edge case that could represent an out-of-bounds access,
though that will depend on the actual value of %i. Importantly, without
this change it would be transformed as follows:
This is incorrect -
%idx_2
is ignored and the "out of bounds" flags isnot propagated. Hence the extra restriction to avoid such cases.
NOTE: This is a follow-up for: #94904