diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 28b4a01cf0ecd..53fa881c99fb0 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -509,10 +509,6 @@ def SymbolDCE : Pass<"symbol-dce"> { information on `Symbols`. }]; let constructor = "mlir::createSymbolDCEPass()"; - - let statistics = [ - Statistic<"numDCE", "num-dce'd", "Number of symbols DCE'd">, - ]; } def SymbolPrivatize : Pass<"symbol-privatize"> { diff --git a/mlir/include/mlir/Transforms/SymbolDceUtils.h b/mlir/include/mlir/Transforms/SymbolDceUtils.h new file mode 100644 index 0000000000000..04f19156a972d --- /dev/null +++ b/mlir/include/mlir/Transforms/SymbolDceUtils.h @@ -0,0 +1,32 @@ +//===- SymbolDceUtils.h.h ---------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_SYMBOLDCEUTILS_H +#define MLIR_TRANSFORMS_SYMBOLDCEUTILS_H + +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" + +namespace mlir { + +class Operation; +class SymbolTableCollection; + +/// Eliminate dead symbols in the symbolTableOp. +LogicalResult symbolDce(Operation *); + +/// Compute the liveness of the symbols within the given symbol table. +/// `symbolTableIsHidden` is true if this symbol table is known to be +/// unaccessible from operations in its parent regions. +LogicalResult computeLiveness(Operation *, SymbolTableCollection &, bool, + DenseSet &); +} // end namespace mlir + +#endif // MLIR_TRANSFORMS_SYMBOLDCEUTILS_H diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp index 87885be00daa3..acfe647ec36e5 100644 --- a/mlir/lib/Transforms/SymbolDCE.cpp +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -14,10 +14,7 @@ #include "mlir/Transforms/Passes.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/SymbolTable.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/DebugLog.h" -#include "llvm/Support/InterleavedRange.h" +#include "mlir/Transforms/SymbolDceUtils.h" namespace mlir { #define GEN_PASS_DEF_SYMBOLDCE @@ -26,165 +23,16 @@ namespace mlir { using namespace mlir; -#define DEBUG_TYPE "symbol-dce" - namespace { struct SymbolDCE : public impl::SymbolDCEBase { void runOnOperation() override; - - /// Compute the liveness of the symbols within the given symbol table. - /// `symbolTableIsHidden` is true if this symbol table is known to be - /// unaccessible from operations in its parent regions. - LogicalResult computeLiveness(Operation *symbolTableOp, - SymbolTableCollection &symbolTable, - bool symbolTableIsHidden, - DenseSet &liveSymbols); }; } // namespace void SymbolDCE::runOnOperation() { Operation *symbolTableOp = getOperation(); - - // SymbolDCE should only be run on operations that define a symbol table. - if (!symbolTableOp->hasTrait()) { - symbolTableOp->emitOpError() - << " was scheduled to run under SymbolDCE, but does not define a " - "symbol table"; + if (failed(symbolDce(symbolTableOp))) return signalPassFailure(); - } - - // A flag that signals if the top level symbol table is hidden, i.e. not - // accessible from parent scopes. - bool symbolTableIsHidden = true; - SymbolOpInterface symbol = dyn_cast(symbolTableOp); - if (symbolTableOp->getParentOp() && symbol) - symbolTableIsHidden = symbol.isPrivate(); - - // Compute the set of live symbols within the symbol table. - DenseSet liveSymbols; - SymbolTableCollection symbolTable; - if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden, - liveSymbols))) - return signalPassFailure(); - - // After computing the liveness, delete all of the symbols that were found to - // be dead. - symbolTableOp->walk([&](Operation *nestedSymbolTable) { - if (!nestedSymbolTable->hasTrait()) - return; - for (auto &block : nestedSymbolTable->getRegion(0)) { - for (Operation &op : llvm::make_early_inc_range(block)) { - if (isa(&op) && !liveSymbols.count(&op)) { - op.erase(); - ++numDCE; - } - } - } - }); -} - -/// Compute the liveness of the symbols within the given symbol table. -/// `symbolTableIsHidden` is true if this symbol table is known to be -/// unaccessible from operations in its parent regions. -LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, - SymbolTableCollection &symbolTable, - bool symbolTableIsHidden, - DenseSet &liveSymbols) { - LDBG() << "computeLiveness: " - << OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions()); - // A worklist of live operations to propagate uses from. - SmallVector worklist; - - // Walk the symbols within the current symbol table, marking the symbols that - // are known to be live. - for (auto &block : symbolTableOp->getRegion(0)) { - // Add all non-symbols or symbols that can't be discarded. - for (Operation &op : block) { - SymbolOpInterface symbol = dyn_cast(&op); - if (!symbol) { - worklist.push_back(&op); - continue; - } - bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) && - symbol.canDiscardOnUseEmpty(); - if (!isDiscardable && liveSymbols.insert(&op).second) - worklist.push_back(&op); - } - } - - // Process the set of symbols that were known to be live, adding new symbols - // that are referenced within. For operations that are not symbol tables, it - // considers the liveness with respect to the op itself rather than scope of - // nested symbol tables by enqueuing all the top level operations for - // consideration. - while (!worklist.empty()) { - Operation *op = worklist.pop_back_val(); - LDBG() << "processing: " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - - // If this is a symbol table, recursively compute its liveness. - if (op->hasTrait()) { - // The internal symbol table is hidden if the parent is, if its not a - // symbol, or if it is a private symbol. - SymbolOpInterface symbol = dyn_cast(op); - bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate(); - LDBG() << "\tsymbol table: " - << OpWithFlags(op, OpPrintingFlags().skipRegions()) - << " is hidden: " << symIsHidden; - if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols))) - return failure(); - } else { - LDBG() << "\tnon-symbol table: " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - // If the op is not a symbol table, then, unless op itself is dead which - // would be handled by DCE, we need to check all the regions and blocks - // within the op to find the uses (e.g., consider visibility within op as - // if top level rather than relying on pure symbol table visibility). This - // is more conservative than SymbolTable::walkSymbolTables in the case - // where there is again SymbolTable information to take advantage of. - for (auto ®ion : op->getRegions()) - for (auto &block : region.getBlocks()) - for (Operation &op : block) - if (op.getNumRegions()) - worklist.push_back(&op); - } - - // Get the first parent symbol table op. Note: due to enqueueing of - // top-level ops, we may not have a symbol table parent here, but if we do - // not, then we also don't have a symbol. - Operation *parentOp = op->getParentOp(); - if (!parentOp->hasTrait()) - continue; - - // Collect the uses held by this operation. - std::optional uses = SymbolTable::getSymbolUses(op); - if (!uses) { - return op->emitError() - << "operation contains potentially unknown symbol table, meaning " - << "that we can't reliable compute symbol uses"; - } - - SmallVector resolvedSymbols; - LDBG() << "uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions()); - for (const SymbolTable::SymbolUse &use : *uses) { - LDBG() << "\tuse: " << use.getUser(); - // Lookup the symbols referenced by this use. - resolvedSymbols.clear(); - if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(), - resolvedSymbols))) - // Ignore references to unknown symbols. - continue; - LDBG() << "\t\tresolved symbols: " - << llvm::interleaved(resolvedSymbols, ", "); - - // Mark each of the resolved symbols as live. - for (Operation *resolvedSymbol : resolvedSymbols) - if (liveSymbols.insert(resolvedSymbol).second) - worklist.push_back(resolvedSymbol); - } - } - - return success(); } std::unique_ptr mlir::createSymbolDCEPass() { diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt index 3ca16239ba33c..2a2d7af1ce7f8 100644 --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils LoopInvariantCodeMotionUtils.cpp RegionUtils.cpp WalkPatternRewriteDriver.cpp + SymbolDceUtils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms diff --git a/mlir/lib/Transforms/Utils/SymbolDceUtils.cpp b/mlir/lib/Transforms/Utils/SymbolDceUtils.cpp new file mode 100644 index 0000000000000..ad7808cbaa275 --- /dev/null +++ b/mlir/lib/Transforms/Utils/SymbolDceUtils.cpp @@ -0,0 +1,164 @@ +//===- SymbolDceUtils.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements eliminate dead symbols +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/SymbolDceUtils.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" + +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" + +#define DEBUG_TYPE "symbol-dce" + +llvm::LogicalResult mlir::symbolDce(Operation *symbolTableOp) { + // SymbolDCE should only be run on operations that define a symbol table. + if (!symbolTableOp->hasTrait()) { + symbolTableOp->emitOpError() + << " was scheduled to run under SymbolDCE, but does not define a " + "symbol table"; + return failure(); + } + + // A flag that signals if the top level symbol table is hidden, i.e. not + // accessible from parent scopes. + bool symbolTableIsHidden = true; + SymbolOpInterface symbol = dyn_cast(symbolTableOp); + if (symbolTableOp->getParentOp() && symbol) + symbolTableIsHidden = symbol.isPrivate(); + + // Compute the set of live symbols within the symbol table. + DenseSet liveSymbols; + SymbolTableCollection symbolTable; + if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden, + liveSymbols))) + return failure(); + + // After computing the liveness, delete all of the symbols that were found to + // be dead. + symbolTableOp->walk([&](Operation *nestedSymbolTable) { + if (!nestedSymbolTable->hasTrait()) + return; + for (auto &block : nestedSymbolTable->getRegion(0)) { + for (Operation &op : llvm::make_early_inc_range(block)) { + if (isa(&op) && !liveSymbols.count(&op)) { + op.erase(); + } + } + } + }); + return success(); +} + +/// Compute the liveness of the symbols within the given symbol table. +/// `symbolTableIsHidden` is true if this symbol table is known to be +/// unaccessible from operations in its parent regions. +llvm::LogicalResult mlir::computeLiveness(Operation *symbolTableOp, + SymbolTableCollection &symbolTable, + bool symbolTableIsHidden, + DenseSet &liveSymbols) { + LDBG() << "computeLiveness: " + << OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions()); + // A worklist of live operations to propagate uses from. + SmallVector worklist; + + // Walk the symbols within the current symbol table, marking the symbols that + // are known to be live. + for (auto &block : symbolTableOp->getRegion(0)) { + // Add all non-symbols or symbols that can't be discarded. + for (Operation &op : block) { + SymbolOpInterface symbol = dyn_cast(&op); + if (!symbol) { + worklist.push_back(&op); + continue; + } + bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) && + symbol.canDiscardOnUseEmpty(); + if (!isDiscardable && liveSymbols.insert(&op).second) + worklist.push_back(&op); + } + } + + // Process the set of symbols that were known to be live, adding new symbols + // that are referenced within. For operations that are not symbol tables, it + // considers the liveness with respect to the op itself rather than scope of + // nested symbol tables by enqueuing all the top level operations for + // consideration. + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + LDBG() << "processing: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + + // If this is a symbol table, recursively compute its liveness. + if (op->hasTrait()) { + // The internal symbol table is hidden if the parent is, if its not a + // symbol, or if it is a private symbol. + SymbolOpInterface symbol = dyn_cast(op); + bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate(); + LDBG() << "\tsymbol table: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " is hidden: " << symIsHidden; + if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols))) + return failure(); + } else { + LDBG() << "\tnon-symbol table: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + // If the op is not a symbol table, then, unless op itself is dead which + // would be handled by DCE, we need to check all the regions and blocks + // within the op to find the uses (e.g., consider visibility within op as + // if top level rather than relying on pure symbol table visibility). This + // is more conservative than SymbolTable::walkSymbolTables in the case + // where there is again SymbolTable information to take advantage of. + for (auto ®ion : op->getRegions()) + for (auto &block : region.getBlocks()) + for (Operation &op : block) + if (op.getNumRegions()) + worklist.push_back(&op); + } + + // Get the first parent symbol table op. Note: due to enqueueing of + // top-level ops, we may not have a symbol table parent here, but if we do + // not, then we also don't have a symbol. + Operation *parentOp = op->getParentOp(); + if (!parentOp->hasTrait()) + continue; + + // Collect the uses held by this operation. + std::optional uses = SymbolTable::getSymbolUses(op); + if (!uses) { + return op->emitError() + << "operation contains potentially unknown symbol table, meaning " + << "that we can't reliable compute symbol uses"; + } + + SmallVector resolvedSymbols; + LDBG() << "uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions()); + for (const SymbolTable::SymbolUse &use : *uses) { + LDBG() << "\tuse: " << use.getUser(); + // Lookup the symbols referenced by this use. + resolvedSymbols.clear(); + if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(), + resolvedSymbols))) + // Ignore references to unknown symbols. + continue; + LDBG() << "\t\tresolved symbols: " + << llvm::interleaved(resolvedSymbols, ", "); + + // Mark each of the resolved symbols as live. + for (Operation *resolvedSymbol : resolvedSymbols) + if (liveSymbols.insert(resolvedSymbol).second) + worklist.push_back(resolvedSymbol); + } + } + + return success(); +}