Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1245,21 +1245,21 @@ struct SliceDimInfo {
OpFoldResult outputSize;
};

/// Return the first input extract slice operand, if present, for the current
/// Return all extract slice operands, if present, for the current
/// generic op.
static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
OpOperand *sliceOperand = nullptr;
static FailureOr<SmallVector<OpOperand *>>
getSliceOperands(GenericOp genericOp) {
SmallVector<OpOperand *> sliceOperands;
for (auto operand : genericOp.getDpsInputOperands()) {
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp)
continue;
sliceOperand = operand;
break;
sliceOperands.push_back(operand);
}
if (!sliceOperand) {
if (sliceOperands.empty()) {
return failure();
}
return sliceOperand;
return sliceOperands;
}

// Return a map of dims that have partial slices on them so that other operands
Expand Down Expand Up @@ -1336,14 +1336,24 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
genericOp,
"propagation through generic with gather semantics is unsupported.");
// Collect the sliced operand, if present.
auto maybeSliceOperand = getSliceOperand(genericOp);
if (failed(maybeSliceOperand))
auto maybeSliceOperands = getSliceOperands(genericOp);
if (failed(maybeSliceOperands))
return failure();
OpOperand *sliceOperand = *maybeSliceOperand;
unsigned OperandIndex = sliceOperand->getOperandNumber();

if (!controlFn(sliceOperand))
SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
OpOperand *sliceOperand;

bool foundValidOperand = false;
for (auto currSliceOperand : sliceOperands) {
if (controlFn(currSliceOperand)) {
sliceOperand = currSliceOperand;
foundValidOperand = true;
break;
}
}
if (!foundValidOperand) {
return failure();
}
unsigned OperandIndex = sliceOperand->getOperandNumber();

tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1577,3 +1577,33 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
// CHECK: return %[[EXTRACT]]

// -----
// Test that if one extract doesnt pass the control function which in this case is set to
// only allow extracts from the same block, then an extract from a later operand can still be pushed
// down.
func.func @push_extract_through_generic_secondextract(%arg0: tensor<128x128xf32>, %arg1: tensor<?x?xbf16>, %arg2: index) -> tensor<?x?xbf16> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%extracted_slice1 = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
%for = scf.for %arg3 = %c0 to %c32 step %arg2 iter_args(%arg4 = %arg1) -> tensor<?x?xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0, d1)> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice1, %extracted_slice : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg1 : tensor<?x?xbf16>) {
^bb0(%in: f32, %in_1 : f32, %out: bf16):
%1 = arith.truncf %in : f32 to bf16
linalg.yield %1 : bf16
} -> tensor<?x?xbf16>
scf.yield %0 : tensor<?x?xbf16>
}
return %for : tensor<?x?xbf16>
}

// CHECK-LABEL: func.func @push_extract_through_generic_secondextract
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
// CHECK: %[[FOR:.+]] = scf.for
// CHECK: %[[PAD:.+]] = tensor.pad %[[EXTRACT]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[PAD]], %[[ARG0]]
// CHECK: %[[EXTRACT2:.+]] = tensor.extract_slice %[[GENERIC]]
// CHECK: scf.yield %[[EXTRACT2]]
9 changes: 7 additions & 2 deletions mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ struct TestDataLayoutPropagationPass
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](OpOperand *opOperand) { return true; });
linalg::populateExtractSliceSinkingPatterns(
patterns, [](OpOperand *opOperand) { return true; });
linalg::ControlPropagationFn controlExtract =
[](OpOperand *opOperand) -> bool {
Operation *producer = opOperand->get().getDefiningOp();
Operation *consumer = opOperand->getOwner();
return consumer->getBlock() == producer->getBlock();
};
linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
Expand Down