Skip to content

Conversation

nirvedhmeshram
Copy link
Contributor

Current logic just bails out if the first extract producer fails the control function, this PR fixes that.

@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

Current logic just bails out if the first extract producer fails the control function, this PR fixes that.


Full diff: https://github.com/llvm/llvm-project/pull/158348.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+23-13)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+30)
  • (modified) mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp (+7-2)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index ed2efd6fea5f7..6c17c3c2d0cab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1245,21 +1245,21 @@ struct SliceDimInfo {
   OpFoldResult outputSize;
 };
 
-/// Return the first input extract slice operand, if present, for the current
+/// Return all extract slice operands, if present, for the current
 /// generic op.
-static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
-  OpOperand *sliceOperand = nullptr;
+static FailureOr<SmallVector<OpOperand *>>
+getSliceOperands(GenericOp genericOp) {
+  SmallVector<OpOperand *> sliceOperands;
   for (auto operand : genericOp.getDpsInputOperands()) {
     auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
     if (!extractOp)
       continue;
-    sliceOperand = operand;
-    break;
+    sliceOperands.push_back(operand);
   }
-  if (!sliceOperand) {
+  if (sliceOperands.empty()) {
     return failure();
   }
-  return sliceOperand;
+  return sliceOperands;
 }
 
 // Return a map of dims that have partial slices on them so that other operands
@@ -1336,14 +1336,24 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
         genericOp,
         "propagation through generic with gather semantics is unsupported.");
   // Collect the sliced operand, if present.
-  auto maybeSliceOperand = getSliceOperand(genericOp);
-  if (failed(maybeSliceOperand))
+  auto maybeSliceOperands = getSliceOperands(genericOp);
+  if (failed(maybeSliceOperands))
     return failure();
-  OpOperand *sliceOperand = *maybeSliceOperand;
-  unsigned OperandIndex = sliceOperand->getOperandNumber();
-
-  if (!controlFn(sliceOperand))
+  SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
+  OpOperand *sliceOperand;
+
+  bool foundValidOperand = false;
+  for (auto currSliceOperand : sliceOperands) {
+    if (controlFn(currSliceOperand)) {
+      sliceOperand = currSliceOperand;
+      foundValidOperand = true;
+      break;
+    }
+  }
+  if (!foundValidOperand) {
     return failure();
+  }
+  unsigned OperandIndex = sliceOperand->getOperandNumber();
 
   tensor::ExtractSliceOp producerSliceOp =
       sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index fb16e1e7dcda4..a5f8d63a3e912 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1577,3 +1577,33 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
 // CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]         
 // CHECK:         return %[[EXTRACT]]
+
+// -----
+// Test that if one extract doesnt pass the control function which in this case is set to
+// only allow extracts from the same block, then an extract from a later operand can still be pushed
+// down.
+func.func @push_extract_through_generic_secondextract(%arg0: tensor<128x128xf32>, %arg1: tensor<?x?xbf16>, %arg2: index) -> tensor<?x?xbf16> {
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %extracted_slice1 = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+  %for = scf.for %arg3 = %c0 to %c32 step %arg2 iter_args(%arg4 = %arg1) -> tensor<?x?xbf16> {
+    %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+    %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0, d1)> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice1, %extracted_slice : tensor<?x?xf32>,  tensor<?x?xf32>) outs(%arg1 : tensor<?x?xbf16>) {
+    ^bb0(%in: f32, %in_1 : f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?x?xbf16>
+    scf.yield %0 : tensor<?x?xbf16>
+  }
+ return %for : tensor<?x?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic_secondextract
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK:         %[[FOR:.+]] = scf.for
+// CHECK:           %[[PAD:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK:           %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:        ins(%[[PAD]], %[[ARG0]]
+// CHECK:           %[[EXTRACT2:.+]] =  tensor.extract_slice %[[GENERIC]]
+// CHECK:           scf.yield %[[EXTRACT2]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index 2cf25d8fc8c19..d332270468ea8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,8 +34,13 @@ struct TestDataLayoutPropagationPass
     RewritePatternSet patterns(context);
     linalg::populateDataLayoutPropagationPatterns(
         patterns, [](OpOperand *opOperand) { return true; });
-    linalg::populateExtractSliceSinkingPatterns(
-        patterns, [](OpOperand *opOperand) { return true; });
+    linalg::ControlPropagationFn controlExtract =
+        [](OpOperand *opOperand) -> bool {
+      Operation *producer = opOperand->get().getDefiningOp();
+      Operation *consumer = opOperand->getOwner();
+      return consumer->getBlock() == producer->getBlock();
+    };
+    linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
   }

@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2025

@llvm/pr-subscribers-mlir

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

Current logic just bails out if the first extract producer fails the control function, this PR fixes that.


Full diff: https://github.com/llvm/llvm-project/pull/158348.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+23-13)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+30)
  • (modified) mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp (+7-2)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index ed2efd6fea5f7..6c17c3c2d0cab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1245,21 +1245,21 @@ struct SliceDimInfo {
   OpFoldResult outputSize;
 };
 
-/// Return the first input extract slice operand, if present, for the current
+/// Return all extract slice operands, if present, for the current
 /// generic op.
-static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
-  OpOperand *sliceOperand = nullptr;
+static FailureOr<SmallVector<OpOperand *>>
+getSliceOperands(GenericOp genericOp) {
+  SmallVector<OpOperand *> sliceOperands;
   for (auto operand : genericOp.getDpsInputOperands()) {
     auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
     if (!extractOp)
       continue;
-    sliceOperand = operand;
-    break;
+    sliceOperands.push_back(operand);
   }
-  if (!sliceOperand) {
+  if (sliceOperands.empty()) {
     return failure();
   }
-  return sliceOperand;
+  return sliceOperands;
 }
 
 // Return a map of dims that have partial slices on them so that other operands
@@ -1336,14 +1336,24 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
         genericOp,
         "propagation through generic with gather semantics is unsupported.");
   // Collect the sliced operand, if present.
-  auto maybeSliceOperand = getSliceOperand(genericOp);
-  if (failed(maybeSliceOperand))
+  auto maybeSliceOperands = getSliceOperands(genericOp);
+  if (failed(maybeSliceOperands))
     return failure();
-  OpOperand *sliceOperand = *maybeSliceOperand;
-  unsigned OperandIndex = sliceOperand->getOperandNumber();
-
-  if (!controlFn(sliceOperand))
+  SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
+  OpOperand *sliceOperand;
+
+  bool foundValidOperand = false;
+  for (auto currSliceOperand : sliceOperands) {
+    if (controlFn(currSliceOperand)) {
+      sliceOperand = currSliceOperand;
+      foundValidOperand = true;
+      break;
+    }
+  }
+  if (!foundValidOperand) {
     return failure();
+  }
+  unsigned OperandIndex = sliceOperand->getOperandNumber();
 
   tensor::ExtractSliceOp producerSliceOp =
       sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index fb16e1e7dcda4..a5f8d63a3e912 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1577,3 +1577,33 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
 // CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]         
 // CHECK:         return %[[EXTRACT]]
+
+// -----
+// Test that if one extract doesnt pass the control function which in this case is set to
+// only allow extracts from the same block, then an extract from a later operand can still be pushed
+// down.
+func.func @push_extract_through_generic_secondextract(%arg0: tensor<128x128xf32>, %arg1: tensor<?x?xbf16>, %arg2: index) -> tensor<?x?xbf16> {
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %extracted_slice1 = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+  %for = scf.for %arg3 = %c0 to %c32 step %arg2 iter_args(%arg4 = %arg1) -> tensor<?x?xbf16> {
+    %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+    %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0, d1)> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice1, %extracted_slice : tensor<?x?xf32>,  tensor<?x?xf32>) outs(%arg1 : tensor<?x?xbf16>) {
+    ^bb0(%in: f32, %in_1 : f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?x?xbf16>
+    scf.yield %0 : tensor<?x?xbf16>
+  }
+ return %for : tensor<?x?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic_secondextract
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK:         %[[FOR:.+]] = scf.for
+// CHECK:           %[[PAD:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK:           %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:        ins(%[[PAD]], %[[ARG0]]
+// CHECK:           %[[EXTRACT2:.+]] =  tensor.extract_slice %[[GENERIC]]
+// CHECK:           scf.yield %[[EXTRACT2]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index 2cf25d8fc8c19..d332270468ea8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,8 +34,13 @@ struct TestDataLayoutPropagationPass
     RewritePatternSet patterns(context);
     linalg::populateDataLayoutPropagationPatterns(
         patterns, [](OpOperand *opOperand) { return true; });
-    linalg::populateExtractSliceSinkingPatterns(
-        patterns, [](OpOperand *opOperand) { return true; });
+    linalg::ControlPropagationFn controlExtract =
+        [](OpOperand *opOperand) -> bool {
+      Operation *producer = opOperand->get().getDefiningOp();
+      Operation *consumer = opOperand->getOwner();
+      return consumer->getBlock() == producer->getBlock();
+    };
+    linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
   }

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@nirvedhmeshram nirvedhmeshram merged commit 9ac1f34 into llvm:main Sep 12, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants