-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][vector] Missing indices on vectorization of 1-d reduction to 1-ranked memref #166959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Simone Pellegrini (simpel01) ChangesVectorization of a 1-d reduction where the output variable is a 1-ranked memref can generate an invalid vector.transfer_write"(%vec, %buff) <{...}> : (vector<f32>, memref<1xf32>) -> () This patch solves the problem by providing the expected amount of indices (i.e. matching the rank of the memref). Full diff: https://github.com/llvm/llvm-project/pull/166959.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 19d2d854a3838..4eb2a0cb200a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -746,12 +746,12 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
auto vectorType = state.getCanonicalVecType(
getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
+ SmallVector<Value> indices(linalgOp.getRank(outputOperand),
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
+
Operation *write;
if (vectorType.getRank() > 0) {
AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
- SmallVector<Value> indices(
- linalgOp.getRank(outputOperand),
- arith::ConstantIndexOp::create(rewriter, loc, 0));
value = broadcastIfNeeded(rewriter, value, vectorType);
assert(value.getType() == vectorType && "Incorrect type");
write = vector::TransferWriteOp::create(
@@ -762,7 +762,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
assert(value.getType() == vectorType && "Incorrect type");
write = vector::TransferWriteOp::create(rewriter, loc, value,
- outputOperand->get(), ValueRange{});
+ outputOperand->get(), indices);
}
write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 9a14ab7d38d3e..6e63dfe7bb8e5 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -1523,6 +1523,49 @@ module attributes {transform.with_named_sequence} {
}
+// -----
+
+// CHECK-LABEL: func @reduce_1d_memref(
+// CHECK-SAME: %[[A:.*]]: memref<32xf32>
+// CHECK-SAME: %[[B:.*]]: memref<1xf32>
+func.func @reduce_1d_memref(%arg0: memref<32xf32>, %arg1: memref<1xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+ // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
+ // CHECK-SAME: : memref<32xf32>, vector<32xf32>
+ // CHECK: %[[init:.*]] = vector.transfer_read %[[B]][%[[C0]]]
+ // CHECK-SAME: : memref<1xf32>, vector<f32>
+ // CHECK: %[[init_scl:.*]] = vector.extract %[[init]][]
+ // CHECK-SAME: : f32 from vector<f32>
+ // CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[init_scl]] [0]
+ // CHECK-SAME: : vector<32xf32> to f32
+ // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
+ // CHECK: vector.transfer_write %[[red_v1]], %[[B]][%[[C0]]]
+ // CHECK-SAME: : vector<f32>, memref<1xf32>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (0)>],
+ iterator_types = ["reduction"]}
+ ins(%arg0 : memref<32xf32>)
+ outs(%arg1 : memref<1xf32>) {
+ ^bb0(%a: f32, %b: f32):
+ %0 = arith.addf %a, %b : f32
+ linalg.yield %0 : f32
+ }
+
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+
// -----
// This test checks that vectorization does not occur when an input indexing map
|
|
Thanks for the fix!
Note, this problem is not really specific to memrefs, but rather to rank-1 shaped-types. I also suspect that we could find a repro for this that would not involve reductions. The fix makes sense to me, but I have a few asks:
Thanks again and please send fixes if you identify more issues :) |
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % comments
| // CHECK-LABEL: func @reduce_1d_memref( | ||
| // CHECK-SAME: %[[A:.*]]: memref<32xf32> | ||
| // CHECK-SAME: %[[B:.*]]: memref<1xf32> | ||
| func.func @reduce_1d_memref(%arg0: memref<32xf32>, %arg1: memref<1xf32>) { | ||
| // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index | ||
|
|
||
| // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] | ||
| // CHECK-SAME: : memref<32xf32>, vector<32xf32> | ||
| // CHECK: %[[init:.*]] = vector.transfer_read %[[B]][%[[C0]]] | ||
| // CHECK-SAME: : memref<1xf32>, vector<f32> | ||
| // CHECK: %[[init_scl:.*]] = vector.extract %[[init]][] | ||
| // CHECK-SAME: : f32 from vector<f32> | ||
| // CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[init_scl]] [0] | ||
| // CHECK-SAME: : vector<32xf32> to f32 | ||
| // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32> | ||
| // CHECK: vector.transfer_write %[[red_v1]], %[[B]][%[[C0]]] | ||
| // CHECK-SAME: : vector<f32>, memref<1xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use UPPERCASE for FileCheck variables. I would also appreciate if you could update the previous test.
Also, could you use more descriptive variable names? Specifically, %arg0 -> %src and %arg1 -> %res. Please make the FileCheck variables match this. A more comprehensive guideline is available here:
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
…rank-1 shaped-types
Vectorization of a 1-d reduction where the output variable is a rank-1 shaped-type
can generate an invalid `vector.transfer_write` with no indices for the corresponding
input, e.g.:
vector.transfer_write"(%vec, %t) <{...}> : (vector<f32>, tensor<1xf32>) -> ()
This patch solves the problem by providing the expected amount of indices (i.e.
matching the rank of the shaped-type).
b67ae6a to
6e25c58
Compare
Vectorization of a 1-d reduction where the output variable is a 1-ranked memref can generate an invalid
vector.transfer_writewith no indices for the memref, e.g.:vector.transfer_write"(%vec, %buff) <{...}> : (vector, memref<1xf32>) -> ()
This patch solves the problem by providing the expected amount of indices (i.e. matching the rank of the memref).