-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[Linalg] Fix bug in control function logic of push down extract pattern #158348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg Author: Nirvedh Meshram (nirvedhmeshram) ChangesCurrent logic just bails out if the first extract producer fails the control function, this PR fixes that. Full diff: https://github.com/llvm/llvm-project/pull/158348.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index ed2efd6fea5f7..6c17c3c2d0cab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1245,21 +1245,21 @@ struct SliceDimInfo {
OpFoldResult outputSize;
};
-/// Return the first input extract slice operand, if present, for the current
+/// Return all extract slice operands, if present, for the current
/// generic op.
-static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
- OpOperand *sliceOperand = nullptr;
+static FailureOr<SmallVector<OpOperand *>>
+getSliceOperands(GenericOp genericOp) {
+ SmallVector<OpOperand *> sliceOperands;
for (auto operand : genericOp.getDpsInputOperands()) {
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp)
continue;
- sliceOperand = operand;
- break;
+ sliceOperands.push_back(operand);
}
- if (!sliceOperand) {
+ if (sliceOperands.empty()) {
return failure();
}
- return sliceOperand;
+ return sliceOperands;
}
// Return a map of dims that have partial slices on them so that other operands
@@ -1336,14 +1336,24 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
genericOp,
"propagation through generic with gather semantics is unsupported.");
// Collect the sliced operand, if present.
- auto maybeSliceOperand = getSliceOperand(genericOp);
- if (failed(maybeSliceOperand))
+ auto maybeSliceOperands = getSliceOperands(genericOp);
+ if (failed(maybeSliceOperands))
return failure();
- OpOperand *sliceOperand = *maybeSliceOperand;
- unsigned OperandIndex = sliceOperand->getOperandNumber();
-
- if (!controlFn(sliceOperand))
+ SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
+ OpOperand *sliceOperand;
+
+ bool foundValidOperand = false;
+ for (auto currSliceOperand : sliceOperands) {
+ if (controlFn(currSliceOperand)) {
+ sliceOperand = currSliceOperand;
+ foundValidOperand = true;
+ break;
+ }
+ }
+ if (!foundValidOperand) {
return failure();
+ }
+ unsigned OperandIndex = sliceOperand->getOperandNumber();
tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index fb16e1e7dcda4..a5f8d63a3e912 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1577,3 +1577,33 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
// CHECK: return %[[EXTRACT]]
+
+// -----
+// Test that if one extract doesnt pass the control function which in this case is set to
+// only allow extracts from the same block, then an extract from a later operand can still be pushed
+// down.
+func.func @push_extract_through_generic_secondextract(%arg0: tensor<128x128xf32>, %arg1: tensor<?x?xbf16>, %arg2: index) -> tensor<?x?xbf16> {
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %extracted_slice1 = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+ %for = scf.for %arg3 = %c0 to %c32 step %arg2 iter_args(%arg4 = %arg1) -> tensor<?x?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0, d1)> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice1, %extracted_slice : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg1 : tensor<?x?xbf16>) {
+ ^bb0(%in: f32, %in_1 : f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x?xbf16>
+ scf.yield %0 : tensor<?x?xbf16>
+ }
+ return %for : tensor<?x?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic_secondextract
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK: %[[FOR:.+]] = scf.for
+// CHECK: %[[PAD:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[PAD]], %[[ARG0]]
+// CHECK: %[[EXTRACT2:.+]] = tensor.extract_slice %[[GENERIC]]
+// CHECK: scf.yield %[[EXTRACT2]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index 2cf25d8fc8c19..d332270468ea8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,8 +34,13 @@ struct TestDataLayoutPropagationPass
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](OpOperand *opOperand) { return true; });
- linalg::populateExtractSliceSinkingPatterns(
- patterns, [](OpOperand *opOperand) { return true; });
+ linalg::ControlPropagationFn controlExtract =
+ [](OpOperand *opOperand) -> bool {
+ Operation *producer = opOperand->get().getDefiningOp();
+ Operation *consumer = opOperand->getOwner();
+ return consumer->getBlock() == producer->getBlock();
+ };
+ linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
|
@llvm/pr-subscribers-mlir Author: Nirvedh Meshram (nirvedhmeshram) ChangesCurrent logic just bails out if the first extract producer fails the control function, this PR fixes that. Full diff: https://github.com/llvm/llvm-project/pull/158348.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index ed2efd6fea5f7..6c17c3c2d0cab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1245,21 +1245,21 @@ struct SliceDimInfo {
OpFoldResult outputSize;
};
-/// Return the first input extract slice operand, if present, for the current
+/// Return all extract slice operands, if present, for the current
/// generic op.
-static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
- OpOperand *sliceOperand = nullptr;
+static FailureOr<SmallVector<OpOperand *>>
+getSliceOperands(GenericOp genericOp) {
+ SmallVector<OpOperand *> sliceOperands;
for (auto operand : genericOp.getDpsInputOperands()) {
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp)
continue;
- sliceOperand = operand;
- break;
+ sliceOperands.push_back(operand);
}
- if (!sliceOperand) {
+ if (sliceOperands.empty()) {
return failure();
}
- return sliceOperand;
+ return sliceOperands;
}
// Return a map of dims that have partial slices on them so that other operands
@@ -1336,14 +1336,24 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
genericOp,
"propagation through generic with gather semantics is unsupported.");
// Collect the sliced operand, if present.
- auto maybeSliceOperand = getSliceOperand(genericOp);
- if (failed(maybeSliceOperand))
+ auto maybeSliceOperands = getSliceOperands(genericOp);
+ if (failed(maybeSliceOperands))
return failure();
- OpOperand *sliceOperand = *maybeSliceOperand;
- unsigned OperandIndex = sliceOperand->getOperandNumber();
-
- if (!controlFn(sliceOperand))
+ SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
+ OpOperand *sliceOperand;
+
+ bool foundValidOperand = false;
+ for (auto currSliceOperand : sliceOperands) {
+ if (controlFn(currSliceOperand)) {
+ sliceOperand = currSliceOperand;
+ foundValidOperand = true;
+ break;
+ }
+ }
+ if (!foundValidOperand) {
return failure();
+ }
+ unsigned OperandIndex = sliceOperand->getOperandNumber();
tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index fb16e1e7dcda4..a5f8d63a3e912 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1577,3 +1577,33 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
// CHECK: return %[[EXTRACT]]
+
+// -----
+// Test that if one extract doesnt pass the control function which in this case is set to
+// only allow extracts from the same block, then an extract from a later operand can still be pushed
+// down.
+func.func @push_extract_through_generic_secondextract(%arg0: tensor<128x128xf32>, %arg1: tensor<?x?xbf16>, %arg2: index) -> tensor<?x?xbf16> {
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %extracted_slice1 = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+ %for = scf.for %arg3 = %c0 to %c32 step %arg2 iter_args(%arg4 = %arg1) -> tensor<?x?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0, d1)> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice1, %extracted_slice : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg1 : tensor<?x?xbf16>) {
+ ^bb0(%in: f32, %in_1 : f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x?xbf16>
+ scf.yield %0 : tensor<?x?xbf16>
+ }
+ return %for : tensor<?x?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic_secondextract
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK: %[[FOR:.+]] = scf.for
+// CHECK: %[[PAD:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[PAD]], %[[ARG0]]
+// CHECK: %[[EXTRACT2:.+]] = tensor.extract_slice %[[GENERIC]]
+// CHECK: scf.yield %[[EXTRACT2]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index 2cf25d8fc8c19..d332270468ea8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,8 +34,13 @@ struct TestDataLayoutPropagationPass
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](OpOperand *opOperand) { return true; });
- linalg::populateExtractSliceSinkingPatterns(
- patterns, [](OpOperand *opOperand) { return true; });
+ linalg::ControlPropagationFn controlExtract =
+ [](OpOperand *opOperand) -> bool {
+ Operation *producer = opOperand->get().getDefiningOp();
+ Operation *consumer = opOperand->getOwner();
+ return consumer->getBlock() == producer->getBlock();
+ };
+ linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
|
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Current logic just bails out if the first extract producer fails the control function, this PR fixes that.