Skip to content

[mlir][vector] Teach TransferOptimization to forward masked stores #87794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 16, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Apr 5, 2024

This only handles one case (that's fairly common in practice*), storing a masked constant splat, then reloading again with the same mask and a padding value that matches the splat.

  • For SVE/SME (without peeling) this occurs when you have a linalg.fill preceding a linalg.matmul.

@llvmbot
Copy link
Member

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Benjamin Maxwell (MacDue)

Changes

This only handles one case (that's fairly common in practice*), storing a masked constant splat, then reloading again with the same mask and a padding value that matches the splat.

  • For SVE/SME (without peeling) this occurs when you have a linalg.fill preceding a linalg.matmul.

Full diff: https://github.com/llvm/llvm-project/pull/87794.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+28-3)
  • (modified) mlir/test/Dialect/Vector/vector-transferop-opt.mlir (+51-1)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e6425879cc67f..1dacafe3d7fabc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -170,12 +170,37 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
       shapedType.getContext());
 }
 
+static bool couldBeSameValueWithMasking(vector::TransferWriteOp defWrite,
+                                        vector::TransferReadOp read) {
+  if (!defWrite.getMask() && !read.getMask())
+    return true; // Success: No masks (values will be the same).
+  // Check for constant splats. These will be the same value if the read is
+  // masked (and padded with the splat value), and the write is unmasked or has
+  // the same mask.
+  bool couldBeSameSplatValue =
+      read.getMask() &&
+      (!defWrite.getMask() || defWrite.getMask() == read.getMask());
+  if (!couldBeSameSplatValue)
+    return false;
+  DenseElementsAttr splatAttr;
+  if (!matchPattern(defWrite.getVector(),
+                    m_Constant<DenseElementsAttr>(&splatAttr)) ||
+      !splatAttr.isSplat()) {
+    return false;
+  }
+  Attribute padAttr;
+  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
+    return false;
+  return padAttr == splatAttr.getSplatValue<Attribute>();
+}
+
 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
                                      vector::TransferReadOp read) {
-  return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
-         !read.getMask() && defWrite.getIndices() == read.getIndices() &&
+  return !defWrite.hasOutOfBoundsDim() &&
+         defWrite.getIndices() == read.getIndices() &&
          defWrite.getVectorType() == read.getVectorType() &&
-         defWrite.getPermutationMap() == read.getPermutationMap();
+         defWrite.getPermutationMap() == read.getPermutationMap() &&
+         couldBeSameValueWithMasking(defWrite, read);
 }
 
 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89e..2c8f105cd5c14b 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
 // `vector.transfer_write` would not be safe:
 //         %1 = vector.transfer_read %subview
 //         vector.transfer_write %1, %alloca
-//         vector.transfer_write %vec, %collapse_shape 
+//         vector.transfer_write %vec, %collapse_shape
 //         %2 = vector.transfer_read %alloca
 //         vector.transfer_write %1, %subview
 // Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,53 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
   vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
   return
 }
+
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %cst_f32 = arith.constant 0.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  %vscale = vector.vscale
+  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// Negative test, the padding does not match the constant splat, so we can't
+// forward the store.
+// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_constant_splat_store_with_masking_negative(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
+  %cst = arith.constant dense<0.0> : vector<[8]x[8]xf32>
+  %cst_f32 = arith.constant 1.0 : f32
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c512 = arith.constant 512 : index
+  %vscale = vector.vscale
+  vector.transfer_write %cst, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%c0, %c0], %cst_f32, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
+  %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
+    %1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
+    scf.yield %1 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}

@MacDue MacDue requested a review from dcaballe April 5, 2024 15:33
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor comment, otherwise LGTM cheers

@MacDue MacDue force-pushed the transfer_opt_part_1 branch from 5790171 to e4435c3 Compare April 15, 2024 13:34
@MacDue MacDue force-pushed the transfer_opt_part_1 branch from e4435c3 to 15387f6 Compare April 29, 2024 15:23
@MacDue MacDue force-pushed the transfer_opt_part_1 branch from 15387f6 to 884e977 Compare May 13, 2024 13:19
MacDue added 4 commits May 15, 2024 13:48
This only handles one case (that's fairly common in practice*), storing
a masked constant splat, then reloading again with the same mask and a
padding value that matches the splat.

* For SVE/SME (without peeling) this occurs when you have a
`linalg.fill` preceding a `linalg.matmul`.
@MacDue MacDue force-pushed the transfer_opt_part_1 branch from 884e977 to 30267ed Compare May 15, 2024 14:55
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@MacDue MacDue merged commit ca02f36 into llvm:main May 16, 2024
@MacDue MacDue deleted the transfer_opt_part_1 branch May 16, 2024 08:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants