-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][x86vector] Sink Vector.transfer_reads and vector.load before the consumer #169333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
@rengolin @adam-smnk @rolfmorel @shahidact Please review. |
|
@llvm/pr-subscribers-mlir Author: Arun Thangamani (arun-thmn) ChangesAdds a pattern that sinks vector producer ops ( The lowering pattern: Full diff: https://github.com/llvm/llvm-project/pull/169333.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c5294ff14fc7..12ba5e9f11141 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -38,6 +38,17 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.sink_vector_producer_ops",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect patterns to sink vector producer operations forward in a block to
+ place them immediately before their first use.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index fc46dff63c2b7..b9c9054f57890 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -91,6 +91,10 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
void populateVectorContractToPackedTypeDotProductPatterns(
RewritePatternSet &patterns);
+// Performs forward scheduling of vector producer ops to minimize their live
+// range by placing them at their earliest legal use site
+void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 95db208207672..25772f2aa57f4 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -32,6 +32,11 @@ void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
}
+void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ x86vector::populateSinkVectorProducerOpsPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index 2cab50fb591c4..cc4d3cac0f7ea 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
LegalizeForLLVMExport.cpp
VectorContractToFMA.cpp
VectorContractToPackedTypeDotProduct.cpp
+ SinkVectorProducerOps.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
new file mode 100644
index 0000000000000..b31636958e158
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
@@ -0,0 +1,93 @@
+//===- SinkVectorProducerOps.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+/// Sink vector producers forward to reduce live ranges.
+/// This pattern applies to ops such as vector.load and vector.transfer_read.
+template <typename producerOp>
+struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> {
+ using OpRewritePattern<producerOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(producerOp op,
+ PatternRewriter &rewriter) const override {
+
+ // Collect all users of the producer op.
+ llvm::SmallVector<Operation *> users;
+ for (OpResult result : op->getResults())
+ for (Operation *user : result.getUsers())
+ users.push_back(user);
+
+ // If there are no users, nothing to sink.
+ if (users.empty())
+ return failure();
+
+ // If the next op is already a user, do not move.
+ Operation *nextOp = op->getNextNode();
+ if (llvm::is_contained(users, nextOp))
+ return failure();
+
+ // Prevent pathological looping:
+ // If the next op produces values used by any of op's users, don't move.
+ llvm::SmallVector<Operation *> nextOpUsers;
+ for (OpResult result : nextOp->getResults())
+ for (Operation *user : result.getUsers())
+ nextOpUsers.push_back(user);
+
+ Operation *nextFirstUser = nextOp->getNextNode();
+ while (nextFirstUser) {
+ if (llvm::is_contained(nextOpUsers, nextFirstUser))
+ break;
+
+ nextFirstUser = nextFirstUser->getNextNode();
+ }
+
+ if (llvm::is_contained(users, nextFirstUser))
+ return failure();
+
+ // Find the nearest user by scanning forward.
+ while (nextOp) {
+ if (llvm::is_contained(users, nextOp))
+ break;
+
+ nextOp = nextOp->getNextNode();
+ }
+
+ if (!nextOp)
+ return failure();
+
+ // // Both ops must be in the same block to safely move.
+ if (op->getBlock() != nextOp->getBlock())
+ return failure();
+
+ // Move producer immediately before its first user.
+ op->moveBefore(nextOp);
+
+ return success();
+ }
+};
+
+void x86vector::populateSinkVectorProducerOpsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<SinkVectorProducerOps<vector::TransferReadOp>,
+ SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir b/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
new file mode 100644
index 0000000000000..11af315e69e66
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
@@ -0,0 +1,199 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+func.func @sink_vector_loads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
+ %1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
+ %2 = vector.load %arg0[%c8, %c0] : memref<16x16xf32>, vector<8xf32>
+ %3 = vector.load %arg0[%c8, %c8] : memref<16x16xf32>, vector<8xf32>
+ %4 = vector.fma %0, %1, %arg1 : vector<8xf32>
+ %5 = vector.fma %2, %3, %4 : vector<8xf32>
+ return %5 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_loads
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @sink_vector_transfer_reads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+ %2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+ %3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+ %4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+ %5 = vector.fma %1, %2, %arg1 : vector<8xf32>
+ %6 = vector.fma %3, %4, %5 : vector<8xf32>
+ return %6 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @sink_vector_transfer_reads_tensor(%arg0: tensor<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+ %2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+ %3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+ %4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+ %5 = vector.fma %1, %2, %arg1 : vector<8xf32>
+ %6 = vector.fma %3, %4, %5 : vector<8xf32>
+ return %6 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads_tensor
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @sink_vector_transfer_reads_bf16(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> {
+ %0 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16>
+ %extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16>
+ %1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+ %2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+ %3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+ %4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+ %5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ %7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ return %8 : vector<1x16xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads_bf16
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @negative_no_infinite_looping(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
+ %1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
+ %2 = vector.fma %0, %1, %arg1 : vector<8xf32>
+ return %2: vector<8xf32>
+}
+
+// CHECK-LABEL: @negative_no_infinite_looping
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @negative_no_sink_outside_block(%arg0: memref<8x16xf32>, %arg1: i1) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<8x16xf32>, vector<8xf32>
+ %1 = vector.load %arg0[%c0, %c8] : memref<8x16xf32>, vector<8xf32>
+ %2 = scf.if %arg1 -> (vector<8xf32>) {
+ scf.yield %0 : vector<8xf32>
+ } else {
+ scf.yield %1 : vector<8xf32>
+ }
+ return %2 : vector<8xf32>
+}
+
+// CHECK-LABEL: @negative_no_sink_outside_block
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: scf.if
+// CHECK-NEXT: scf.yield
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
|
|
@llvm/pr-subscribers-mlir-vector Author: Arun Thangamani (arun-thmn) ChangesAdds a pattern that sinks vector producer ops ( The lowering pattern: Full diff: https://github.com/llvm/llvm-project/pull/169333.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c5294ff14fc7..12ba5e9f11141 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -38,6 +38,17 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.sink_vector_producer_ops",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect patterns to sink vector producer operations forward in a block to
+ place them immediately before their first use.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index fc46dff63c2b7..b9c9054f57890 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -91,6 +91,10 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
void populateVectorContractToPackedTypeDotProductPatterns(
RewritePatternSet &patterns);
+// Performs forward scheduling of vector producer ops to minimize their live
+// range by placing them at their earliest legal use site
+void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 95db208207672..25772f2aa57f4 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -32,6 +32,11 @@ void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
}
+void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ x86vector::populateSinkVectorProducerOpsPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index 2cab50fb591c4..cc4d3cac0f7ea 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
LegalizeForLLVMExport.cpp
VectorContractToFMA.cpp
VectorContractToPackedTypeDotProduct.cpp
+ SinkVectorProducerOps.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
new file mode 100644
index 0000000000000..b31636958e158
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
@@ -0,0 +1,93 @@
+//===- SinkVectorProducerOps.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+/// Sink vector producers forward to reduce live ranges.
+/// This pattern applies to ops such as vector.load and vector.transfer_read.
+template <typename producerOp>
+struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> {
+ using OpRewritePattern<producerOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(producerOp op,
+ PatternRewriter &rewriter) const override {
+
+ // Collect all users of the producer op.
+ llvm::SmallVector<Operation *> users;
+ for (OpResult result : op->getResults())
+ for (Operation *user : result.getUsers())
+ users.push_back(user);
+
+ // If there are no users, nothing to sink.
+ if (users.empty())
+ return failure();
+
+ // If the next op is already a user, do not move.
+ Operation *nextOp = op->getNextNode();
+ if (llvm::is_contained(users, nextOp))
+ return failure();
+
+ // Prevent pathological looping:
+ // If the next op produces values used by any of op's users, don't move.
+ llvm::SmallVector<Operation *> nextOpUsers;
+ for (OpResult result : nextOp->getResults())
+ for (Operation *user : result.getUsers())
+ nextOpUsers.push_back(user);
+
+ Operation *nextFirstUser = nextOp->getNextNode();
+ while (nextFirstUser) {
+ if (llvm::is_contained(nextOpUsers, nextFirstUser))
+ break;
+
+ nextFirstUser = nextFirstUser->getNextNode();
+ }
+
+ if (llvm::is_contained(users, nextFirstUser))
+ return failure();
+
+ // Find the nearest user by scanning forward.
+ while (nextOp) {
+ if (llvm::is_contained(users, nextOp))
+ break;
+
+ nextOp = nextOp->getNextNode();
+ }
+
+ if (!nextOp)
+ return failure();
+
+ // // Both ops must be in the same block to safely move.
+ if (op->getBlock() != nextOp->getBlock())
+ return failure();
+
+ // Move producer immediately before its first user.
+ op->moveBefore(nextOp);
+
+ return success();
+ }
+};
+
+void x86vector::populateSinkVectorProducerOpsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<SinkVectorProducerOps<vector::TransferReadOp>,
+ SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir b/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
new file mode 100644
index 0000000000000..11af315e69e66
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
@@ -0,0 +1,199 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+func.func @sink_vector_loads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
+ %1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
+ %2 = vector.load %arg0[%c8, %c0] : memref<16x16xf32>, vector<8xf32>
+ %3 = vector.load %arg0[%c8, %c8] : memref<16x16xf32>, vector<8xf32>
+ %4 = vector.fma %0, %1, %arg1 : vector<8xf32>
+ %5 = vector.fma %2, %3, %4 : vector<8xf32>
+ return %5 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_loads
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @sink_vector_transfer_reads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+ %2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+ %3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+ %4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
+ %5 = vector.fma %1, %2, %arg1 : vector<8xf32>
+ %6 = vector.fma %3, %4, %5 : vector<8xf32>
+ return %6 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @sink_vector_transfer_reads_tensor(%arg0: tensor<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+ %2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+ %3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+ %4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
+ %5 = vector.fma %1, %2, %arg1 : vector<8xf32>
+ %6 = vector.fma %3, %4, %5 : vector<8xf32>
+ return %6 : vector<8xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads_tensor
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @sink_vector_transfer_reads_bf16(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> {
+ %0 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16>
+ %extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16>
+ %1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+ %2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+ %3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+ %4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+ %5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ %7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ return %8 : vector<1x16xf32>
+}
+
+// CHECK-LABEL: @sink_vector_transfer_reads_bf16
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @negative_no_infinite_looping(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
+ %1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
+ %2 = vector.fma %0, %1, %arg1 : vector<8xf32>
+ return %2: vector<8xf32>
+}
+
+// CHECK-LABEL: @negative_no_infinite_looping
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @negative_no_sink_outside_block(%arg0: memref<8x16xf32>, %arg1: i1) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<8x16xf32>, vector<8xf32>
+ %1 = vector.load %arg0[%c0, %c8] : memref<8x16xf32>, vector<8xf32>
+ %2 = scf.if %arg1 -> (vector<8xf32>) {
+ scf.yield %0 : vector<8xf32>
+ } else {
+ scf.yield %1 : vector<8xf32>
+ }
+ return %2 : vector<8xf32>
+}
+
+// CHECK-LABEL: @negative_no_sink_outside_block
+// CHECK: vector.load
+// CHECK-NEXT: vector.load
+// CHECK-NEXT: scf.if
+// CHECK-NEXT: scf.yield
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.sink_vector_producer_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
|
|
How does this relate to #163382? |
| return failure(); | ||
|
|
||
| // Prevent pathological looping: | ||
| // If the next op produces values used by any of op's users, don't move. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the problem:
%0 = my_op
%1 = next_op
...
%m = first_user %0, %1
Both %0 and %1 can be in any order, so if %0 moves now, it will become %1 and if you run the pass again, it'll match and go back being %0, in a cycle.
However, this is not only true for the next operation, but any operation between %0 and %m. This is a high polynomial time algorithm, especially with intersection checks on two lists.
What if we just intersect users?
users = all users of all results of "op";
other = op;
while (other = other->getNextNode()) {
// If ops share users: O(num_ops * num_users)
if (is_contained(users, other->getUsers()) {
// Move "op" just _before_ "nextOp" and stop
op->moveBefore(other);
break;
}
}
Now, "op" will never pass "other", so if you run again, it will do nothing to "op".
Running the pass multiple times should compact the IR:
// Original
%a = my_op
...
%b = my_other_op
...
%c = my_last_op
...
%val = first_user (%a, %b, %c)
// First pass
...
%a = my_op
...
%b = my_other_op
...
%c = my_last_op
%val = first_user (%a, %b, %c)
// Second pass
...
...
%a = my_op
...
%b = my_other_op
%c = my_last_op
%val = first_user (%a, %b, %c)
// Third pass
...
...
...
%a = my_op
%b = my_other_op
%c = my_last_op
%val = first_user (%a, %b, %c)
Alternatively, you can take the first op (%a), get it's first user (%val), and then iterate backwards, moving %c first, then %b, then %a, and you only need to pass once to compact the IR. You can even cache the last producer sunk, and use it for the moveBefore() call without having to look for it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes Thanks. If we do intersect between the current op users and next op users, few valid patterns can buy-pass this transformation. For example, in the below code:
Example:
%a = prod1
%b = prod2
%c = prod3
%d = prod4
%c1 = fma %a %c %arg
%c2 = fma %a %d %c1
%c3 = fma %b %c %c2
%c4 = fma %b %d %c3
prod2 users: {c3, c4} and prod3 users: {c1, c3}. If the logic is based on intersect, the rewrite will be:
%b = prod2
%c = prod3
%a = prod1
%c1 = fma %a %c %arg
%d = prod4
%c2 = fma %a %d %c1
%c3 = fma %b %c %c2
%c4 = fma %b %d %c3
prod2 can be moved before %c3, but didn't as %c3 is a user for both of them. So, we check for the first users of both the current and next producer to do the shift.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you're right. We only need to check is_contained. Intersection would create false positives.
However, here, you're only moving past one operation at a time, so you don't need to check those that do not participate. However, that's wasteful, as you're recalculating the same user list for every move.
In my example above, each pass moves each producer to the best place in one step, and we'd still require multiple passes to move all producers. In your implementation, the pass runs as many times as there are instructions in between producer and first consumer, for every producer, recalculating users and scanning those lists over and over on the same operations.
My proposal is to do a forward pass (while(getNextOp)), collect all producers in program order in a separate ordered list, and the first consumer, and then only look at those, possibly in reverse order.
Adds a pattern that sinks vector producer ops (
vector.loadandvector.transfer_read) forward in a block to their earliest legal use, reducing live ranges and improving scheduling opportunities.The lowering pattern:
batch_reduce.matmul(input) -> register-tiling(M, N) -> Vectorization (tovector.contract) ->unrollvector.contract (unitdims) ->hoistingtransformation (moveCloads/store outside batch/k loop) -> sink vector producers -> applylicm,canonicalization, andbufferize->vector.contracttofma-> sink vector producers.