Skip to content

Conversation

@simpel01
Copy link
Contributor

@simpel01 simpel01 commented Nov 7, 2025

Vectorization of a 1-d reduction where the output variable is a 1-ranked memref can generate an invalid vector.transfer_write with no indices for the memref, e.g.:

vector.transfer_write"(%vec, %buff) <{...}> : (vector, memref<1xf32>) -> ()

This patch solves the problem by providing the expected amount of indices (i.e. matching the rank of the memref).

@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Simone Pellegrini (simpel01)

Changes

Vectorization of a 1-d reduction where the output variable is a 1-ranked memref can generate an invalid vector.transfer_write with no indices for the memref, e.g.:

vector.transfer_write"(%vec, %buff) <{...}> : (vector<f32>, memref<1xf32>) -> ()

This patch solves the problem by providing the expected amount of indices (i.e. matching the rank of the memref).


Full diff: https://github.com/llvm/llvm-project/pull/166959.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+4-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir (+43)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 19d2d854a3838..4eb2a0cb200a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -746,12 +746,12 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
   auto vectorType = state.getCanonicalVecType(
       getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
 
+  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
+                             arith::ConstantIndexOp::create(rewriter, loc, 0));
+
   Operation *write;
   if (vectorType.getRank() > 0) {
     AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
-    SmallVector<Value> indices(
-        linalgOp.getRank(outputOperand),
-        arith::ConstantIndexOp::create(rewriter, loc, 0));
     value = broadcastIfNeeded(rewriter, value, vectorType);
     assert(value.getType() == vectorType && "Incorrect type");
     write = vector::TransferWriteOp::create(
@@ -762,7 +762,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
       value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
     assert(value.getType() == vectorType && "Incorrect type");
     write = vector::TransferWriteOp::create(rewriter, loc, value,
-                                            outputOperand->get(), ValueRange{});
+                                            outputOperand->get(), indices);
   }
 
   write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 9a14ab7d38d3e..6e63dfe7bb8e5 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -1523,6 +1523,49 @@ module attributes {transform.with_named_sequence} {
 }
 
 
+// -----
+
+//  CHECK-LABEL: func @reduce_1d_memref(
+//   CHECK-SAME:   %[[A:.*]]: memref<32xf32>
+//   CHECK-SAME:   %[[B:.*]]: memref<1xf32>
+func.func @reduce_1d_memref(%arg0: memref<32xf32>, %arg1: memref<1xf32>) {
+  //  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+  //      CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
+  // CHECK-SAME:   : memref<32xf32>, vector<32xf32>
+  //      CHECK: %[[init:.*]] = vector.transfer_read %[[B]][%[[C0]]]
+  // CHECK-SAME:   : memref<1xf32>, vector<f32>
+  //      CHECK: %[[init_scl:.*]] = vector.extract %[[init]][]
+  // CHECK-SAME:   : f32 from vector<f32>
+  //      CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[init_scl]] [0]
+  // CHECK-SAME:   : vector<32xf32> to f32
+  //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
+  //      CHECK: vector.transfer_write %[[red_v1]], %[[B]][%[[C0]]]
+  // CHECK-SAME:   : vector<f32>, memref<1xf32>
+  linalg.generic {
+         indexing_maps = [affine_map<(d0) -> (d0)>,
+                          affine_map<(d0) -> (0)>],
+         iterator_types = ["reduction"]}
+         ins(%arg0 : memref<32xf32>)
+         outs(%arg1 : memref<1xf32>) {
+    ^bb0(%a: f32, %b: f32):
+      %0 = arith.addf %a, %b : f32
+      linalg.yield %0 : f32
+    }
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+
 // -----
 
 // This test checks that vectorization does not occur when an input indexing map

@banach-space
Copy link
Contributor

Thanks for the fix!

Vectorization of a 1-d reduction where the output variable is a 1-ranked memref can generate an invalid vector.transfer_write with no indices for the memref, e.g.:

vector.transfer_write"(%vec, %buff) <{...}> : (vector, memref<1xf32>) -> ()

This patch solves the problem by providing the expected amount of indices (i.e. matching the rank of the memref).

Note, this problem is not really specific to memrefs, but rather to rank-1 shaped-types. I also suspect that we could find a repro for this that would not involve reductions.

The fix makes sense to me, but I have a few asks:

  • Replace references to memrefs in the summary with references to rank-1 shaped-types, i.e. the actual root cause.
  • There's no need to use memrefs in tests, is there? I suggest using tensors consistently.
  • Update test function names, i.e. %reduce_1d -> %reduce_to_rank_0, %reduce_1d_memref -> %reduce_to_rank_1 (or %reduce_to_rank_1_memref if you want to keep memrefs).

Thanks again and please send fixes if you identify more issues :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % comments

Comment on lines 1528 to 1544
// CHECK-LABEL: func @reduce_1d_memref(
// CHECK-SAME: %[[A:.*]]: memref<32xf32>
// CHECK-SAME: %[[B:.*]]: memref<1xf32>
func.func @reduce_1d_memref(%arg0: memref<32xf32>, %arg1: memref<1xf32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index

// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
// CHECK-SAME: : memref<32xf32>, vector<32xf32>
// CHECK: %[[init:.*]] = vector.transfer_read %[[B]][%[[C0]]]
// CHECK-SAME: : memref<1xf32>, vector<f32>
// CHECK: %[[init_scl:.*]] = vector.extract %[[init]][]
// CHECK-SAME: : f32 from vector<f32>
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[init_scl]] [0]
// CHECK-SAME: : vector<32xf32> to f32
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
// CHECK: vector.transfer_write %[[red_v1]], %[[B]][%[[C0]]]
// CHECK-SAME: : vector<f32>, memref<1xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use UPPERCASE for FileCheck variables. I would also appreciate if you could update the previous test.

Also, could you use more descriptive variable names? Specifically, %arg0 -> %src and %arg1 -> %res. Please make the FileCheck variables match this. A more comprehensive guideline is available here:

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

…rank-1 shaped-types

Vectorization of a 1-d reduction where the output variable is a rank-1 shaped-type
can generate an invalid `vector.transfer_write` with no indices for the corresponding
input, e.g.:

  vector.transfer_write"(%vec, %t) <{...}> : (vector<f32>, tensor<1xf32>) -> ()

This patch solves the problem by providing the expected amount of indices (i.e.
matching the rank of the shaped-type).
@simpel01 simpel01 force-pushed the vectorization-reduction-fix branch from b67ae6a to 6e25c58 Compare November 17, 2025 13:17
@RoboTux RoboTux merged commit 71e3de8 into llvm:main Nov 19, 2025
10 checks passed
@simpel01 simpel01 deleted the vectorization-reduction-fix branch November 19, 2025 18:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants