Skip to content

Conversation

Hsiangkai
Copy link
Contributor

If we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type.

It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types.

Instead of
%s0 = vector.transfer_read %base[] : memref to vector
%s1 = vector.broadcast %s0 : vector to vector<d0...d1 x dtype>

Use
%s0 = memref.load %base[] : memref
%s1 = vector.broadcast %s0 : dtype to vector<d0...d1 x dtype>

If we use vector.transfer_read to read from a 0-d value, we can convert it
to memref.load from the 0-d value then broadcast the value to the target
vector type.

It can avoid generating vector operations breaking the requirements of
convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all
vector.transfer_read have 2-D vector types.

Instead of
  %s0 = vector.transfer_read %base[] : memref<dtype> to vector<dtype>
  %s1 = vector.broadcast %s0 : vector<dtype> to vector<d0...d1 x dtype>

Use
  %s0 = memref.load %base[] : memref<dtype>
  %s1 = vector.broadcast %s0 : dtype to vector<d0...d1 x dtype>
@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2025

@llvm/pr-subscribers-mlir

Author: Hsiangkai Wang (Hsiangkai)

Changes

If we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type.

It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types.

Instead of
%s0 = vector.transfer_read %base[] : memref<dtype> to vector<dtype>
%s1 = vector.broadcast %s0 : vector<dtype> to vector<d0...d1 x dtype>

Use
%s0 = memref.load %base[] : memref<dtype>
%s1 = vector.broadcast %s0 : dtype to vector<d0...d1 x dtype>


Full diff: https://github.com/llvm/llvm-project/pull/159520.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+29-11)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+18)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 2cf8f0beaa4de..4f62b6a7f2fde 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -360,17 +360,35 @@ struct TransferOpReduceRank
     SmallVector<bool> newScalableDims(
         originalVecType.getScalableDims().take_back(reducedShapeRank));
 
-    VectorType newReadType = VectorType::get(
-        newShape, originalVecType.getElementType(), newScalableDims);
-    ArrayAttr newInBoundsAttr =
-        op.getInBounds()
-            ? rewriter.getArrayAttr(
-                  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
-            : ArrayAttr();
-    Value newRead = vector::TransferReadOp::create(
-        rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
-        AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
-        newInBoundsAttr);
+    Value newRead;
+    if (newShape.size() == 0 && newScalableDims.size() == 0) {
+      // Handle the scalar case.
+      // Convert
+      //   %val = vector.transfer_read %base[] : memref<dtype> to
+      //                                         vector<d0 x d1 x dtype>
+      // into
+      //   %scalar = memref.load %base[] : memref<dtype>
+      //   %val = vector.broadcast %scalar : dtype to vector<d0 x d1 x dtype>
+      Type baseType = op.getBase().getType();
+      if (isa<MemRefType>(baseType)) {
+        newRead = memref::LoadOp::create(rewriter, op.getLoc(), op.getBase(),
+                                         op.getIndices());
+      }
+    }
+
+    if (!newRead) {
+      VectorType newReadType = VectorType::get(
+          newShape, originalVecType.getElementType(), newScalableDims);
+      ArrayAttr newInBoundsAttr =
+          op.getInBounds()
+              ? rewriter.getArrayAttr(
+                    op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
+              : ArrayAttr();
+      newRead = vector::TransferReadOp::create(
+          rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
+          AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
+          newInBoundsAttr);
+    }
     return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
                                        newRead)
         .getVector();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 3ae18835c8367..16104aa76e692 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -388,6 +388,24 @@ func.func @xfer_read_minor_identitiy_bcast_dims(
   return %res : vector<8x4x2x3xf32>
 }
 
+// CHECK-LABEL:   func.func @xfer_read_minor_identitiy_bcast_scalar
+//  CHECK-SAME:     %[[MEM:.*]]: memref<f32>) -> vector<8x4x2x3xf32> {
+//       CHECK:     %[[LOAD:.*]] = memref.load %[[MEM]][] : memref<f32>
+//       CHECK:     %[[BC:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<8x4x2x3xf32>
+//       CHECK:     return %[[BC]] : vector<8x4x2x3xf32>
+func.func @xfer_read_minor_identitiy_bcast_scalar(
+    %mem: memref<f32>) -> vector<8x4x2x3xf32> {
+
+  %pad = arith.constant 0.000000e+00 : f32
+
+  %res = vector.transfer_read %mem[], %pad {
+    in_bounds = [true, true, true, true],
+    permutation_map = affine_map<() -> (0, 0, 0, 0)>
+  } : memref<f32>, vector<8x4x2x3xf32>
+
+  return %res : vector<8x4x2x3xf32>
+}
+
 // CHECK-LABEL:   func.func @xfer_read_minor_identitiy_bcast_dims_scalable
 //  CHECK-SAME:     %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
 //       CHECK:     %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>

@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2025

@llvm/pr-subscribers-mlir-vector

Author: Hsiangkai Wang (Hsiangkai)

Changes

If we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type.

It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types.

Instead of
%s0 = vector.transfer_read %base[] : memref<dtype> to vector<dtype>
%s1 = vector.broadcast %s0 : vector<dtype> to vector<d0...d1 x dtype>

Use
%s0 = memref.load %base[] : memref<dtype>
%s1 = vector.broadcast %s0 : dtype to vector<d0...d1 x dtype>


Full diff: https://github.com/llvm/llvm-project/pull/159520.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+29-11)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+18)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 2cf8f0beaa4de..4f62b6a7f2fde 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -360,17 +360,35 @@ struct TransferOpReduceRank
     SmallVector<bool> newScalableDims(
         originalVecType.getScalableDims().take_back(reducedShapeRank));
 
-    VectorType newReadType = VectorType::get(
-        newShape, originalVecType.getElementType(), newScalableDims);
-    ArrayAttr newInBoundsAttr =
-        op.getInBounds()
-            ? rewriter.getArrayAttr(
-                  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
-            : ArrayAttr();
-    Value newRead = vector::TransferReadOp::create(
-        rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
-        AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
-        newInBoundsAttr);
+    Value newRead;
+    if (newShape.size() == 0 && newScalableDims.size() == 0) {
+      // Handle the scalar case.
+      // Convert
+      //   %val = vector.transfer_read %base[] : memref<dtype> to
+      //                                         vector<d0 x d1 x dtype>
+      // into
+      //   %scalar = memref.load %base[] : memref<dtype>
+      //   %val = vector.broadcast %scalar : dtype to vector<d0 x d1 x dtype>
+      Type baseType = op.getBase().getType();
+      if (isa<MemRefType>(baseType)) {
+        newRead = memref::LoadOp::create(rewriter, op.getLoc(), op.getBase(),
+                                         op.getIndices());
+      }
+    }
+
+    if (!newRead) {
+      VectorType newReadType = VectorType::get(
+          newShape, originalVecType.getElementType(), newScalableDims);
+      ArrayAttr newInBoundsAttr =
+          op.getInBounds()
+              ? rewriter.getArrayAttr(
+                    op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
+              : ArrayAttr();
+      newRead = vector::TransferReadOp::create(
+          rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
+          AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
+          newInBoundsAttr);
+    }
     return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
                                        newRead)
         .getVector();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 3ae18835c8367..16104aa76e692 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -388,6 +388,24 @@ func.func @xfer_read_minor_identitiy_bcast_dims(
   return %res : vector<8x4x2x3xf32>
 }
 
+// CHECK-LABEL:   func.func @xfer_read_minor_identitiy_bcast_scalar
+//  CHECK-SAME:     %[[MEM:.*]]: memref<f32>) -> vector<8x4x2x3xf32> {
+//       CHECK:     %[[LOAD:.*]] = memref.load %[[MEM]][] : memref<f32>
+//       CHECK:     %[[BC:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<8x4x2x3xf32>
+//       CHECK:     return %[[BC]] : vector<8x4x2x3xf32>
+func.func @xfer_read_minor_identitiy_bcast_scalar(
+    %mem: memref<f32>) -> vector<8x4x2x3xf32> {
+
+  %pad = arith.constant 0.000000e+00 : f32
+
+  %res = vector.transfer_read %mem[], %pad {
+    in_bounds = [true, true, true, true],
+    permutation_map = affine_map<() -> (0, 0, 0, 0)>
+  } : memref<f32>, vector<8x4x2x3xf32>
+
+  return %res : vector<8x4x2x3xf32>
+}
+
 // CHECK-LABEL:   func.func @xfer_read_minor_identitiy_bcast_dims_scalable
 //  CHECK-SAME:     %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
 //       CHECK:     %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me, thanks!

I've left some small suggestions, but since I am away next week, approving as is. No need to wait for me to take another look.

Note, rank-0 vectors have been a bit contentious. Once you address my comments, please give it at least 24hrs before landing (should other reviewers wish to chime in).

Thanks!

// %scalar = memref.load %base[] : memref<dtype>
// %val = vector.broadcast %scalar : dtype to vector<d0 x d1 x dtype>
Type baseType = op.getBase().getType();
if (isa<MemRefType>(baseType)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not support this for Tensors as well?

AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
Value newRead;
if (newShape.size() == 0 && newScalableDims.size() == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no rank-0 scalable vectors, so you can skip the 2nd check.

op.getIndices());
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Handle the non-scalar case.

(and then, above if (newShape.size() == 0 && newScalableDims.size() == 0) {, // Handle the scalar case).

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure we removed this pattern earlier. This is not the right place to do it. If you want to do it, you should add a pattern from vector.load -> memref.load for scalars.

We should be lowering this to a vector.load on scalars + a vector.broadcast, if we aren't, that's a bug in itself.

@Hsiangkai
Copy link
Contributor Author

I'm pretty sure we removed this pattern earlier. This is not the right place to do it. If you want to do it, you should add a pattern from vector.load -> memref.load for scalars.

We should be lowering this to a vector.load on scalars + a vector.broadcast, if we aren't, that's a bug in itself.

Thanks for your review. I will abandon this patch and revisit the pipeline.

@Hsiangkai Hsiangkai closed this Sep 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants