From 97ce7f44d611a305fd96710532c0523297bd1bca Mon Sep 17 00:00:00 2001 From: "Song, Yunfei" Date: Tue, 3 Sep 2024 00:38:41 -0700 Subject: [PATCH] fix `whileProducerOutOfLoopBlock` --- .../Transforms/IterativeTilingAndFusion.cpp | 6 +- lib/gc/Transforms/TilingUsingInterfaceX.cpp | 60 ++++++++++--------- lib/gc/Transforms/TilingUsingInterfaceX.h | 2 +- .../iterative-tiling-and-fusion.mlir | 49 +++++++++++++++ 4 files changed, 86 insertions(+), 31 deletions(-) diff --git a/lib/gc/Transforms/IterativeTilingAndFusion.cpp b/lib/gc/Transforms/IterativeTilingAndFusion.cpp index 3a04e3356..cfbc6d9e2 100644 --- a/lib/gc/Transforms/IterativeTilingAndFusion.cpp +++ b/lib/gc/Transforms/IterativeTilingAndFusion.cpp @@ -257,8 +257,8 @@ tilingSizesIfMatchedFilter(RewriterBase &rewriter, if (defOrUse.isDef()) { SmallVector backwardSlice; FailureOr realProducer = - scfX::getRealProducerOfExtractSliceOp(otherCandidate, - backwardSlice); + scfX::getRealProducerFromExtractSliceOp(otherCandidate, + backwardSlice); if (succeeded(realProducer) && realProducer->getDefiningOp() == defOrUse.ownerOp) return failure(); @@ -476,7 +476,7 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand, // stage, sorted from inner to outer. SmallVector backwardSlice; FailureOr realProducer = - scfX::getRealProducerOfExtractSliceOp(*closestSliceOp, backwardSlice); + scfX::getRealProducerFromExtractSliceOp(*closestSliceOp, backwardSlice); if (failed(realProducer)) return std::nullopt; diff --git a/lib/gc/Transforms/TilingUsingInterfaceX.cpp b/lib/gc/Transforms/TilingUsingInterfaceX.cpp index 25ada4c4f..ef74e05c4 100644 --- a/lib/gc/Transforms/TilingUsingInterfaceX.cpp +++ b/lib/gc/Transforms/TilingUsingInterfaceX.cpp @@ -191,7 +191,7 @@ tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter, /// @param candidateSliceOp: %4 = extract %args2 /// @param backwardSlice: in-out parameter populated by backward extractSliceOps /// @return OpResult Producer : %0 = producer -FailureOr mlir::scfX::getRealProducerOfExtractSliceOp( +FailureOr mlir::scfX::getRealProducerFromExtractSliceOp( Operation *candidateSliceOp, SmallVector &backwardSlice, unsigned curDepth, unsigned maxDepth) { @@ -216,8 +216,8 @@ FailureOr mlir::scfX::getRealProducerOfExtractSliceOp( } if (auto sliceOp = rootSource.getDefiningOp()) { // walk up loop to find larger candidate extractSliceOp - return getRealProducerOfExtractSliceOp(sliceOp, backwardSlice, - curDepth + 1); + return getRealProducerFromExtractSliceOp(sliceOp, backwardSlice, + curDepth + 1); } break; } @@ -278,6 +278,21 @@ struct ErasedOpListener : public RewriterBase::Listener { bool isErased(Operation *op) { return erased.count(op); } }; +/// Check if it is the ForOp that yield the result of inner loop +static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) { + if (auto forOp = dyn_cast(loop.getOperation())) { + Block::OpListType &opsInLoopBody = forOp.getBody()->getOperations(); + for (auto &&[index, op] : llvm::enumerate(opsInLoopBody)) { + // If the orderIndex of inner loop is the last second one before the + // yieldOp of ForOp, the given loop must yield the result of inner loop. + if (isa(op)) { + return success((index + 2) == opsInLoopBody.size()); + } + } + } + return failure(); +} + /// Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with /// multi-level `extractSliceOp`. E.g. /// @@ -293,7 +308,9 @@ std::optional mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp) { SmallVector backwardSlice; - if (failed(getRealProducerOfExtractSliceOp(candidateSliceOp, backwardSlice))) + FailureOr realProducer = + getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice); + if (failed(realProducer)) return std::nullopt; std::optional fuseProducerResult; @@ -303,14 +320,18 @@ mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter, for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) { // get nest loops between next candidate sliceOp and tiled producer. auto whileProducerOutOfLoopBlock = - [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult { - if (fuseProducerResult) { - Block &body = loop->getRegion(0).front(); - if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp() - ->getBlock() == &body) - return failure(); - } - return success(); + [&fuseProducerResult, + &realProducer](LoopLikeOpInterface loop) -> LogicalResult { + // ensure that all surrounding outer loops are just yielding the result of + // the inner loops. + if (failed(isForOpYieldResultOfInnerLoop(loop))) + return failure(); + Operation *originalOp = + fuseProducerResult + ? fuseProducerResult->tiledAndFusedProducer.getDefiningOp() + : realProducer->getDefiningOp(); + Block &body = loop->getRegion(0).front(); + return success(originalOp->getBlock() != &body); }; SmallVector outerLoops = getOuterNestLoopsWhile(sliceOp->getParentOfType(), @@ -515,21 +536,6 @@ static FailureOr getConsumerFromUses(Value val, return operand; } -/// Check if it is the ForOp that yield the result of inner loop -static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) { - if (auto forOp = dyn_cast(loop.getOperation())) { - for (auto &&[index, op] : - llvm::enumerate(forOp.getBody()->getOperations())) { - // If the orderIndex of inner loop is the last second one before the - // yieldOp of ForOp, the given loop must yield the result of inner loop. - if (isa(op)) { - return success((index + 2) == forOp.getBody()->getOperations().size()); - } - } - } - return failure(); -} - /// Fetch the untiled consumer of a scf.for's result which is yielded by a /// tensor.insert_slice. This function makes the following assumptions that /// tensor.insert_slice has scf.yield as its only user. diff --git a/lib/gc/Transforms/TilingUsingInterfaceX.h b/lib/gc/Transforms/TilingUsingInterfaceX.h index 778021e94..630a084ce 100644 --- a/lib/gc/Transforms/TilingUsingInterfaceX.h +++ b/lib/gc/Transforms/TilingUsingInterfaceX.h @@ -18,7 +18,7 @@ SmallVector getOuterNestLoopsWhile( LoopLikeOpInterface loop, const std::function &pred); -FailureOr getRealProducerOfExtractSliceOp( +FailureOr getRealProducerFromExtractSliceOp( Operation *candidateSliceOp, SmallVector &backwardSlice, unsigned curDepth = 0, unsigned maxDepth = 5); diff --git a/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir b/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir index 2be28edd0..d9efaca66 100644 --- a/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir +++ b/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir @@ -535,4 +535,53 @@ module { /// CHECK: return %[[FINAL_RESULT]]#1, %[[FINAL_RESULT]]#0 return %0, %1 : tensor<16x32x32xf32>, tensor<16x32xf32> } +} + +// ----- + +#map = affine_map<(d0) -> (d0 * 128)> +module { + /// CHECK-LABEL: @fuse_tiled_producer + func.func @fuse_tiled_producer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>) -> tensor<256x256xf32> { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + /// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) in (2, 2) + %1 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %dest0) -> tensor<256x256xf32> { + %iv0 = affine.apply #map(%arg4) + %iv1 = affine.apply #map(%arg5) + %extracted_slice_1 = tensor.extract_slice %arg6[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32> + %dest1 = linalg.fill ins(%cst : f32) outs(%extracted_slice_1 : tensor<128x128xf32>) -> tensor<128x128xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32> + /// CHECK: scf.for + /// CHECK: scf.for + %2 = scf.for %arg7 = %c0 to %c128 step %c64 iter_args(%arg8 = %dest1) -> (tensor<128x128xf32>) { + %3 = scf.for %arg9 = %c0 to %c128 step %c64 iter_args(%arg10 = %arg8) -> (tensor<128x128xf32>) { + %extracted_slice_4 = tensor.extract_slice %arg10[%arg7, %arg9] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32> + %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg7, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32> + %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg9] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32> + /// CHECK: %[[FILL_OUT:.*]] = linalg.fill + /// CHECK: %[[MATMUL_OUT:.*]] = linalg.matmul + /// CHECK: %[[EXP_OUT:.*]] = linalg.exp + %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32> + %insert_slice = tensor.insert_slice %4 into %arg10[%arg7, %arg9] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32> + /// CHECK: scf.yield {{.*}}, {{.*}} : tensor<128x128xf32>, tensor<128x128xf32> + scf.yield %insert_slice : tensor<128x128xf32> + } + /// CHECK: scf.yield {{.*}}, {{.*}} : tensor<128x128xf32>, tensor<128x128xf32> + scf.yield %3 : tensor<128x128xf32> + } + scf.forall.in_parallel { + /// CHECK: tensor.parallel_insert_slice + /// CHECK: tensor.parallel_insert_slice + tensor.parallel_insert_slice %2 into %arg6[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32> + } + } + %2 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + /// CHECK: return %[[FINAL_RESULT]]#1 + return %2 : tensor<256x256xf32> + } } \ No newline at end of file