-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Convert vector.transfer_read to scalar load and broadcast #159520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
If we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type. It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types. Instead of %s0 = vector.transfer_read %base[] : memref<dtype> to vector<dtype> %s1 = vector.broadcast %s0 : vector<dtype> to vector<d0...d1 x dtype> Use %s0 = memref.load %base[] : memref<dtype> %s1 = vector.broadcast %s0 : dtype to vector<d0...d1 x dtype>
@llvm/pr-subscribers-mlir Author: Hsiangkai Wang (Hsiangkai) ChangesIf we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type. It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types. Instead of Use Full diff: https://github.com/llvm/llvm-project/pull/159520.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 2cf8f0beaa4de..4f62b6a7f2fde 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -360,17 +360,35 @@ struct TransferOpReduceRank
SmallVector<bool> newScalableDims(
originalVecType.getScalableDims().take_back(reducedShapeRank));
- VectorType newReadType = VectorType::get(
- newShape, originalVecType.getElementType(), newScalableDims);
- ArrayAttr newInBoundsAttr =
- op.getInBounds()
- ? rewriter.getArrayAttr(
- op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
- : ArrayAttr();
- Value newRead = vector::TransferReadOp::create(
- rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
- AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
- newInBoundsAttr);
+ Value newRead;
+ if (newShape.size() == 0 && newScalableDims.size() == 0) {
+ // Handle the scalar case.
+ // Convert
+ // %val = vector.transfer_read %base[] : memref<dtype> to
+ // vector<d0 x d1 x dtype>
+ // into
+ // %scalar = memref.load %base[] : memref<dtype>
+ // %val = vector.broadcast %scalar : dtype to vector<d0 x d1 x dtype>
+ Type baseType = op.getBase().getType();
+ if (isa<MemRefType>(baseType)) {
+ newRead = memref::LoadOp::create(rewriter, op.getLoc(), op.getBase(),
+ op.getIndices());
+ }
+ }
+
+ if (!newRead) {
+ VectorType newReadType = VectorType::get(
+ newShape, originalVecType.getElementType(), newScalableDims);
+ ArrayAttr newInBoundsAttr =
+ op.getInBounds()
+ ? rewriter.getArrayAttr(
+ op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
+ : ArrayAttr();
+ newRead = vector::TransferReadOp::create(
+ rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
+ AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
+ newInBoundsAttr);
+ }
return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
newRead)
.getVector();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 3ae18835c8367..16104aa76e692 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -388,6 +388,24 @@ func.func @xfer_read_minor_identitiy_bcast_dims(
return %res : vector<8x4x2x3xf32>
}
+// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_scalar
+// CHECK-SAME: %[[MEM:.*]]: memref<f32>) -> vector<8x4x2x3xf32> {
+// CHECK: %[[LOAD:.*]] = memref.load %[[MEM]][] : memref<f32>
+// CHECK: %[[BC:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<8x4x2x3xf32>
+// CHECK: return %[[BC]] : vector<8x4x2x3xf32>
+func.func @xfer_read_minor_identitiy_bcast_scalar(
+ %mem: memref<f32>) -> vector<8x4x2x3xf32> {
+
+ %pad = arith.constant 0.000000e+00 : f32
+
+ %res = vector.transfer_read %mem[], %pad {
+ in_bounds = [true, true, true, true],
+ permutation_map = affine_map<() -> (0, 0, 0, 0)>
+ } : memref<f32>, vector<8x4x2x3xf32>
+
+ return %res : vector<8x4x2x3xf32>
+}
+
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
|
@llvm/pr-subscribers-mlir-vector Author: Hsiangkai Wang (Hsiangkai) ChangesIf we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type. It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types. Instead of Use Full diff: https://github.com/llvm/llvm-project/pull/159520.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 2cf8f0beaa4de..4f62b6a7f2fde 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -360,17 +360,35 @@ struct TransferOpReduceRank
SmallVector<bool> newScalableDims(
originalVecType.getScalableDims().take_back(reducedShapeRank));
- VectorType newReadType = VectorType::get(
- newShape, originalVecType.getElementType(), newScalableDims);
- ArrayAttr newInBoundsAttr =
- op.getInBounds()
- ? rewriter.getArrayAttr(
- op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
- : ArrayAttr();
- Value newRead = vector::TransferReadOp::create(
- rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
- AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
- newInBoundsAttr);
+ Value newRead;
+ if (newShape.size() == 0 && newScalableDims.size() == 0) {
+ // Handle the scalar case.
+ // Convert
+ // %val = vector.transfer_read %base[] : memref<dtype> to
+ // vector<d0 x d1 x dtype>
+ // into
+ // %scalar = memref.load %base[] : memref<dtype>
+ // %val = vector.broadcast %scalar : dtype to vector<d0 x d1 x dtype>
+ Type baseType = op.getBase().getType();
+ if (isa<MemRefType>(baseType)) {
+ newRead = memref::LoadOp::create(rewriter, op.getLoc(), op.getBase(),
+ op.getIndices());
+ }
+ }
+
+ if (!newRead) {
+ VectorType newReadType = VectorType::get(
+ newShape, originalVecType.getElementType(), newScalableDims);
+ ArrayAttr newInBoundsAttr =
+ op.getInBounds()
+ ? rewriter.getArrayAttr(
+ op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
+ : ArrayAttr();
+ newRead = vector::TransferReadOp::create(
+ rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
+ AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
+ newInBoundsAttr);
+ }
return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
newRead)
.getVector();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 3ae18835c8367..16104aa76e692 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -388,6 +388,24 @@ func.func @xfer_read_minor_identitiy_bcast_dims(
return %res : vector<8x4x2x3xf32>
}
+// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_scalar
+// CHECK-SAME: %[[MEM:.*]]: memref<f32>) -> vector<8x4x2x3xf32> {
+// CHECK: %[[LOAD:.*]] = memref.load %[[MEM]][] : memref<f32>
+// CHECK: %[[BC:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<8x4x2x3xf32>
+// CHECK: return %[[BC]] : vector<8x4x2x3xf32>
+func.func @xfer_read_minor_identitiy_bcast_scalar(
+ %mem: memref<f32>) -> vector<8x4x2x3xf32> {
+
+ %pad = arith.constant 0.000000e+00 : f32
+
+ %res = vector.transfer_read %mem[], %pad {
+ in_bounds = [true, true, true, true],
+ permutation_map = affine_map<() -> (0, 0, 0, 0)>
+ } : memref<f32>, vector<8x4x2x3xf32>
+
+ return %res : vector<8x4x2x3xf32>
+}
+
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me, thanks!
I've left some small suggestions, but since I am away next week, approving as is. No need to wait for me to take another look.
Note, rank-0 vectors have been a bit contentious. Once you address my comments, please give it at least 24hrs before landing (should other reviewers wish to chime in).
Thanks!
// %scalar = memref.load %base[] : memref<dtype> | ||
// %val = vector.broadcast %scalar : dtype to vector<d0 x d1 x dtype> | ||
Type baseType = op.getBase().getType(); | ||
if (isa<MemRefType>(baseType)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not support this for Tensors as well?
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), | ||
newInBoundsAttr); | ||
Value newRead; | ||
if (newShape.size() == 0 && newScalableDims.size() == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no rank-0 scalable vectors, so you can skip the 2nd check.
op.getIndices()); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Handle the non-scalar case. |
(and then, above if (newShape.size() == 0 && newScalableDims.size() == 0) {
, // Handle the scalar case
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure we removed this pattern earlier. This is not the right place to do it. If you want to do it, you should add a pattern from vector.load -> memref.load for scalars.
We should be lowering this to a vector.load on scalars + a vector.broadcast, if we aren't, that's a bug in itself.
Thanks for your review. I will abandon this patch and revisit the pipeline. |
If we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type.
It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types.
Instead of
%s0 = vector.transfer_read %base[] : memref to vector
%s1 = vector.broadcast %s0 : vector to vector<d0...d1 x dtype>
Use
%s0 = memref.load %base[] : memref
%s1 = vector.broadcast %s0 : dtype to vector<d0...d1 x dtype>