diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index ed2efd6fea5f7..6c17c3c2d0cab 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -1245,21 +1245,21 @@ struct SliceDimInfo { OpFoldResult outputSize; }; -/// Return the first input extract slice operand, if present, for the current +/// Return all extract slice operands, if present, for the current /// generic op. -static FailureOr getSliceOperand(GenericOp genericOp) { - OpOperand *sliceOperand = nullptr; +static FailureOr> +getSliceOperands(GenericOp genericOp) { + SmallVector sliceOperands; for (auto operand : genericOp.getDpsInputOperands()) { auto extractOp = operand->get().getDefiningOp(); if (!extractOp) continue; - sliceOperand = operand; - break; + sliceOperands.push_back(operand); } - if (!sliceOperand) { + if (sliceOperands.empty()) { return failure(); } - return sliceOperand; + return sliceOperands; } // Return a map of dims that have partial slices on them so that other operands @@ -1336,14 +1336,24 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, genericOp, "propagation through generic with gather semantics is unsupported."); // Collect the sliced operand, if present. - auto maybeSliceOperand = getSliceOperand(genericOp); - if (failed(maybeSliceOperand)) + auto maybeSliceOperands = getSliceOperands(genericOp); + if (failed(maybeSliceOperands)) return failure(); - OpOperand *sliceOperand = *maybeSliceOperand; - unsigned OperandIndex = sliceOperand->getOperandNumber(); - - if (!controlFn(sliceOperand)) + SmallVector sliceOperands = *maybeSliceOperands; + OpOperand *sliceOperand; + + bool foundValidOperand = false; + for (auto currSliceOperand : sliceOperands) { + if (controlFn(currSliceOperand)) { + sliceOperand = currSliceOperand; + foundValidOperand = true; + break; + } + } + if (!foundValidOperand) { return failure(); + } + unsigned OperandIndex = sliceOperand->getOperandNumber(); tensor::ExtractSliceOp producerSliceOp = sliceOperand->get().getDefiningOp(); diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index fb16e1e7dcda4..a5f8d63a3e912 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1577,3 +1577,33 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32> // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]] // CHECK: return %[[EXTRACT]] + +// ----- +// Test that if one extract doesnt pass the control function which in this case is set to +// only allow extracts from the same block, then an extract from a later operand can still be pushed +// down. +func.func @push_extract_through_generic_secondextract(%arg0: tensor<128x128xf32>, %arg1: tensor, %arg2: index) -> tensor { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %extracted_slice1 = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor + %for = scf.for %arg3 = %c0 to %c32 step %arg2 iter_args(%arg4 = %arg1) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0, d1)> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice1, %extracted_slice : tensor, tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %in_1 : f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor + scf.yield %0 : tensor + } + return %for : tensor +} + +// CHECK-LABEL: func.func @push_extract_through_generic_secondextract +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice +// CHECK: %[[FOR:.+]] = scf.for +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXTRACT]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[PAD]], %[[ARG0]] +// CHECK: %[[EXTRACT2:.+]] = tensor.extract_slice %[[GENERIC]] +// CHECK: scf.yield %[[EXTRACT2]] diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp index 2cf25d8fc8c19..d332270468ea8 100644 --- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp @@ -34,8 +34,13 @@ struct TestDataLayoutPropagationPass RewritePatternSet patterns(context); linalg::populateDataLayoutPropagationPatterns( patterns, [](OpOperand *opOperand) { return true; }); - linalg::populateExtractSliceSinkingPatterns( - patterns, [](OpOperand *opOperand) { return true; }); + linalg::ControlPropagationFn controlExtract = + [](OpOperand *opOperand) -> bool { + Operation *producer = opOperand->get().getDefiningOp(); + Operation *consumer = opOperand->getOwner(); + return consumer->getBlock() == producer->getBlock(); + }; + linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); }