Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.sink_vector_producer_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to sink vector producer operations forward in a block to
place them immediately before their first use.
}];

let assemblyFormat = "attr-dict";
}


#endif // X86VECTOR_TRANSFORM_OPS

4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
void populateVectorContractToPackedTypeDotProductPatterns(
RewritePatternSet &patterns);

// Performs forward scheduling of vector producer ops to minimize their live
// range by placing them at their earliest legal use site
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);

//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
}

void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
x86vector::populateSinkVectorProducerOpsPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
LegalizeForLLVMExport.cpp
VectorContractToFMA.cpp
VectorContractToPackedTypeDotProduct.cpp
SinkVectorProducerOps.cpp

LINK_LIBS PUBLIC
MLIRArithDialect
Expand Down
95 changes: 95 additions & 0 deletions mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//===- SinkVectorProducerOps.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;

/// Sink vector producers forward to reduce live ranges.
/// This pattern applies to ops such as vector.load and vector.transfer_read.
template <typename producerOp>
struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> {
using OpRewritePattern<producerOp>::OpRewritePattern;

LogicalResult matchAndRewrite(producerOp op,
PatternRewriter &rewriter) const override {

// Collect all users of the producer op.
llvm::SmallVector<Operation *> users;
for (OpResult result : op->getResults())
for (Operation *user : result.getUsers())
users.push_back(user);

// If there are no users, nothing to sink.
if (users.empty())
return failure();

// If the next op is already a user, do not move.
Operation *nextOp = op->getNextNode();
if (llvm::is_contained(users, nextOp))
return failure();

// Prevent pathological looping:
// If the next op produces values used by any of op's users, don't move.
Copy link
Member

@rengolin rengolin Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the problem:

  %0 = my_op
  %1 = next_op
  ...
  %m = first_user %0, %1

Both %0 and %1 can be in any order, so if %0 moves now, it will become %1 and if you run the pass again, it'll match and go back being %0, in a cycle.

However, this is not only true for the next operation, but any operation between %0 and %m. This is a high polynomial time algorithm, especially with intersection checks on two lists.

What if we just intersect users?

  users = all users of all results of "op";
  other = op;
  while (other =  other->getNextNode()) {
    // If ops share users: O(num_ops * num_users)
    if (is_contained(users, other->getUsers()) {
      // Move "op" just _before_ "nextOp" and stop
      op->moveBefore(other);
      break;
    }
  }

Now, "op" will never pass "other", so if you run again, it will do nothing to "op".

Running the pass multiple times should compact the IR:

  // Original
  %a = my_op
  ...
  %b = my_other_op
  ...
  %c = my_last_op
  ...
  %val = first_user (%a, %b, %c)

  // First pass
  ...
  %a = my_op
  ...
  %b = my_other_op
  ...
  %c = my_last_op
  %val = first_user (%a, %b, %c)

  // Second pass
  ...
  ...
  %a = my_op
  ...
  %b = my_other_op
  %c = my_last_op
  %val = first_user (%a, %b, %c)

  // Third pass
  ...
  ...
  ...
  %a = my_op
  %b = my_other_op
  %c = my_last_op
  %val = first_user (%a, %b, %c)

Alternatively, you can take the first op (%a), get it's first user (%val), and then iterate backwards, moving %c first, then %b, then %a, and you only need to pass once to compact the IR. You can even cache the last producer sunk, and use it for the moveBefore() call without having to look for it again.

Copy link
Contributor Author

@arun-thmn arun-thmn Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes Thanks. If we do intersect between the current op users and next op users, few valid patterns can buy-pass this transformation. For example, in the below code:

Example:
%a = prod1
%b = prod2
%c = prod3
%d = prod4
%c1 = fma %a %c %arg
%c2 = fma %a %d %c1
%c3 = fma %b %c %c2
%c4 = fma %b %d %c3

prod2 users: {c3, c4} and prod3 users: {c1, c3}. If the logic is based on intersect, the rewrite will be:
%b = prod2
%c = prod3
%a = prod1
%c1 = fma %a %c %arg
%d = prod4
%c2 = fma %a %d %c1
%c3 = fma %b %c %c2
%c4 = fma %b %d %c3

prod2 can be moved before %c3, but didn't as %c3 is a user for both of them. So, we check for the first users of both the current and next producer to do the shift.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right. We only need to check is_contained. Intersection would create false positives.

However, here, you're only moving past one operation at a time, so you don't need to check those that do not participate. However, that's wasteful, as you're recalculating the same user list for every move.

In my example above, each pass moves each producer to the best place in one step, and we'd still require multiple passes to move all producers. In your implementation, the pass runs as many times as there are instructions in between producer and first consumer, for every producer, recalculating users and scanning those lists over and over on the same operations.

My proposal is to do a forward pass (while(getNextOp)), collect all producers in program order in a separate ordered list, and the first consumer, and then only look at those, possibly in reverse order.

llvm::SmallVector<Operation *> nextOpUsers;
for (OpResult result : nextOp->getResults())
for (Operation *user : result.getUsers())
nextOpUsers.push_back(user);

Operation *nextFirstUser = nextOp->getNextNode();
while (nextFirstUser) {
if (llvm::is_contained(nextOpUsers, nextFirstUser))
break;

nextFirstUser = nextFirstUser->getNextNode();
}

// Find the nearest user by scanning forward.
while (nextOp) {
if (llvm::is_contained(users, nextOp))
break;

nextOp = nextOp->getNextNode();
}

if (!nextOp)
return failure();

// The Op first user and next Op first user are same. Break here to
// to avoid the shift cycle looping.
if (nextOp == nextFirstUser)
return failure();

// Both ops must be in the same block to safely move.
if (op->getBlock() != nextOp->getBlock())
return failure();

// Move producer immediately before its first user.
op->moveBefore(nextOp);

return success();
}
};

void x86vector::populateSinkVectorProducerOpsPatterns(
RewritePatternSet &patterns) {
patterns.add<SinkVectorProducerOps<vector::TransferReadOp>,
SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext());
}
199 changes: 199 additions & 0 deletions mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s

