-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][vector] Teach TransferOptimization
to forward masked stores
#87794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Benjamin Maxwell (MacDue) ChangesThis only handles one case (that's fairly common in practice*), storing a masked constant splat, then reloading again with the same mask and a padding value that matches the splat.
Full diff: https://github.com/llvm/llvm-project/pull/87794.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e6425879cc67f..1dacafe3d7fabc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,12 +170,37 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
shapedType.getContext());
}
+static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
+ vector::TransferReadOp read) {
+ if (!defWrite.getMask() && !read.getMask())
+ return true; // Success: No masks (values will be the same).
+ // Check for constant splats. These will be the same value if the read is
+ // masked (and padded with the splat value), and the write is unmasked or has
+ // the same mask.
+ bool couldBeSameSplatValue =
+ read.getMask() &&
+ (!defWrite.getMask() || defWrite.getMask() == read.getMask());
+ if (!couldBeSameSplatValue)
+ return false;
+ DenseElementsAttr splatAttr;
+ if (!matchPattern(defWrite.getVector(),
+ m_Constant<DenseElementsAttr>(&splatAttr)) ||
+ !splatAttr.isSplat()) {
+ return false;
+ }
+ Attribute padAttr;
+ if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
+ return false;
+ return padAttr == splatAttr.getSplatValue<Attribute>();
+}
+
bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) {
- return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
- !read.getMask() && defWrite.getIndices() == read.getIndices() &&
+ return !defWrite.hasOutOfBoundsDim() &&
+ defWrite.getIndices() == read.getIndices() &&
defWrite.getVectorType() == read.getVectorType() &&
- defWrite.getPermutationMap() == read.getPermutationMap();
+ defWrite.getPermutationMap() == read.getPermutationMap() &&
+ couldBeSameValueWithMasking(defWrite, read);
}
bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89e..2c8f105cd5c14b 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
// `vector.transfer_write` would not be safe:
// %1 = vector.transfer_read %subview
// vector.transfer_write %1, %alloca
-// vector.transfer_write %vec, %collapse_shape
+// vector.transfer_write %vec, %collapse_shape
// %2 = vector.transfer_read %alloca
// vector.transfer_write %1, %subview
// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,53 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
return
}
+
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %cst_f32 = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ %vscale = vector.vscale
+ vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
+
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+ %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+ %cst_f32 = arith.constant 1.0 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c512 = arith.constant 512 : index
+ %vscale = vector.vscale
+ vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+ %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+ scf.yield %1 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor comment, otherwise LGTM cheers
5790171
to
e4435c3
Compare
e4435c3
to
15387f6
Compare
15387f6
to
884e977
Compare
This only handles one case (that's fairly common in practice*), storing a masked constant splat, then reloading again with the same mask and a padding value that matches the splat. * For SVE/SME (without peeling) this occurs when you have a `linalg.fill` preceding a `linalg.matmul`.
884e977
to
30267ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
This only handles one case (that's fairly common in practice*), storing a masked constant splat, then reloading again with the same mask and a padding value that matches the splat.
linalg.fill
preceding alinalg.matmul
.