-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][x86vector] Sink Vector.transfer_reads and vector.load before the consumer #169333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
arun-thmn
wants to merge
4
commits into
llvm:main
Choose a base branch
from
arun-thmn:vector-sink-prod
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| //===- SinkVectorProducerOps.cpp ------------------------------------------===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
| #include "mlir/Dialect/Vector/Utils/VectorUtils.h" | ||
| #include "mlir/Dialect/X86Vector/Transforms.h" | ||
| #include "mlir/Dialect/X86Vector/X86VectorDialect.h" | ||
|
|
||
| #include "mlir/IR/BuiltinAttributes.h" | ||
| #include "mlir/IR/Dominance.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
|
|
||
| #include "mlir/Pass/Pass.h" | ||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
|
||
| using namespace mlir; | ||
| using namespace mlir::vector; | ||
| using namespace mlir::x86vector; | ||
|
|
||
| /// Sink vector producers forward to reduce live ranges. | ||
| /// This pattern applies to ops such as vector.load and vector.transfer_read. | ||
| template <typename producerOp> | ||
| struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> { | ||
| using OpRewritePattern<producerOp>::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(producerOp op, | ||
| PatternRewriter &rewriter) const override { | ||
|
|
||
| // Collect all users of the producer op. | ||
| llvm::SmallVector<Operation *> users; | ||
| for (OpResult result : op->getResults()) | ||
| for (Operation *user : result.getUsers()) | ||
| users.push_back(user); | ||
|
|
||
| // If there are no users, nothing to sink. | ||
| if (users.empty()) | ||
| return failure(); | ||
|
|
||
| // If the next op is already a user, do not move. | ||
| Operation *nextOp = op->getNextNode(); | ||
| if (llvm::is_contained(users, nextOp)) | ||
| return failure(); | ||
|
|
||
| // Prevent pathological looping: | ||
| // If the next op produces values used by any of op's users, don't move. | ||
| llvm::SmallVector<Operation *> nextOpUsers; | ||
| for (OpResult result : nextOp->getResults()) | ||
| for (Operation *user : result.getUsers()) | ||
| nextOpUsers.push_back(user); | ||
|
|
||
| Operation *nextFirstUser = nextOp->getNextNode(); | ||
| while (nextFirstUser) { | ||
| if (llvm::is_contained(nextOpUsers, nextFirstUser)) | ||
| break; | ||
|
|
||
| nextFirstUser = nextFirstUser->getNextNode(); | ||
| } | ||
|
|
||
| // Find the nearest user by scanning forward. | ||
| while (nextOp) { | ||
| if (llvm::is_contained(users, nextOp)) | ||
| break; | ||
|
|
||
| nextOp = nextOp->getNextNode(); | ||
| } | ||
|
|
||
| if (!nextOp) | ||
| return failure(); | ||
|
|
||
| // The Op first user and next Op first user are same. Break here to | ||
| // to avoid the shift cycle looping. | ||
| if (nextOp == nextFirstUser) | ||
| return failure(); | ||
|
|
||
| // Both ops must be in the same block to safely move. | ||
| if (op->getBlock() != nextOp->getBlock()) | ||
| return failure(); | ||
|
|
||
| // Move producer immediately before its first user. | ||
| op->moveBefore(nextOp); | ||
|
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| void x86vector::populateSinkVectorProducerOpsPatterns( | ||
| RewritePatternSet &patterns) { | ||
| patterns.add<SinkVectorProducerOps<vector::TransferReadOp>, | ||
| SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext()); | ||
| } | ||
199 changes: 199 additions & 0 deletions
199
mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,199 @@ | ||
| // RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s | ||
|
|
||
| func.func @sink_vector_loads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %c8 = arith.constant 8 : index | ||
| %0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32> | ||
| %1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32> | ||
| %2 = vector.load %arg0[%c8, %c0] : memref<16x16xf32>, vector<8xf32> | ||
| %3 = vector.load %arg0[%c8, %c8] : memref<16x16xf32>, vector<8xf32> | ||
| %4 = vector.fma %0, %1, %arg1 : vector<8xf32> | ||
| %5 = vector.fma %2, %3, %4 : vector<8xf32> | ||
| return %5 : vector<8xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: @sink_vector_loads | ||
| // CHECK: vector.load | ||
| // CHECK-NEXT: vector.load | ||
| // CHECK-NEXT: vector.fma | ||
| // CHECK-NEXT: vector.load | ||
| // CHECK-NEXT: vector.load | ||
| // CHECK-NEXT: vector.fma | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op | ||
| transform.apply_patterns to %0 { | ||
| transform.apply_patterns.x86vector.sink_vector_producer_ops | ||
| } : !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @sink_vector_transfer_reads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %c8 = arith.constant 8 : index | ||
| %0 = ub.poison : f32 | ||
| %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32> | ||
| %2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32> | ||
| %3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32> | ||
| %4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32> | ||
| %5 = vector.fma %1, %2, %arg1 : vector<8xf32> | ||
| %6 = vector.fma %3, %4, %5 : vector<8xf32> | ||
| return %6 : vector<8xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: @sink_vector_transfer_reads | ||
| // CHECK: vector.transfer_read | ||
| // CHECK-NEXT: vector.transfer_read | ||
| // CHECK-NEXT: vector.fma | ||
| // CHECK-NEXT: vector.transfer_read | ||
| // CHECK-NEXT: vector.transfer_read | ||
| // CHECK-NEXT: vector.fma | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op | ||
| transform.apply_patterns to %0 { | ||
| transform.apply_patterns.x86vector.sink_vector_producer_ops | ||
| } : !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @sink_vector_transfer_reads_tensor(%arg0: tensor<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %c8 = arith.constant 8 : index | ||
| %0 = ub.poison : f32 | ||
| %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32> | ||
| %2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32> | ||
| %3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32> | ||
| %4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32> | ||
| %5 = vector.fma %1, %2, %arg1 : vector<8xf32> | ||
| %6 = vector.fma %3, %4, %5 : vector<8xf32> | ||
| return %6 : vector<8xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: @sink_vector_transfer_reads_tensor | ||
| // CHECK: vector.transfer_read | ||
| // CHECK-NEXT: vector.transfer_read | ||
| // CHECK-NEXT: vector.fma | ||
| // CHECK-NEXT: vector.transfer_read | ||
| // CHECK-NEXT: vector.transfer_read | ||
| // CHECK-NEXT: vector.fma | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op | ||
| transform.apply_patterns to %0 { | ||
| transform.apply_patterns.x86vector.sink_vector_producer_ops | ||
| } : !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> | ||
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> | ||
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> | ||
|
|
||
| func.func @sink_vector_transfer_reads_bf16(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> { | ||
| %0 = ub.poison : bf16 | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
| %c16 = arith.constant 16 : index | ||
| %extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16> | ||
| %extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16> | ||
| %1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16> | ||
| %2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16> | ||
| %3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16> | ||
| %4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16> | ||
| %5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32> | ||
| %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32> | ||
| %7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32> | ||
| %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32> | ||
| return %8 : vector<1x16xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: @sink_vector_transfer_reads_bf16 | ||
| // CHECK: vector.transfer_read | ||
| // CHECK-NEXT: vector.transfer_read | ||
| // CHECK-NEXT: vector.contract | ||
| // CHECK-NEXT: vector.transfer_read | ||
| // CHECK-NEXT: vector.contract | ||
| // CHECK-NEXT: vector.transfer_read | ||
| // CHECK-NEXT: vector.contract | ||
| // CHECK-NEXT: vector.contract | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op | ||
| transform.apply_patterns to %0 { | ||
| transform.apply_patterns.x86vector.sink_vector_producer_ops | ||
| } : !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @negative_no_infinite_looping(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %c8 = arith.constant 8 : index | ||
| %0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32> | ||
| %1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32> | ||
| %2 = vector.fma %0, %1, %arg1 : vector<8xf32> | ||
| return %2: vector<8xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: @negative_no_infinite_looping | ||
| // CHECK: vector.load | ||
| // CHECK-NEXT: vector.load | ||
| // CHECK-NEXT: vector.fma | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op | ||
| transform.apply_patterns to %0 { | ||
| transform.apply_patterns.x86vector.sink_vector_producer_ops | ||
| } : !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @negative_no_sink_outside_block(%arg0: memref<8x16xf32>, %arg1: i1) -> vector<8xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %c8 = arith.constant 8 : index | ||
| %0 = vector.load %arg0[%c0, %c0] : memref<8x16xf32>, vector<8xf32> | ||
| %1 = vector.load %arg0[%c0, %c8] : memref<8x16xf32>, vector<8xf32> | ||
| %2 = scf.if %arg1 -> (vector<8xf32>) { | ||
| scf.yield %0 : vector<8xf32> | ||
| } else { | ||
| scf.yield %1 : vector<8xf32> | ||
| } | ||
| return %2 : vector<8xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: @negative_no_sink_outside_block | ||
| // CHECK: vector.load | ||
| // CHECK-NEXT: vector.load | ||
| // CHECK-NEXT: scf.if | ||
| // CHECK-NEXT: scf.yield | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op | ||
| transform.apply_patterns to %0 { | ||
| transform.apply_patterns.x86vector.sink_vector_producer_ops | ||
| } : !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the problem:
Both %0 and %1 can be in any order, so if %0 moves now, it will become %1 and if you run the pass again, it'll match and go back being %0, in a cycle.
However, this is not only true for the next operation, but any operation between %0 and %m. This is a high polynomial time algorithm, especially with intersection checks on two lists.
What if we just intersect users?
Now, "op" will never pass "other", so if you run again, it will do nothing to "op".
Running the pass multiple times should compact the IR:
Alternatively, you can take the first op (
%a), get it's first user (%val), and then iterate backwards, moving%cfirst, then%b, then%a, and you only need to pass once to compact the IR. You can even cache the last producer sunk, and use it for themoveBefore()call without having to look for it again.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes Thanks. If we do
intersectbetween thecurrent op usersandnext op users, few valid patterns can buy-pass this transformation. For example, in the below code:Example:
%a = prod1%b = prod2%c = prod3%d = prod4%c1 = fma %a %c %arg%c2 = fma %a %d %c1%c3 = fma %b %c %c2%c4 = fma %b %d %c3prod2 users: {
c3, c4} and prod3 users: {c1, c3}. If the logic is based on intersect, the rewrite will be:%b = prod2%c = prod3%a = prod1%c1 = fma %a %c %arg%d = prod4%c2 = fma %a %d %c1%c3 = fma %b %c %c2%c4 = fma %b %d %c3prod2can be moved before%c3, but didn't as%c3is a user for both of them. So, we check for thefirstusers of both thecurrentandnextproducer to do the shift.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you're right. We only need to check
is_contained. Intersection would create false positives.However, here, you're only moving past one operation at a time, so you don't need to check those that do not participate. However, that's wasteful, as you're recalculating the same user list for every move.
In my example above, each pass moves each producer to the best place in one step, and we'd still require multiple passes to move all producers. In your implementation, the pass runs as many times as there are instructions in between producer and first consumer, for every producer, recalculating users and scanning those lists over and over on the same operations.
My proposal is to do a forward pass (
while(getNextOp)), collect all producers in program order in a separate ordered list, and the first consumer, and then only look at those, possibly in reverse order.