diff --git a/lib/gc/Transforms/TilingUsingInterfaceX.cpp b/lib/gc/Transforms/TilingUsingInterfaceX.cpp index 468002dcb..25ada4c4f 100644 --- a/lib/gc/Transforms/TilingUsingInterfaceX.cpp +++ b/lib/gc/Transforms/TilingUsingInterfaceX.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -255,6 +256,28 @@ SmallVector mlir::scfX::getOuterNestLoopsWhile( return {nestLoops.rbegin(), nestLoops.rend()}; } +/// A listener that watches which ops were erased. +struct ErasedOpListener : public RewriterBase::Listener { +private: + /// Pointers to all erased operations and blocks. + DenseSet erased; + // Hook old listener. + OpBuilder::Listener *oldListenerHook = nullptr; + +public: + ErasedOpListener() = default; + ErasedOpListener(OpBuilder::Listener *oldListener) + : oldListenerHook(oldListener) {} + void notifyOperationErased(Operation *op) override { + // Call old listener hook. + if (auto *oldListener = + dyn_cast_if_present(oldListenerHook)) + oldListener->notifyOperationErased(op); + erased.insert(op); + } + bool isErased(Operation *op) { return erased.count(op); } +}; + /// Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with /// multi-level `extractSliceOp`. E.g. /// @@ -296,6 +319,55 @@ mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter, tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops); if (!fuseProducerResult) return std::nullopt; + + // Cache old listener. + OpBuilder::Listener *oldListener = rewriter.getListener(); + // Set new listener. + ErasedOpListener newListener = ErasedOpListener(oldListener); + rewriter.setListener(&newListener); + + auto producerOp = + cast(fuseProducerResult->origProducer.getDefiningOp()); + unsigned resultNumber = fuseProducerResult->origProducer.getResultNumber(); + // cache candidate slice + auto extractSliceOp = cast(candidateSliceOp); + SmallVector offsets = extractSliceOp.getMixedOffsets(), + sizes = extractSliceOp.getMixedSizes(), + strides = extractSliceOp.getMixedStrides(); + // Explicitly execute DCE. + (void)mlir::simplifyRegions(rewriter, {*producerOp->getParentRegion()}); + // If fused producer has multiple users. + bool yieldReplacement = !newListener.isErased(producerOp); + // Reset to old listener. + rewriter.setListener(oldListener); + + if (yieldReplacement) { + OpBuilder::InsertionGuard g(rewriter); + // Set insertPoint right before tiled op. + rewriter.setInsertionPoint(fuseProducerResult->tiledOps[0]); + // Manually clone new candidate slice. + auto clonedExtractSliceOp = rewriter.create( + producerOp->getLoc(), producerOp->getResult(resultNumber), offsets, + sizes, strides); + // Yield replacement for fused producer in avoid of repeated computation. + if (failed(scf::yieldReplacementForFusedProducer( + rewriter, clonedExtractSliceOp, fuseProducerResult.value(), + outerLoops))) + return std::nullopt; + // Erase cloned candidate slice. + rewriter.eraseOp(clonedExtractSliceOp); + + unsigned loopNumResults = outerLoops.front()->getNumResults(), + producerNumResults = producerOp->getNumResults(); + // Replace other users of fused producer with new loop results. + for (auto &&[index, result] : llvm::enumerate(producerOp->getResults())) { + rewriter.replaceAllUsesWith( + result, outerLoops.front()->getResult(loopNumResults - + producerNumResults + index)); + } + // Erase fused producer op. + rewriter.eraseOp(producerOp); + } } return fuseProducerResult; } diff --git a/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir b/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir index d40782333..b0480c37e 100644 --- a/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir +++ b/test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir @@ -381,11 +381,11 @@ module { // ----- module { - // CHECK: func.func @fuse_generic_matmul( - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32> - // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32> - func.func @fuse_generic_matmul(%arg0: tensor<32x32xf32>, %arg1: tensor<2x16x16xf32>, %arg2: tensor<4x16x16xf32>) -> tensor<32x64xf32> attributes {llvm.emit_c_interface} { + /// CHECK-LABEL: @fuse_generic_matmul + /// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> + /// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32> + /// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32> + func.func @fuse_generic_matmul(%arg0: tensor<32x32xf32>, %arg1: tensor<2x16x16xf32>, %arg2: tensor<4x16x16xf32>) -> tensor<32x64xf32> { /// CHECK: %[[EMPTY_OUT_0:.*]] = tensor.empty %0 = tensor.empty() : tensor<2x2x16x16xf32> %pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<32x32xf32> -> tensor<2x2x16x16xf32> @@ -429,4 +429,33 @@ module { /// CHECK: return %[[FINAL_RESULT]]#1 return %unpack : tensor<32x64xf32> } +} + +// ----- + +module { + /// CHECK-LABEL: @yield_fused_producer + func.func @yield_fused_producer(%arg0: tensor<16x32x32xf32>) -> (tensor<16x32x32xf32>, tensor<16x32xf32>) { + /// CHECK: arith.constant + %cst_0 = arith.constant dense<2.000000e+00> : tensor<16x32x32xf32> + /// CHECK-NEXT: tensor.empty + %dest0 = tensor.empty() : tensor<16x32x32xf32> + %0 = linalg.powf ins(%arg0, %cst_0 : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%dest0 : tensor<16x32x32xf32>) -> tensor<16x32x32xf32> + /// CHECK-NEXT: tensor.empty + %dest1 = tensor.empty() : tensor<16x32xf32> + /// CHECK-NEXT: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) in (16) + /// CHECK-NEXT: tensor.extract_slice + /// CHECK-NEXT: tensor.extract_slice + /// CHECK-NEXT: tensor.extract_slice + /// CHECK-NEXT: linalg.powf + /// CHECK-NEXT: tensor.extract_slice + /// CHECK-NEXT: linalg.reduce + %1 = linalg.reduce { arith.addf } ins(%0 : tensor<16x32x32xf32>) outs(%dest1 : tensor<16x32xf32>) dimensions = [2] + /// CHECK-NEXT: scf.forall.in_parallel + /// CHECK-NEXT: tensor.parallel_insert_slice + /// CHECK-NEXT: tensor.parallel_insert_slice + /// CHECK-NEXT: } + /// CHECK: return %[[FINAL_RESULT]]#1, %[[FINAL_RESULT]]#0 + return %0, %1 : tensor<16x32x32xf32>, tensor<16x32xf32> + } } \ No newline at end of file