Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -1316,7 +1317,15 @@ getUntiledProducerFromSliceSource(OpOperand *source,
ArrayRef<LoopLikeOpInterface> loops) {
std::optional<OpOperand *> destinationIterArg;
assert(!loops.empty() && "expected non empty loops container");

// The `extractOp` may not reside within the innermost loop, calculate the
// distance between it and the last LoopLikeInterfaceOp. Adding this
// `distance` to `loopIt` yields the start of the loop.
auto loopIt = loops.rbegin();
auto parentLoop = source->getOwner()->getParentOfType<LoopLikeOpInterface>();
const LoopLikeOpInterface *it = llvm::find(loops, parentLoop);
int64_t distance = std::distance(loops.begin(), it);
loopIt += (loops.size() - distance - 1);
while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) {
auto iterArg = cast<BlockArgument>(source->get());
auto loop = *loopIt;
Expand Down Expand Up @@ -1347,7 +1356,6 @@ mlir::scf::tileAndFuseProducerOfSlice(

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(candidateSliceOp);

// 2. Clone the fused producer
// 2a. Compute the destination operands to use for the cloned operation.
SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
Expand Down Expand Up @@ -1750,6 +1758,15 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}

// The extract_slice op is created in the innermost loop by default. Using
// `moveLoopInvariantCode` and `hoistLoopInvariantSubsets` improves the
// position of the extract_slice op within the loops, allowing the fuse Op to
// be created in the correct loop.
for (LoopLikeOpInterface loop : loops)
(void)moveLoopInvariantCode(loop);
for (LoopLikeOpInterface loop : loops)
(void)hoistLoopInvariantSubsets(rewriter, loop);

// Since the loop gets potentially replaced during fusion, we need to track
// the mutation of replacement values. To do this, we attach a listener to
// update the replacements as they happen.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INIT:.+]] = tensor.empty
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
Expand Down Expand Up @@ -141,6 +141,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]])
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]]
// CHECK: %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D2]])
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]])
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
Expand All @@ -151,7 +152,6 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
// CHECK-SAME: outs(%[[FILL0_TILE]] :
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0]
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT1_TILE]] :
Expand Down Expand Up @@ -444,6 +444,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]]
// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]]
// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]]
// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
// CHECK: %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] =
// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
Expand All @@ -458,7 +459,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]]
// CHECK-DAG: %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] :
// CHECK-SAME: outs(%[[SLICE_ARG4]] :
// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
// CHECK-DAG: %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
// CHECK-DAG: %[[TILE_GEMM3:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] :
Expand Down Expand Up @@ -688,3 +688,44 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: }

// -----

func.func @pooling_ncw_max_fill_fuse(%input: tensor<?x?x?xf32>, %fake: tensor<?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%res = linalg.pooling_ncw_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
ins(%input, %fake: tensor<?x?x?xf32>, tensor<?xf32>)
outs(%fill: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %res : tensor<?x?x?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.pooling_ncw_max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%a, %b, %c, %d, %e = transform.structured.fuse %0 tile_sizes [1, 16, 1, 1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}

// CHECK-LABEL: func.func @pooling_ncw_max_fill_fuse(
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x?x?xf32>,
// CHECK-SAME: %[[FAKE:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[INIT:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG2:.+]] = %[[ITERARG1]])
// CHECK: %[[FILL_EXTRACT:.*]] = tensor.extract_slice %[[ITERARG2]]{{\[}}%[[IV0]], %[[IV1]], %[[IV2]]]
// CHECK: %[[TILED_FILL:.*]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[FILL_EXTRACT]] : tensor<1x?x1xf32>) -> tensor<1x?x1xf32>
// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG3:.*]] = %[[ITERARG2]], %[[ITERARG4:.*]] = %[[TILED_FILL]])
// CHECK: %[[TILED_INPUT:.*]] = tensor.extract_slice %[[INPUT]]{{\[}}%[[IV0]], %[[IV1]]
// CHECK: %[[TILED_FAKE:.*]] = tensor.extract_slice %[[FAKE]]{{\[}}%[[IV3]]]
// CHECK: linalg.pooling_ncw_max
// CHECK-SAME: ins(%[[TILED_INPUT]], %[[TILED_FAKE]] :
// CHECK-SAME: outs(%[[ITERARG4]] :
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
Expand All @@ -47,7 +48,6 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
// CHECK-SAME: outs(%[[FILL0_TILE]] :
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT1_TILE]] :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK: %[[RESULT:.+]]:2 = scf.forall (%[[IV:[a-zA-Z0-9]+]]) =
// CHECK-SAME: shared_outs(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
Expand All @@ -47,7 +48,6 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
// CHECK-SAME: outs(%[[FILL0_TILE]] :
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT1_TILE]] :
Expand Down