Skip to content

Conversation

@arun-thmn
Copy link
Contributor

@arun-thmn arun-thmn commented Nov 24, 2025

Adds a pattern that sinks vector producer ops (vector.load and vector.transfer_read) forward in a block to their earliest legal use, reducing live ranges and improving scheduling opportunities.

The lowering pattern: batch_reduce.matmul (input) -> register-tiling(M, N) -> Vectorization (to vector.contract) -> unroll vector.contract (unit dims) -> hoisting transformation (move C loads/store outside batch/k loop) -> sink vector producers -> apply licm, canonicalization, and bufferize -> vector.contract to fma -> sink vector producers.

@github-actions
Copy link

github-actions bot commented Nov 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@arun-thmn
Copy link
Contributor Author

@rengolin @adam-smnk @rolfmorel @shahidact Please review.

@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2025

@llvm/pr-subscribers-mlir

Author: Arun Thangamani (arun-thmn)

Changes

Adds a pattern that sinks vector producer ops (vector.load and vector.transfer_read) forward in a block to their earliest legal use, reducing live ranges and improving scheduling opportunities.

The lowering pattern: batch_reduce.matmul (input) -> register-tiling(M, N) -> Vectorization (to vector.contract) -> unroll vector.contract (unit dims) -> hoisting transformation (move C loads/store outside batch/k loop) -> sink vector producers -> apply licm, canonicalization, and bufferize -> vector.contract to fma -> sink vector producers.


Full diff: https://github.com/llvm/llvm-project/pull/169333.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td (+11)
  • (modified) mlir/include/mlir/Dialect/X86Vector/Transforms.h (+4)
  • (modified) mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp (+93)
  • (added) mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir (+199)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c5294ff14fc7..12ba5e9f11141 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -38,6 +38,17 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.sink_vector_producer_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to sink vector producer operations forward in a block to 
+         place them immediately before their first use.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86VECTOR_TRANSFORM_OPS
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index fc46dff63c2b7..b9c9054f57890 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -91,6 +91,10 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
 void populateVectorContractToPackedTypeDotProductPatterns(
     RewritePatternSet &patterns);
 
+// Performs forward scheduling of vector producer ops to minimize their live
+// range by placing them at their earliest legal use site
+void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 95db208207672..25772f2aa57f4 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -32,6 +32,11 @@ void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
   x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
 }
 
+void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  x86vector::populateSinkVectorProducerOpsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index 2cab50fb591c4..cc4d3cac0f7ea 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   LegalizeForLLVMExport.cpp
   VectorContractToFMA.cpp
   VectorContractToPackedTypeDotProduct.cpp