func.func @sink_vector_loads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
%1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
%2 = vector.load %arg0[%c8, %c0] : memref<16x16xf32>, vector<8xf32>
%3 = vector.load %arg0[%c8, %c8] : memref<16x16xf32>, vector<8xf32>
%4 = vector.fma %0, %1, %arg1 : vector<8xf32>
%5 = vector.fma %2, %3, %4 : vector<8xf32>
return %5 : vector<8xf32>
}

// CHECK-LABEL: @sink_vector_loads
// CHECK: vector.load
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.fma
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.fma

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

// -----

func.func @sink_vector_transfer_reads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%0 = ub.poison : f32
%1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
%2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
%3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
%4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
%5 = vector.fma %1, %2, %arg1 : vector<8xf32>
%6 = vector.fma %3, %4, %5 : vector<8xf32>
return %6 : vector<8xf32>
}

// CHECK-LABEL: @sink_vector_transfer_reads
// CHECK: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.fma
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.fma

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

// -----

func.func @sink_vector_transfer_reads_tensor(%arg0: tensor<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%0 = ub.poison : f32
%1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
%2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
%3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
%4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
%5 = vector.fma %1, %2, %arg1 : vector<8xf32>
%6 = vector.fma %3, %4, %5 : vector<8xf32>
return %6 : vector<8xf32>
}

// CHECK-LABEL: @sink_vector_transfer_reads_tensor
// CHECK: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.fma
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.fma

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

// -----

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>

func.func @sink_vector_transfer_reads_bf16(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> {
%0 = ub.poison : bf16
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16>
%extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16>
%1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
%2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
%3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
%4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
%5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
%6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
%7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
%8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
return %8 : vector<1x16xf32>
}

// CHECK-LABEL: @sink_vector_transfer_reads_bf16
// CHECK: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.contract
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.contract
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.contract
// CHECK-NEXT: vector.contract

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

// -----

func.func @negative_no_infinite_looping(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
%1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
%2 = vector.fma %0, %1, %arg1 : vector<8xf32>
return %2: vector<8xf32>
}

// CHECK-LABEL: @negative_no_infinite_looping
// CHECK: vector.load
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.fma

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

// -----

func.func @negative_no_sink_outside_block(%arg0: memref<8x16xf32>, %arg1: i1) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%0 = vector.load %arg0[%c0, %c0] : memref<8x16xf32>, vector<8xf32>
%1 = vector.load %arg0[%c0, %c8] : memref<8x16xf32>, vector<8xf32>
%2 = scf.if %arg1 -> (vector<8xf32>) {
scf.yield %0 : vector<8xf32>
} else {
scf.yield %1 : vector<8xf32>
}
return %2 : vector<8xf32>
}

// CHECK-LABEL: @negative_no_sink_outside_block
// CHECK: vector.load
// CHECK-NEXT: vector.load
// CHECK-NEXT: scf.if
// CHECK-NEXT: scf.yield

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}