From 640cd804bd655fe1aabc8c6e78e4e8a3f0529dfa Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Mon, 10 Nov 2025 21:42:24 -0800 Subject: [PATCH 1/3] [mlir][SCF] Add `scf::tileAndFuseConsumer` that tiles a consumer into a given tiled loop nest. The existing `scf::tileAndFuseConsumerOfSlices` takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the `TilingInterface`. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as `loops` (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using `TilingInterface`). This handles more naturally the case where multiple operands of the consumer come from the loop nest. The `scf::tileAndFuseConsumerOfSlices` was implemented as a mirror of `scf::tileAndFuseProducerOfSlice`. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices (`tensor.insert_slice`/`tensor.parallel_insert_slice`) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the `scf::tileAndFuseConsumerOfSlices` should be deprecated in favor of `scf::tileAndFuseConsumer`. There is a lot of tech-debt that has accumulated in `scf::tileAndFuseConsumerOfSlices` that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to `scf::tileAndFuseConsumer`, the old path is still maintained. The test for `scf::tileAndFuseConsumerUsingSlices` is copied to `tile-and-fuse-consumer.mlir` to `tile-and-fuse-consumer-using-slices.mlir`. All the tests that were there in this file are now using the `tileAndFuseConsumer` method. The test op `test.tile_and_fuse_consumer` is modified to call `scf::tileAndFuseConsumer`, while a new op `test.tile_and_fuse_consumer_of_slice` is used to keep the old path tested while it is deprecated. Signed-off-by: MaheshRavishankar --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 5 + .../SCF/Transforms/TileUsingInterface.h | 12 + .../SCF/Transforms/TileUsingInterface.cpp | 221 +++- .../transform-tile-and-fuse-pack-unpack.mlir | 4 +- .../tile-and-fuse-consumer-using-slices.mlir | 1156 +++++++++++++++++ .../tile-and-fuse-consumer.mlir | 380 +++--- .../TestTilingInterfaceTransformOps.cpp | 79 +- .../TestTilingInterfaceTransformOps.td | 24 +- 8 files changed, 1630 insertions(+), 251 deletions(-) create mode 100644 mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index cd033c140a233..8bdf3e0b566ef 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [ getNumDynamicControlOperands() + getRank()); } + BlockArgument getTiedBlockArgument(OpResult opResult) { + assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult"); + return getBody()->getArgument(getRank() + opResult.getResultNumber()); + } + ::mlir::Value getInductionVar(int64_t idx) { return getInductionVars()[idx]; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 7c735d825b445..0005fad3d5c01 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, /// tiled in a manner that is consistent for all the passed slices. Note that /// the method replaces the uses of `candidateSlices` with the tiled and fused /// consumer value but does not delete the slice operations. +/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion +/// is to take the consumer operation, and find the slices to use for fusion +/// by walking its operands to the `loops` and then into the body to get the +/// slices used for fusion. struct SCFFuseConsumerOfSliceResult { // Original untiled consumer operands. SmallVector origConsumerOperands; @@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef candidateSlices, MutableArrayRef loops); +/// Fuse the `consumer` operation into the loop nest provided by `loops`. +/// The transformation looks for operands in the `consumer` that are defined +/// by the outermost loop of the loop nest in `loops`. The nested loop is +/// expected to have the structure of the loops generated through tiling. +FailureOr +tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer, + MutableArrayRef loops); + /// Method to lower an `op` that implements the `TilingInterface` to /// loops/scalars. FailureOr> diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 29b770fb4b279..7e715ee189740 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest( for (auto [outerLoop, innerLoop] : llvm::zip_equal(loops.drop_back(), loops.drop_front())) { // Again assume that all the outer loops are scf.for operations. - auto outerForLoop = cast(outerLoop); + auto outerForLoop = cast(outerLoop.getOperation()); auto outerLoopYield = cast(outerForLoop.getBody()->getTerminator()); SmallVector newYields = @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter, return clonedSlices; } -/// Implementation of fusing consumer of a single slice by computing the -/// slice of the consumer in-place for scf loop. -FailureOr -mlir::scf::tileAndFuseConsumerOfSlices( - RewriterBase &rewriter, ArrayRef candidateSlices, - MutableArrayRef loops) { - if (candidateSlices.empty()) { - return rewriter.notifyMatchFailure( - rewriter.getUnknownLoc(), - "no candidate slices provided for consumer fusion"); - } - // Return if `loops` is empty, return an error for now. Caller is expected - // to handle this case. - if (loops.empty()) { - return rewriter.notifyMatchFailure( - candidateSlices.front(), - "cannot call tile and fuse consumer with an empty loop nest"); - } +static FailureOr +tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, + ArrayRef consumerOpOperands, + ArrayRef candidateSlices, + MutableArrayRef loops) { + assert(!loops.empty() && "expected loops to be not empty"); - if (!(llvm::all_of(candidateSlices, llvm::IsaPred) || - llvm::all_of(candidateSlices, - llvm::IsaPred))) { + // 1. Check assumption for loop with `reorderOperations` disabled. + if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) { return rewriter.notifyMatchFailure( - candidateSlices.front(), - "candidates slices need to be all `tensor.extract_slice`s or " - "`tensor.parallel_insert_slice`s"); - } - - // 1. Get the consumer of scf.for for the result yielded by - // tensor.insert_slice/parallel_insert_slice. - SmallVector consumerOpOperands; - Operation *consumerOp; - { - FailureOr> maybeConsumerOpOperand = - getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); - if (failed(maybeConsumerOpOperand)) { - return rewriter.notifyMatchFailure(candidateSlices.front(), - "could not fetch consumer to fuse"); - } - std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); - consumerOp = consumerOpOperands.front()->getOwner(); + loops.front(), "the first user of loop should not dominate any define " + "of consumer operand(s)"); } LoopLikeOpInterface outerMostLoop = loops.front(); LoopLikeOpInterface innerMostLoop = loops.back(); - // Check assumption for loop with `reorderOperations` disabled. - if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { - return rewriter.notifyMatchFailure( - outerMostLoop, "the first user of loop should not dominate any define " - "of consumer operand(s)"); - } - OpBuilder::InsertionGuard g(rewriter); - // 2. Check consumer is not using scf loop's output as init. auto dstOp = dyn_cast(consumerOp); if (!dstOp) @@ -2428,11 +2391,171 @@ mlir::scf::tileAndFuseConsumerOfSlices( llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); }); + auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands); return scf::SCFFuseConsumerOfSliceResult{ - std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), + std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands), std::move(tileAndFuseResult->tiledOps)}; } +/// Implementation of fusing consumer of a single slice by computing the +/// slice of the consumer in-place for scf loop. +FailureOr +mlir::scf::tileAndFuseConsumerOfSlices( + RewriterBase &rewriter, ArrayRef candidateSlices, + MutableArrayRef loops) { + if (candidateSlices.empty()) { + return rewriter.notifyMatchFailure( + rewriter.getUnknownLoc(), + "no candidate slices provided for consumer fusion"); + } + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "cannot call tile and fuse consumer with an empty loop nest"); + } + + if (!(llvm::all_of(candidateSlices, llvm::IsaPred) || + llvm::all_of(candidateSlices, + llvm::IsaPred))) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "candidates slices need to be all `tensor.extract_slice`s or " + "`tensor.parallel_insert_slice`s"); + } + + // Get the consumer of scf.for for the result yielded by + // tensor.insert_slice/parallel_insert_slice. + SmallVector consumerOpOperands; + Operation *consumerOp; + { + FailureOr> maybeConsumerOpOperand = + getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); + if (failed(maybeConsumerOpOperand)) { + return rewriter.notifyMatchFailure(candidateSlices.front(), + "could not fetch consumer to fuse"); + } + std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); + consumerOp = consumerOpOperands.front()->getOwner(); + } + + return tileAndFuseConsumerOfSlicesImpl( + rewriter, consumerOp, consumerOpOperands, candidateSlices, loops); +} + +/// For a given `result` of a `forallOp` return the +/// `tensor.parallel_insert_slice` op (or combining op) that is used to +/// construct this result. +static std::optional +getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) { + if (result.getOwner() != forallOp) + return std::nullopt; + BlockArgument bbArg = forallOp.getTiedBlockArgument(result); + SmallVector combiningOps = forallOp.getCombiningOps(bbArg); + // If the number of combining ops is not 1, then this is unexpected. Return + // nullopt. + if (combiningOps.size() != 1) { + return std::nullopt; + } + return combiningOps[0]; +} + +/// For a given result of the loop nest that is a tiled loop nest, return the +/// insert slice-like op that is used for consumer fusion +std::optional +getProducingInsertSliceLikeOp(OpResult result, + ArrayRef loops) { + assert(!loops.empty() && "Expected loops to be not empty"); + LoopLikeOpInterface outermostLoop = loops.front(); + + if (auto forallOp = dyn_cast(outermostLoop.getOperation())) { + assert(loops.size() == 1 && + "expected only a single loop when tiling using scf.forall"); + return getProducingParallelInsertSlice(forallOp, result); + } + // Assume that the loop nest is a nested `scf.for` that is created through + // tiling and retrieve the `tensor.insert_slice` operation used to construct + // the result. + while (loops.size() != 1) { + if (result.getOwner() != loops.front()) + return std::nullopt; + auto forOp = dyn_cast(loops.front()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast(forOp.getBody()->getTerminator()); + OpResult innerForResult = + dyn_cast(yieldOp.getOperand(result.getResultNumber())); + if (!innerForResult) + return std::nullopt; + result = innerForResult; + loops = loops.drop_front(); + } + if (result.getOwner() != loops.front()) + return std::nullopt; + auto forOp = dyn_cast(loops.front()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast(forOp.getBody()->getTerminator()); + auto insertSliceOp = yieldOp.getOperand(result.getResultNumber()) + .getDefiningOp(); + if (!insertSliceOp) + return std::nullopt; + return insertSliceOp; +} + +FailureOr +mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user, + MutableArrayRef loops) { + // Only handle users that implement the `TilingInterface`. + if (!isa(user)) { + return rewriter.notifyMatchFailure( + user, "unhandled user that does not implement TilingInterface"); + } + + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + user, "cannot call tile and fuse consumer with an empty loop nest"); + } + + LoopLikeOpInterface outermostLoop = loops.front(); + + // Collect the operands of the user that come from the outermost loop of the + // loop nest. + SmallVector consumerFusableOperands; + for (OpOperand &opOperand : user->getOpOperands()) { + if (opOperand.get().getDefiningOp() == outermostLoop) { + consumerFusableOperands.push_back(&opOperand); + } + } + + // Nothing to fuse. Just return an empty set. + if (consumerFusableOperands.empty()) { + return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands, + SmallVector{}, + SmallVector{}}; + } + + // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices + // for fusion. + SmallVector candidateSlices; + candidateSlices.reserve(consumerFusableOperands.size()); + for (OpOperand *opOperand : consumerFusableOperands) { + std::optional slice = + getProducingInsertSliceLikeOp(cast(opOperand->get()), loops); + if (!slice) { + return rewriter.notifyMatchFailure( + user, + "couldnt find producing insert-slice like operation for operand"); + } + candidateSlices.push_back(slice.value()); + } + return tileAndFuseConsumerOfSlicesImpl( + rewriter, user, consumerFusableOperands, candidateSlices, loops); +} + //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index 185fb9b358055..d72ab080f3c5c 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -170,7 +170,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - transform.test.fuse_consumer %slice_op in (%forall_op) + transform.test.fuse_consumer_using_slice %slice_op in (%forall_op) : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -231,7 +231,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice + // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer // to fuse" error. transform.yield diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir new file mode 100644 index 0000000000000..62dd7faec4eb7 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir @@ -0,0 +1,1156 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32> + return %2 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +module { + func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + } + } + %in_operand_2 = tensor.empty() : tensor<64x64xf32> + %out_operand_3 = tensor.empty() : tensor<64x64xf32> + %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32> + return %2 : tensor<64x64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %first_slice_op, %second_slice_op = transform.split_handle %slice_ops + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %first_slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[MAT_OUT:.*]] = linalg.matmul +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %out_operand_4 = tensor.empty() : tensor<64xf32> + %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) { + ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.subf %out_0, %13 : f32 + %15 = arith.addf %out_1, %in : f32 + linalg.yield %14, %15 : f32, f32 + } -> (tensor<64xf32>, tensor<64xf32>) + return %2#1 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] : +// CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1] +// CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#3 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %1 = tensor.empty() : tensor<64x64xf32> + %2 = tensor.empty() : tensor<64x64xf32> + %3 = tensor.empty() : tensor<64x64xf32> + %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32): + %6 = arith.mulf %in, %in_0 : f32 + %7 = arith.subf %out, %6 : f32 + %8 = arith.addf %out_1, %in : f32 + linalg.yield %7, %8 : f32, f32 + } -> (tensor<64x64xf32>, tensor<64x64xf32>) + %5 = tensor.empty() : tensor<2048xf32> + %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32> + return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %first_slice_op, %second_slice_op = transform.split_handle %slice_ops + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %first_slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32> +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG3]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[MAT_OUT:.*]] = linalg.matmul +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: %[[UNPACK:.*]] = linalg.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32> +// CHECK: return %[[FINAL_RESULT]]#3, %[[UNPACK]] : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<2048xf32> + %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32> + return %unpack : tensor<2048xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)> +// CHECK: func.func @fuse_unpack_consumer_into_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] +// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[TILED_UNPACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#1 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<2047xf32> + %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32> + return %unpack : tensor<2047xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)> +// CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] +// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[TILED_UNPACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#1 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_perfect_tiling_pack_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<4x32x16xf32> + %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32> + return %pack : tensor<4x32x16xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_perfect_tiling_pack_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 1) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]]) +// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] +// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[GENERIC_OUT]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] + +// ----- + +#map = affine_map<(d0) -> (-d0 + 4, 16)> +func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> tensor<1x4x16x1xf32> { + %0 = tensor.empty() : tensor<1x4x16x1xf32> + %1 = tensor.empty() : tensor<4x4xf32> + %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) { + %3 = affine.min #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor + %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor + %4 = linalg.exp ins(%extracted_slice : tensor) outs(%extracted_slice_0 : tensor) -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor into tensor<4x4xf32> + } + } + %cst = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %0 : tensor<4x4xf32> -> tensor<1x4x16x1xf32> + return %pack : tensor<1x4x16x1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)> +// CHECK: func.func @fuse_pack_consumer_if_single_iteration( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32> +// CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]]) +// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]]) +// CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) +// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1] + +// ----- + +func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> { + %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { + %src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %dest = tensor.extract_slice %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %1 into %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> + } + } + %pack = linalg.pack %0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] into %arg2 : tensor<64x32xf32> -> tensor<2x64x16x1xf32> + return %pack : tensor<2x64x16x1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] + +// ----- + +// It is valid to fuse the pack op in perfect tiling scenario when the dimension +// is dynamic and padding is not needed. + +func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, %arg1: tensor<64x?xf32>, %1: tensor<64x?x16xf32>) -> tensor<64x?x16xf32> { + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor<64x?xf32> + %0 = scf.forall (%arg2) = (0) to (%d1) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x?xf32>) { + %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32> + %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32> + %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x?xf32> + } + } + %pack = linalg.pack %0 inner_dims_pos = [1] inner_tiles = [16] into %1 : tensor<64x?xf32> -> tensor<64x?x16xf32> + return %pack : tensor<64x?x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (%{{.+}}) step (16) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] + +// ----- + +// It is valid to fuse the pack op with padding semantics if it is a perfect +// tiling case. + +func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> { + %0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { + %size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2) + %src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor + %dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor + %2 = linalg.exp ins(%src : tensor) outs(%dest : tensor) -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor into tensor<64x32xf32> + } + } + %1 = tensor.empty() : tensor<22x2x3x16xf32> + %cst = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<22x2x3x16xf32> + return %pack : tensor<22x2x3x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_pack_consumer_with_padding_semantics( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32> +// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]]) +// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]]) +// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]]) +// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] + +// ----- + +// Imperfect tiling is not supported in pack op consumer fusion. + +#map = affine_map<(d0) -> (d0 * 5)> +#map1 = affine_map<(d0) -> (d0)> +func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x6xf32> { + %0 = tensor.empty() : tensor<30xf32> + %1 = scf.forall (%arg1) in (6) shared_outs(%arg2 = %0) -> (tensor<30xf32>) { + %3 = affine.apply #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%3] [5] [1] : tensor<30xf32> to tensor<5xf32> + %extracted_slice_0 = tensor.extract_slice %arg2[%3] [5] [1] : tensor<30xf32> to tensor<5xf32> + %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<5xf32>) outs(%extracted_slice_0 : tensor<5xf32>) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %in, %in : f32 + linalg.yield %5 : f32 + } -> tensor<5xf32> + scf.forall.in_parallel { + // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32> + } + } + %2 = tensor.empty() : tensor<5x6xf32> + %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32> + return %pack : tensor<5x6xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +module { + func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) { + %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32> + %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> + scf.yield %insert_slice : tensor<256x256xf32> + } + %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 2 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_add_multiple_tilable_consumers( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) +// CHECK-SAME: { +// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] : +// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : +// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp +// CHECK-SAME: ins(%[[TILED_ADD_OUT]] : +// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] : +// CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul +// CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] : +// CHECK-SAME: outs(%[[MUL_OUT_SLICE]] : +// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] : +// CHECK: } +// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 : + +// ----- + +module { + func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) { + %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32> + %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> + scf.yield %insert_slice : tensor<256x256xf32> + } + %dest1 = tensor.empty() : tensor<258x258xf32> + %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32> + %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @no_fuse_only_dps_consumer( +// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} { +// CHECK: linalg.add +// CHECK: linalg.mul +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice +// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]] + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +module { + func.func @fuse_with_tilable_consumer_with_projected_permutations(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<24xf32>) -> tensor<256x256x24xf32> { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %0 = tensor.empty() : tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %0) -> (tensor<256x256xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_1 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %4 = linalg.add ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice : tensor<64x256xf32>) -> tensor<64x256xf32> + %inserted_slice = tensor.insert_slice %4 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> + scf.yield %inserted_slice : tensor<256x256xf32> + } + %2 = tensor.empty() : tensor<256x256x24xf32> + %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg2 : tensor<256x256xf32>, tensor<24xf32>) outs(%2 : tensor<256x256x24xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<256x256x24xf32> + return %3 : tensor<256x256x24xf32> + } +} + +// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index +// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32> +// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) { +// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32> +// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32> +// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) { +// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32): +// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 +// CHECK: linalg.yield %[[VAL_23]] : f32 +// CHECK: } -> tensor<64x256x24xf32> +// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32> +// CHECK: } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @multi_slice_fusion1(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor, %arg3 : index) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor, tensor) { + %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor to tensor + %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor to tensor + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor to tensor + %generic:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init0_slice, %init1_slice : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.mulf %b0, %b1 : f32 + %1 = arith.addf %b0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 + } -> (tensor, tensor) + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor into tensor + tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor into tensor + } + } + %empty = tensor.empty(%dim0) : tensor + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%loop#0, %loop#1 : tensor, tensor) outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %result : tensor +} +// CHECK-LABEL: func @multi_slice_fusion1( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK: %[[TILESIZE:.+]] = affine.min +// CHECK-DAG: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#0, %[[GENERIC]]#1 : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: return %[[RESULT]]#2 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// Check that when the given operand tiles are inconsistent, tiling fails. + +func.func @multi_slice_fusion2(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor, %arg3 : index) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor, tensor) { + %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor to tensor + %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor to tensor + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init0_slice : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor to tensor + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init1_slice: tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor into tensor + tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor into tensor + } + } + %empty = tensor.empty(%dim0) : tensor + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%loop#0, %loop#1 : tensor, tensor) outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %result : tensor +} +// CHECK-LABEL: func @multi_slice_fusion2( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK: %[[TILESIZE:.+]] = affine.min +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: return %[[RESULT]]#2 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor, + %arg3 : index, %arg4 : index) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %dim2 = tensor.dim %arg0, %c2 : tensor + %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4) + shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor, tensor) { + %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4] + %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1] + : tensor to tensor + %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor to tensor + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init0_slice : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor to tensor + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%generic0 : tensor) outs(%init1_slice: tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor into tensor + tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor into tensor + } + } + %empty = tensor.empty(%dim0, %dim1) : tensor + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%loop#0, %loop#1 : tensor, tensor) outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %result : tensor +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-LABEL: func @multi_slice_fusion_with_broadcast( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 +// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK-DAG: %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]]) +// CHECK-DAG: %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]]) +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]] +// CHECK: return %[[RESULT]]#2 + +// ----- + +func.func @multi_slice_fusion_invalid(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor, + %arg3 : index, %arg4 : index) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %dim2 = tensor.dim %arg0, %c2 : tensor + %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4) + shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor, tensor) { + %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4] + %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1] + : tensor to tensor + %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor to tensor + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init0_slice : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + %init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor to tensor + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init1_slice: tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor + scf.forall.in_parallel { + // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor into tensor + tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor into tensor + } + } + %empty = tensor.empty(%dim0, %dim1) : tensor + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%loop#0, %loop#1 : tensor, tensor) outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %result : tensor +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index 78884625ce7dc..0137e2a69a46e 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -1,8 +1,8 @@ -// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics --mlir-print-local-scope %s | FileCheck %s #map = affine_map<(d0) -> (d0)> module { - func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + func.func @fuse_tilable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -28,14 +28,14 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %add = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %yield in (%loop) + %a, %new_loop = transform.test.fuse_consumer %add into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK: func.func @fuse_tilable_consumer_scf_for( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) @@ -60,8 +60,61 @@ module attributes {transform.with_named_sequence} { // ----- +#map = affine_map<(d0) -> (d0)> module { - func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + func.func @fuse_tilable_consumer_nested_scf_for(%arg0: tensor, %arg1: tensor, %arg2 : tensor, + %lb0 : index, %ub0 : index, %step0 : index, + %lb1 : index, %ub1 : index, %step1 : index) -> tensor { + %0 = scf.for %arg3 = %lb0 to %ub0 step %step0 iter_args(%init0 = %arg0) -> tensor { + %1 = scf.for %arg4 = %lb1 to %ub1 step %step1 iter_args(%init1 = %init0) -> tensor { + %extracted_slice = tensor.extract_slice %init1[%arg3, %arg4] [%step0, %step1] [1, 1] : tensor to tensor + %2 = tensor.insert_slice %extracted_slice into %init1[%arg3, %arg4] [%step0, %step1] [1, 1] : tensor into tensor + scf.yield %2 : tensor + } + scf.yield %1 : tensor + } + %2 = linalg.add ins(%0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor + return %2 : tensor + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loops = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop0, %loop1 = transform.split_handle %loops + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %add = transform.structured.match ops{["linalg.add"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %new_loop0, %new_loop1 = transform.test.fuse_consumer %add into (%loop0, %loop1) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func @fuse_tilable_consumer_nested_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[OUTER_RESULT:.+]]:2 = scf.for +// CHECK-SAME: iter_args(%[[INIT00:[a-zA-Z0-9_]+]] = %[[ARG0]], %[[INIT01:[a-zA-Z0-9_]+]] = %[[ARG2]]) +// CHECK: %[[INNER_RESULT:.+]]:2 = scf.for +// CHECK-SAME: iter_args(%[[INIT10:[a-zA-Z0-9_]+]] = %[[INIT00]], %[[INIT11:[a-zA-Z0-9_]+]] = %[[INIT01]]) +// CHECK-DAG: %[[OPERAND1:.+]] = tensor.extract_slice %[[INIT10]] +// CHECK-DAG: %[[OLD_INSERT_SLICE:.+]] = tensor.insert_slice %[[OPERAND1]] into %[[INIT10]] +// CHECK-DAG: %[[OPERAND2:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-DAG: %[[INIT:.+]] = tensor.extract_slice %[[INIT11]] +// CHECK: %[[ADD:.+]] = linalg.add +// CHECK-SAME: ins(%[[OPERAND1]], %[[OPERAND2]] : +// CHECK-SAME: outs(%[[INIT]] : +// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[ADD]] into %[[INIT11]] +// CHECK: scf.yield %[[OLD_INSERT_SLICE]], %[[INSERT_SLICE]] +// CHECK: scf.yield %[[INNER_RESULT]]#0, %[[INNER_RESULT]]#1 +// CHECK: return %[[OUTER_RESULT]]#1 + +// ----- + +module { + func.func @fuse_tilable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -83,19 +136,16 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %add = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %first_slice_op, %second_slice_op = transform.split_handle %slice_ops - : (!transform.any_op) - -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %add into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_forall( +// CHECK: func.func @fuse_tilable_consumer_scf_forall( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>) @@ -124,7 +174,7 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0) -> (d0)> module { - func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + func.func @fuse_tilable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -155,16 +205,18 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %producer, %consumer = transform.split_handle %generics + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %yield in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer( +// CHECK: func.func @fuse_tilable_consumer_scf_for_multi_yielding_consumer( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) @@ -193,7 +245,7 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { + func.func @fuse_tilable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -224,19 +276,16 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %first_slice_op, %second_slice_op = transform.split_handle %slice_ops - : (!transform.any_op) - -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop) + %a, %new_loops = transform.test.fuse_consumer %generic into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer( +// CHECK: func.func @fuse_tilable_consumer_scf_forall_multi_yielding_consumer( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32> @@ -293,17 +342,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> -// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)> // CHECK: func.func @fuse_unpack_consumer_into_scf_forall( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> @@ -315,8 +362,8 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : -// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) -// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 * 32)>(%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min affine_map<(d0) -> (1024, d0 * -32 + 2048)>(%[[IV1]]) // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] // CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] @@ -356,17 +403,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> -// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)> // CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> @@ -378,8 +423,8 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : -// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) -// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 * 32)>(%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min affine_map<(d0) -> (1024, d0 * -32 + 2047)>(%[[IV1]]) // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] // CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] @@ -419,16 +464,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_perfect_tiling_pack_consumer( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> @@ -440,7 +484,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : -// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]]) +// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV1]]) // CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] // CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[GENERIC_OUT]] // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] @@ -471,13 +515,12 @@ func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> ten module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %consumer into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)> // CHECK: func.func @fuse_pack_consumer_if_single_iteration( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32> @@ -485,7 +528,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16) // CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]]) -// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]]) +// CHECK-DAG: %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 4, 16)>(%[[IV]]) // CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] // CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] // CHECK: %[[ELEM:.*]] = linalg.exp @@ -517,13 +560,12 @@ func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor< module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] @@ -535,7 +577,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV]]) // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] @@ -566,13 +608,12 @@ func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, % module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] @@ -584,7 +625,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV]]) // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] @@ -616,16 +657,12 @@ func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, % module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)> -// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_with_padding_semantics( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] @@ -633,7 +670,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16) // CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]]) -// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%[[I]]) // CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]] // CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] // CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]] @@ -641,9 +678,9 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]]) -// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]]) -// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 3)>(%[[I]]) +// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply affine_map<(d0) -> (d0 ceildiv 3)>(%[[SIZE]]) +// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[J]]) // CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]] // CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] @@ -674,20 +711,21 @@ func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x linalg.yield %5 : f32 } -> tensor<5xf32> scf.forall.in_parallel { - // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32> } } %2 = tensor.empty() : tensor<5x6xf32> + // expected-error @below {{failed to fuse consumer of slice}} %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32> return %pack : tensor<5x6xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -717,11 +755,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %mulop = transform.structured.match ops{["linalg.mul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 2 + %fused_consumer, %new_loop = transform.test.fuse_consumer %mulop into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %expop = transform.structured.match ops{["linalg.exp"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %fused_consumer_2, %new_loop_2 = transform.test.fuse_consumer %expop into (%new_loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -741,64 +783,20 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] : // CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : // CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp -// CHECK-SAME: ins(%[[TILED_ADD_OUT]] : -// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] : // CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] // CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul // CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] : // CHECK-SAME: outs(%[[MUL_OUT_SLICE]] : -// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] : -// CHECK: } -// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 : - -// ----- - -module { - func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) { - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - %c256 = arith.constant 256 : index - %cst = arith.constant 0.000000e+00 : f32 - %dest0 = tensor.empty() : tensor<256x256xf32> - %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) { - %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> - %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> - %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> - %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32> - %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> - scf.yield %insert_slice : tensor<256x256xf32> - } - %dest1 = tensor.empty() : tensor<258x258xf32> - %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32> - %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> - return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32> - } -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1 - : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} -// CHECK: func.func @no_fuse_only_dps_consumer( -// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} { -// CHECK: linalg.add -// CHECK: linalg.mul -// CHECK: scf.yield +// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp +// CHECK-SAME: ins(%[[TILED_ADD_OUT]] : +// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] : +// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_MUL]], %[[INSERT_EXP]] : // CHECK: } -// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice -// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]] +// CHECK: return %[[LOOP_RESULT]]#1, %[[LOOP_RESULT]]#2 : // ----- @@ -829,40 +827,41 @@ module { } } -// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> { -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32> -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32> -// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) { -// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32> -// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32> -// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] -// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) { -// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32): -// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 -// CHECK: linalg.yield %[[VAL_23]] : f32 -// CHECK: } -> tensor<64x256x24xf32> -// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] -// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32> -// CHECK: } - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1 + %a, %b = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } +// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index +// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32> +// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) { +// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32> +// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32> +// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: %[[VAL_19:.*]] = linalg.generic +// CHECK-SAME: ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) { +// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32): +// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 +// CHECK: linalg.yield %[[VAL_23]] : f32 +// CHECK: } -> tensor<64x256x24xf32> +// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32> +// CHECK: } // ----- @@ -878,12 +877,12 @@ func.func @multi_slice_fusion1(%arg0 : tensor, %arg1 : tensor, % %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor to tensor %generic:2 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0_slice : tensor) outs(%init0_slice, %init1_slice : tensor, tensor) { + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init0_slice, %init1_slice : tensor, tensor) { ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): %0 = arith.mulf %b0, %b1 : f32 - %1 = arith.addf %b0, %b2 : f32 - linalg.yield %0, %1 : f32, f32 + %1 = arith.addf %b0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 } -> (tensor, tensor) scf.forall.in_parallel { tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor into tensor @@ -901,6 +900,19 @@ func.func @multi_slice_fusion1(%arg0 : tensor, %arg1 : tensor, % } -> tensor return %result : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %producer, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-LABEL: func @multi_slice_fusion1( // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK: %[[C0:.+]] = arith.constant 0 @@ -916,23 +928,9 @@ func.func @multi_slice_fusion1(%arg0 : tensor, %arg1 : tensor, % // CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] // CHECK: return %[[RESULT]]#2 -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %loop = transform.structured.match ops{["scf.forall"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} // ----- -// Check that when the given operand tiles are inconsistent, tiling fails. - func.func @multi_slice_fusion2(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor, %arg3 : index) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -944,20 +942,20 @@ func.func @multi_slice_fusion2(%arg0 : tensor, %arg1 : tensor, % %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor to tensor %generic0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0_slice : tensor) outs(%init0_slice : tensor) { + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init0_slice : tensor) { ^bb0(%b0 : f32, %b1 : f32): %0 = arith.mulf %b0, %b1 : f32 - linalg.yield %0 : f32 + linalg.yield %0 : f32 } -> tensor %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor to tensor %generic1 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0_slice : tensor) outs(%init1_slice: tensor) { + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init1_slice: tensor) { ^bb0(%b0 : f32, %b1 : f32): - %0 = arith.addf %b0, %b1 : f32 - linalg.yield %0: f32 + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 } -> tensor scf.forall.in_parallel { tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor into tensor @@ -975,6 +973,19 @@ func.func @multi_slice_fusion2(%arg0 : tensor, %arg1 : tensor, % } -> tensor return %result : tensor } +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %producer1, %producer2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + // CHECK-LABEL: func @multi_slice_fusion2( // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK: %[[C0:.+]] = arith.constant 0 @@ -991,19 +1002,6 @@ func.func @multi_slice_fusion2(%arg0 : tensor, %arg1 : tensor, % // CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] // CHECK: return %[[RESULT]]#2 -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %loop = transform.structured.match ops{["scf.forall"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} - // ----- func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor, @@ -1060,11 +1058,11 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %producer_1, %producer_2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -1124,7 +1122,6 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor, %arg1 : tensor< linalg.yield %0: f32 } -> tensor scf.forall.in_parallel { - // expected-error @below {{failed to fuse consumer of slice}} tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] : tensor into tensor tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] @@ -1132,6 +1129,7 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor, %arg1 : tensor< } } %empty = tensor.empty(%dim0, %dim1) : tensor + // expected-error @below {{failed to fuse consumer of slice}} %result = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -1146,11 +1144,11 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %producer_1, %producer_2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 326fec3ee5cf0..51dac0e866254 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -172,7 +172,71 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, /// Apply fusing of consumer transformation to all payload ops and store both /// the original consumer operation as well as the fused consumer operation. -static LogicalResult applyFuseConsumer( +static LogicalResult +applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, + Operation *consumer, + MutableArrayRef loops, + TransformResults &transformResults) { + SmallVector fusedConsumerOps; + + rewriter.setInsertionPoint(consumer); + + FailureOr fuseConsumerResults = + scf::tileAndFuseConsumer(rewriter, consumer, loops); + + if (failed(fuseConsumerResults)) + return consumer->emitOpError("failed to fuse consumer of slice"); + + // Report back the relevant handles to the transform op. + for (OpOperand *tiledAndFusedConsumerOperand : + fuseConsumerResults->tiledAndFusedConsumerOperands) { + fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner()); + } + + transformResults.set(transformOp->getOpResult(0), fusedConsumerOps); + for (auto [index, loop] : llvm::enumerate(loops)) { + transformResults.set(transformOp->getOpResult(index + 1), {loop}); + } + return success(); +} + +DiagnosedSilenceableFailure +transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, + TransformResults &transformResults, + TransformState &state) { + Operation *consumer = *state.getPayloadOps(getConsumer()).begin(); + + SmallVector loops; + // Since the matcher works inside-out, we need to iterate the loops in reverse. + for (auto loop : llvm::reverse(getLoops())) { + auto loopLikeOp = + dyn_cast(*state.getPayloadOps(loop).begin()); + if (!loopLikeOp) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + loops.push_back(loopLikeOp); + } + LogicalResult result = applyFuseConsumer(rewriter, getOperation(), consumer, + loops, transformResults); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); +} + +void transform::TestFuseConsumerOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getConsumerMutable(), effects); + consumesHandle(getLoopsMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// +// TestFuseConsumerUsingSliceOp +//===----------------------------------------------------------------------===// + +/// Apply fusing of consumer transformation to all payload ops and store both +/// the original consumer operation as well as the fused consumer operation. +static LogicalResult applyFuseConsumerUsingSlices( RewriterBase &rewriter, Operation *transformOp, ArrayRef slices, MutableArrayRef loops, uint32_t numConsumerToFuse, TransformResults &transformResults) { @@ -204,10 +268,9 @@ static LogicalResult applyFuseConsumer( return success(); } -DiagnosedSilenceableFailure -transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, - TransformResults &transformResults, - TransformState &state) { +DiagnosedSilenceableFailure transform::TestFuseConsumerUsingSliceOp::apply( + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { SmallVector slices; for (auto op : getTargets()) { auto sliceOp = *state.getPayloadOps(op).begin(); @@ -224,13 +287,13 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, loops.push_back(loopLikeOp); } LogicalResult result = - applyFuseConsumer(rewriter, getOperation(), slices, loops, - getNumConsumerToFuse(), transformResults); + applyFuseConsumerUsingSlices(rewriter, getOperation(), slices, loops, + getNumConsumerToFuse(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } -void transform::TestFuseConsumerOp::getEffects( +void transform::TestFuseConsumerUsingSliceOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTargetsMutable(), effects); consumesHandle(getLoopsMutable(), effects); diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 694c4229eef62..bfefad02418ac 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -49,7 +49,7 @@ def TestFuseAndYieldOp : Op, DeclareOpInterfaceMethods, @@ -73,6 +73,28 @@ def TestFuseConsumerOp : Op, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Fuses the consumer of the operation pointed to by the target handle + using the options provided as attributes. + }]; + + let arguments = (ins + TransformHandleTypeInterface:$consumer, + Variadic:$loops); + let results = (outs TransformHandleTypeInterface:$fused_consumer, + Variadic:$result_loops); + + let assemblyFormat = [{ + $consumer `into` `(` $loops `)` + attr-dict `:` functional-type(operands, results) + }]; +} + + def TestTileUsingForallOp : Op, DeclareOpInterfaceMethods, From 7e3749038a66585b06087a0eb5c2da221d75eeeb Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Wed, 12 Nov 2025 13:17:50 -0800 Subject: [PATCH 2/3] Fix warning (leading to build errors when warnings are treated as error) Signed-off-by: MaheshRavishankar --- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 7e715ee189740..03ce5555f56ff 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -2478,9 +2478,10 @@ getProducingInsertSliceLikeOp(OpResult result, // tiling and retrieve the `tensor.insert_slice` operation used to construct // the result. while (loops.size() != 1) { - if (result.getOwner() != loops.front()) + LoopLikeOpInterface loop = loops.front(); + if (result.getOwner() != loop) return std::nullopt; - auto forOp = dyn_cast(loops.front()); + auto forOp = dyn_cast(loop.getOperation()); if (!forOp) return std::nullopt; auto yieldOp = cast(forOp.getBody()->getTerminator()); @@ -2491,9 +2492,10 @@ getProducingInsertSliceLikeOp(OpResult result, result = innerForResult; loops = loops.drop_front(); } - if (result.getOwner() != loops.front()) + LoopLikeOpInterface loop = loops.front(); + if (result.getOwner() != loop) return std::nullopt; - auto forOp = dyn_cast(loops.front()); + auto forOp = dyn_cast(loop.getOperation()); if (!forOp) return std::nullopt; auto yieldOp = cast(forOp.getBody()->getTerminator()); From bf8c1de8b53c5421e81de919ee89e693c0119fca Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Wed, 12 Nov 2025 13:19:02 -0800 Subject: [PATCH 3/3] Fix linter error. Signed-off-by: MaheshRavishankar --- .../TilingInterface/TestTilingInterfaceTransformOps.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 51dac0e866254..194c052eb4682 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -207,7 +207,8 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, Operation *consumer = *state.getPayloadOps(getConsumer()).begin(); SmallVector loops; - // Since the matcher works inside-out, we need to iterate the loops in reverse. + // Since the matcher works inside-out, we need to iterate the loops in + // reverse. for (auto loop : llvm::reverse(getLoops())) { auto loopLikeOp = dyn_cast(*state.getPayloadOps(loop).begin());