+  SinkVectorProducerOps.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
new file mode 100644
index 0000000000000..b31636958e158
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
@@ -0,0 +1,93 @@
+//===- SinkVectorProducerOps.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+/// Sink vector producers forward to reduce live ranges.
+/// This pattern applies to ops such as vector.load and vector.transfer_read.
+template <typename producerOp>
+struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> {
+  using OpRewritePattern<producerOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(producerOp op,
+                                PatternRewriter &rewriter) const override {
+
+    // Collect all users of the producer op.
+    llvm::SmallVector<Operation *> users;
+    for (OpResult result : op->getResults())
+      for (Operation *user : result.getUsers())
+        users.push_back(user);
+
+    // If there are no users, nothing to sink.
+    if (users.empty())
+      return failure();
+
+    // If the next op is already a user, do not move.
+    Operation *nextOp = op->getNextNode();
+    if (llvm::is_contained(users, nextOp))
+      return failure();
+
+    // Prevent pathological looping:
+    // If the next op produces values used by any of op's users, don't move.
+    llvm::SmallVector<Operation *> nextOpUsers;
+    for (OpResult result : nextOp->getResults())
+      for (Operation *user : result.getUsers())
+        nextOpUsers.push_back(user);
+
+    Operation *nextFirstUser = nextOp->getNextNode();
+    while (nextFirstUser) {
+      if (llvm::is_contained(nextOpUsers, nextFirstUser))
+        break;
+
+      nextFirstUser = nextFirstUser->getNextNode();
+    }
+
+    if (llvm::is_contained(users, nextFirstUser))
+      return failure();
+
+    // Find the nearest user by scanning forward.
+    while (nextOp) {
+      if (llvm::is_contained(users, nextOp))
+        break;
+
+      nextOp = nextOp->getNextNode();
+    }
+
+    if (!nextOp)
+      return failure();
+
+    // // Both ops must be in the same block to safely move.
+    if (op->getBlock() != nextOp->getBlock())
+      return failure();
+
+    // Move producer immediately before its first user.
+    op->moveBefore(nextOp);
+
+    return success();
+  }
+};
+
+void x86vector::populateSinkVectorProducerOpsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<SinkVectorProducerOps<vector::TransferReadOp>,
+               SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir b/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
new file mode 100644
index 0000000000000..11af315e69e66
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
@@ -0,0 +1,199 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+func.func @sink_vector_loads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
+  %1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
+  %2 = vector.load %arg0[%c8, %c0] : memref<16x16xf32>, vector<8xf32>
+  %3 = vector.load %arg0[%c8, %c8] : memref<16x16xf32>, vector<8xf32>
+  %4 = vector.fma %0, %1, %arg1 : vector<8xf32>
+  %5 = vector.fma %2, %3, %4 : vector<8xf32>
+  return %5 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_loads
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @sink_vector_transfer_reads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+  %2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+  %3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+  %4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+  %5 = vector.fma %1, %2, %arg1 : vector<8xf32>
+  %6 = vector.fma %3, %4, %5 : vector<8xf32>
+  return %6 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @sink_vector_transfer_reads_tensor(%arg0: tensor<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+  %2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+  %3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+  %4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+  %5 = vector.fma %1, %2, %arg1 : vector<8xf32>
+  %6 = vector.fma %3, %4, %5 : vector<8xf32>
+  return %6 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads_tensor
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @sink_vector_transfer_reads_bf16(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> {
+  %0 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c16 = arith.constant 16 : index
+  %extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16>
+  %extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16>
+  %1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+  %2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+  %3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+  %4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+  %5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  %7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  return %8 : vector<1x16xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads_bf16
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @negative_no_infinite_looping(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
+  %1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
+  %2 = vector.fma %0, %1, %arg1 : vector<8xf32>
+  return %2: vector<8xf32>
+}
+
+// CHECK-LABEL: @negative_no_infinite_looping
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @negative_no_sink_outside_block(%arg0: memref<8x16xf32>, %arg1: i1) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<8x16xf32>, vector<8xf32>
+  %1 = vector.load %arg0[%c0, %c8] : memref<8x16xf32>, vector<8xf32>
+  %2 = scf.if %arg1 -> (vector<8xf32>) {
+    scf.yield %0 : vector<8xf32>
+  } else {
+    scf.yield %1 : vector<8xf32>
+  }
+  return %2 : vector<8xf32>
+}
+
+// CHECK-LABEL: @negative_no_sink_outside_block
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: scf.if
+// CHECK-NEXT: scf.yield
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+

@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2025

@llvm/pr-subscribers-mlir-vector

Author: Arun Thangamani (arun-thmn)

Changes

Adds a pattern that sinks vector producer ops (vector.load and vector.transfer_read) forward in a block to their earliest legal use, reducing live ranges and improving scheduling opportunities.

The lowering pattern: batch_reduce.matmul (input) -> register-tiling(M, N) -> Vectorization (to vector.contract) -> unroll vector.contract (unit dims) -> hoisting transformation (move C loads/store outside batch/k loop) -> sink vector producers -> apply licm, canonicalization, and bufferize -> vector.contract to fma -> sink vector producers.


Full diff: https://github.com/llvm/llvm-project/pull/169333.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td (+11)
  • (modified) mlir/include/mlir/Dialect/X86Vector/Transforms.h (+4)
  • (modified) mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp (+93)
  • (added) mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir (+199)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c5294ff14fc7..12ba5e9f11141 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -38,6 +38,17 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.sink_vector_producer_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to sink vector producer operations forward in a block to 
+         place them immediately before their first use.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86VECTOR_TRANSFORM_OPS
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index fc46dff63c2b7..b9c9054f57890 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -91,6 +91,10 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
 void populateVectorContractToPackedTypeDotProductPatterns(
     RewritePatternSet &patterns);
 
+// Performs forward scheduling of vector producer ops to minimize their live
+// range by placing them at their earliest legal use site
+void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 95db208207672..25772f2aa57f4 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -32,6 +32,11 @@ void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
   x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
 }
 
+void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  x86vector::populateSinkVectorProducerOpsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index 2cab50fb591c4..cc4d3cac0f7ea 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   LegalizeForLLVMExport.cpp
   VectorContractToFMA.cpp
   VectorContractToPackedTypeDotProduct.cpp
+  SinkVectorProducerOps.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
new file mode 100644
index 0000000000000..b31636958e158
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
@@ -0,0 +1,93 @@
+//===- SinkVectorProducerOps.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+/// Sink vector producers forward to reduce live ranges.
+/// This pattern applies to ops such as vector.load and vector.transfer_read.
+template <typename producerOp>
+struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> {
+  using OpRewritePattern<producerOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(producerOp op,
+                                PatternRewriter &rewriter) const override {
+
+    // Collect all users of the producer op.
+    llvm::SmallVector<Operation *> users;
+    for (OpResult result : op->getResults())
+      for (Operation *user : result.getUsers())
+        users.push_back(user);
+
+    // If there are no users, nothing to sink.
+    if (users.empty())
+      return failure();
+
+    // If the next op is already a user, do not move.
+    Operation *nextOp = op->getNextNode();
+    if (llvm::is_contained(users, nextOp))
+      return failure();
+
+    // Prevent pathological looping:
+    // If the next op produces values used by any of op's users, don't move.
+    llvm::SmallVector<Operation *> nextOpUsers;
+    for (OpResult result : nextOp->getResults())
+      for (Operation *user : result.getUsers())
+        nextOpUsers.push_back(user);
+
+    Operation *nextFirstUser = nextOp->getNextNode();
+    while (nextFirstUser) {
+      if (llvm::is_contained(nextOpUsers, nextFirstUser))
+        break;
+
+      nextFirstUser = nextFirstUser->getNextNode();
+    }
+
+    if (llvm::is_contained(users, nextFirstUser))
+      return failure();
+
+    // Find the nearest user by scanning forward.
+    while (nextOp) {
+      if (llvm::is_contained(users, nextOp))
+        break;
+
+      nextOp = nextOp->getNextNode();
+    }
+
+    if (!nextOp)
+      return failure();
+
+    // // Both ops must be in the same block to safely move.
+    if (op->getBlock() != nextOp->getBlock())
+      return failure();
+
+    // Move producer immediately before its first user.
+    op->moveBefore(nextOp);
+
+    return success();
+  }
+};
+
+void x86vector::populateSinkVectorProducerOpsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<SinkVectorProducerOps<vector::TransferReadOp>,
+               SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir b/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
new file mode 100644
index 0000000000000..11af315e69e66
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
@@ -0,0 +1,199 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+func.func @sink_vector_loads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
+  %1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
+  %2 = vector.load %arg0[%c8, %c0] : memref<16x16xf32>, vector<8xf32>
+  %3 = vector.load %arg0[%c8, %c8] : memref<16x16xf32>, vector<8xf32>
+  %4 = vector.fma %0, %1, %arg1 : vector<8xf32>
+  %5 = vector.fma %2, %3, %4 : vector<8xf32>
+  return %5 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_loads
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @sink_vector_transfer_reads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+  %2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+  %3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+  %4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+  %5 = vector.fma %1, %2, %arg1 : vector<8xf32>
+  %6 = vector.fma %3, %4, %5 : vector<8xf32>
+  return %6 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @sink_vector_transfer_reads_tensor(%arg0: tensor<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+  %2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+  %3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+  %4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+  %5 = vector.fma %1, %2, %arg1 : vector<8xf32>
+  %6 = vector.fma %3, %4, %5 : vector<8xf32>
+  return %6 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads_tensor
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @sink_vector_transfer_reads_bf16(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> {
+  %0 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c16 = arith.constant 16 : index
+  %extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16>
+  %extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16>
+  %1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+  %2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+  %3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+  %4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+  %5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  %7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  return %8 : vector<1x16xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads_bf16
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @negative_no_infinite_looping(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
+  %1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
+  %2 = vector.fma %0, %1, %arg1 : vector<8xf32>
+  return %2: vector<8xf32>
+}
+
+// CHECK-LABEL: @negative_no_infinite_looping
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @negative_no_sink_outside_block(%arg0: memref<8x16xf32>, %arg1: i1) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<8x16xf32>, vector<8xf32>
+  %1 = vector.load %arg0[%c0, %c8] : memref<8x16xf32>, vector<8xf32>
+  %2 = scf.if %arg1 -> (vector<8xf32>) {
+    scf.yield %0 : vector<8xf32>
+  } else {
+    scf.yield %1 : vector<8xf32>
+  }
+  return %2 : vector<8xf32>
+}
+
+// CHECK-LABEL: @negative_no_sink_outside_block
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: scf.if
+// CHECK-NEXT: scf.yield
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.sink_vector_producer_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+

@rengolin
Copy link
Member

How does this relate to #163382?

return failure();

// Prevent pathological looping:
// If the next op produces values used by any of op's users, don't move.
Copy link
Member

@rengolin rengolin Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the problem:

  %0 = my_op
  %1 = next_op
  ...
  %m = first_user %0, %1

Both %0 and %1 can be in any order, so if %0 moves now, it will become %1 and if you run the pass again, it'll match and go back being %0, in a cycle.

However, this is not only true for the next operation, but any operation between %0 and %m. This is a high polynomial time algorithm, especially with intersection checks on two lists.

What if we just intersect users?

  users = all users of all results of "op";
  other = op;
  while (other =  other->getNextNode()) {
    // If ops share users: O(num_ops * num_users)
    if (is_contained(users, other->getUsers()) {
      // Move "op" just _before_ "nextOp" and stop
      op->moveBefore(other);
      break;
    }
  }

Now, "op" will never pass "other", so if you run again, it will do nothing to "op".

Running the pass multiple times should compact the IR:

  // Original
  %a = my_op
  ...
  %b = my_other_op
  ...
  %c = my_last_op
  ...
  %val = first_user (%a, %b, %c)

  // First pass
  ...
  %a = my_op
  ...
  %b = my_other_op
  ...
  %c = my_last_op
  %val = first_user (%a, %b, %c)

  // Second pass
  ...
  ...
  %a = my_op
  ...
  %b = my_other_op
  %c = my_last_op
  %val = first_user (%a, %b, %c)

  // Third pass
  ...
  ...
  ...
  %a = my_op
  %b = my_other_op
  %c = my_last_op
  %val = first_user (%a, %b, %c)

Alternatively, you can take the first op (%a), get it's first user (%val), and then iterate backwards, moving %c first, then %b, then %a, and you only need to pass once to compact the IR. You can even cache the last producer sunk, and use it for the moveBefore() call without having to look for it again.

Copy link
Contributor Author

@arun-thmn arun-thmn Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes Thanks. If we do intersect between the current op users and next op users, few valid patterns can buy-pass this transformation. For example, in the below code:

Example:
%a = prod1
%b = prod2
%c = prod3
%d = prod4
%c1 = fma %a %c %arg
%c2 = fma %a %d %c1
%c3 = fma %b %c %c2
%c4 = fma %b %d %c3

prod2 users: {c3, c4} and prod3 users: {c1, c3}. If the logic is based on intersect, the rewrite will be:
%b = prod2
%c = prod3
%a = prod1
%c1 = fma %a %c %arg
%d = prod4
%c2 = fma %a %d %c1
%c3 = fma %b %c %c2
%c4 = fma %b %d %c3

prod2 can be moved before %c3, but didn't as %c3 is a user for both of them. So, we check for the first users of both the current and next producer to do the shift.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right. We only need to check is_contained. Intersection would create false positives.

However, here, you're only moving past one operation at a time, so you don't need to check those that do not participate. However, that's wasteful, as you're recalculating the same user list for every move.

In my example above, each pass moves each producer to the best place in one step, and we'd still require multiple passes to move all producers. In your implementation, the pass runs as many times as there are instructions in between producer and first consumer, for every producer, recalculating users and scanning those lists over and over on the same operations.

My proposal is to do a forward pass (while(getNextOp)), collect all producers in program order in a separate ordered list, and the first consumer, and then only look at those, possibly in reverse order.

@arun-thmn
Copy link
Contributor Author

How does this relate to #163382?

PR: #163382 is an old one (lowers only specific pattern). I have cancelled it.
This patch is invoked before/after PR: #168074

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants