3 changes: 3 additions & 0 deletions mlir/include/mlir/TableGen/OpInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class OpInterface {
// Return the interfaces extra class declaration code.
llvm::Optional<StringRef> getExtraClassDeclaration() const;

// Return the traits extra class declaration code.
llvm::Optional<StringRef> getExtraTraitClassDeclaration() const;

// Return the verify method body if it has one.
llvm::Optional<StringRef> getVerify() const;

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_library(MLIRIR
DEPENDS
MLIRCallInterfacesIncGen
MLIROpAsmInterfacesIncGen
MLIRSymbolInterfacesIncGen
)
target_link_libraries(MLIRIR
PUBLIC
Expand Down
45 changes: 38 additions & 7 deletions mlir/lib/IR/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,6 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
setSymbolName(symbol, nameBuffer);
}

/// Returns true if the given operation defines a symbol.
bool SymbolTable::isSymbol(Operation *op) {
return op->hasTrait<OpTrait::Symbol>() || getNameIfSymbol(op).hasValue();
}

/// Returns the name of the given symbol operation.
StringRef SymbolTable::getSymbolName(Operation *symbol) {
Optional<StringRef> name = getNameIfSymbol(symbol);
Expand Down Expand Up @@ -212,6 +207,35 @@ Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
return from;
}

/// Walks all symbol table operations nested within, and including, `op`. For
/// each symbol table operation, the provided callback is invoked with the op
/// and a boolean signifying if the symbols within that symbol table can be
/// treated as if all uses are visible. `allSymUsesVisible` identifies whether
/// all of the symbol uses of symbols within `op` are visible.
void SymbolTable::walkSymbolTables(
Operation *op, bool allSymUsesVisible,
function_ref<void(Operation *, bool)> callback) {
bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
if (isSymbolTable) {
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
allSymUsesVisible |= !symbol || symbol.isPrivate();
} else {
// Otherwise if 'op' is not a symbol table, any nested symbols are
// guaranteed to be hidden.
allSymUsesVisible = true;
}

for (Region &region : op->getRegions())
for (Block &block : region)
for (Operation &nestedOp : block)
walkSymbolTables(&nestedOp, allSymUsesVisible, callback);

// If 'op' had the symbol table trait, visit it after any nested symbol
// tables.
if (isSymbolTable)
callback(op, allSymUsesVisible);
}

/// Returns the operation registered with the given symbol name with the
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
Expand Down Expand Up @@ -286,7 +310,7 @@ Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
// SymbolTable Trait Types
//===----------------------------------------------------------------------===//

LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
LogicalResult detail::verifySymbolTable(Operation *op) {
if (op->getNumRegions() != 1)
return op->emitOpError()
<< "Operations with a 'SymbolTable' must have exactly one region";
Expand Down Expand Up @@ -316,7 +340,7 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
return success();
}

LogicalResult OpTrait::impl::verifySymbol(Operation *op) {
LogicalResult detail::verifySymbol(Operation *op) {
// Verify the name attribute.
if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
return op->emitOpError() << "requires string attribute '"
Expand Down Expand Up @@ -866,3 +890,10 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
Region *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}

//===----------------------------------------------------------------------===//
// Symbol Interfaces
//===----------------------------------------------------------------------===//

/// Include the generated symbol interfaces.
#include "mlir/IR/SymbolInterfaces.cpp.inc"
6 changes: 6 additions & 0 deletions mlir/lib/TableGen/OpInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ llvm::Optional<StringRef> OpInterface::getExtraClassDeclaration() const {
return value.empty() ? llvm::Optional<StringRef>() : value;
}

// Return the traits extra class declaration code.
llvm::Optional<StringRef> OpInterface::getExtraTraitClassDeclaration() const {
auto value = def->getValueAsString("extraTraitClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
}

// Return the body for this method if it has one.
llvm::Optional<StringRef> OpInterface::getVerify() const {
auto value = def->getValueAsString("verify");
Expand Down
56 changes: 8 additions & 48 deletions mlir/lib/Transforms/Inliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,50 +31,6 @@ using namespace mlir;
// Symbol Use Tracking
//===----------------------------------------------------------------------===//

/// Returns true if this operation can be discarded if it is a symbol and has no
/// uses. 'allUsesVisible' corresponds to if the parent symbol table is hidden
/// from above.
static bool canDiscardSymbolOnUseEmpty(Operation *op, bool allUsesVisible) {
if (!SymbolTable::isSymbol(op))
return false;

// TODO: This is essentially the same logic from SymbolDCE. Remove this when
// we have a 'Symbol' interface.
// Private symbols are always initially considered dead.
SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
if (visibility == mlir::SymbolTable::Visibility::Private)
return true;
// We only include nested visibility here if all uses are visible.
if (allUsesVisible && visibility == SymbolTable::Visibility::Nested)
return true;
// Otherwise, public symbols are never removable.
return false;
}

/// Walk all of the symbol table operations nested with 'op' along with a
/// boolean signifying if the symbols within can be treated as if all uses are
/// visible. The provided callback is invoked with the symbol table operation,
/// and a boolean signaling if all of the uses within the symbol table are
/// visible.
static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
function_ref<void(Operation *, bool)> callback) {
if (op->hasTrait<OpTrait::SymbolTable>()) {
allSymUsesVisible = allSymUsesVisible || !SymbolTable::isSymbol(op) ||
SymbolTable::getSymbolVisibility(op) ==
SymbolTable::Visibility::Private;
callback(op, allSymUsesVisible);
} else {
// Otherwise if 'op' is not a symbol table, any nested symbols are
// guaranteed to be hidden.
allSymUsesVisible = true;
}

for (Region &region : op->getRegions())
for (Block &block : region)
for (Operation &nested : block)
walkSymbolTables(&nested, allSymUsesVisible, callback);
}

/// Walk all of the used symbol callgraph nodes referenced with the given op.
static void walkReferencedSymbolNodes(
Operation *op, CallGraph &cg,
Expand Down Expand Up @@ -171,8 +127,11 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
// If this is a callgraph operation, check to see if it is discardable.
if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
if (canDiscardSymbolOnUseEmpty(&op, allUsesVisible))
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
if (symbol && (allUsesVisible || symbol.isPrivate()) &&
symbol.canDiscardOnUseEmpty()) {
discardableSymNodeUses.try_emplace(node, 0);
}
continue;
}
}
Expand All @@ -182,7 +141,8 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
}
}
};
walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), walkFn);
SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
walkFn);

// Drop the use information for any discardable nodes that are always live.
for (auto &it : alwaysLiveNodes)
Expand Down Expand Up @@ -224,7 +184,7 @@ void CGUseList::eraseNode(CallGraphNode *node) {
bool CGUseList::isDead(CallGraphNode *node) const {
// If the parent operation isn't a symbol, simply check normal SSA deadness.
Operation *nodeOp = node->getCallableRegion()->getParentOp();
if (!SymbolTable::isSymbol(nodeOp))
if (!isa<SymbolOpInterface>(nodeOp))
return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();

// Otherwise, check the number of symbol uses.
Expand All @@ -235,7 +195,7 @@ bool CGUseList::isDead(CallGraphNode *node) const {
bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
// If this isn't a symbol node, check for side-effects and SSA use count.
Operation *nodeOp = node->getCallableRegion()->getParentOp();
if (!SymbolTable::isSymbol(nodeOp))
if (!isa<SymbolOpInterface>(nodeOp))
return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();

// Otherwise, check the number of symbol uses.
Expand Down
257 changes: 251 additions & 6 deletions mlir/lib/Transforms/SCCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,56 @@ class LatticeValue {
Dialect *constantDialect;
};

/// This class contains various state used when computing the lattice of a
/// callable operation.
class CallableLatticeState {
public:
/// Build a lattice state with a given callable region, and a specified number
/// of results to be initialized to the default lattice value (Unknown).
CallableLatticeState(Region *callableRegion, unsigned numResults)
: callableArguments(callableRegion->front().getArguments()),
resultLatticeValues(numResults) {}

/// Returns the arguments to the callable region.
Block::BlockArgListType getCallableArguments() const {
return callableArguments;
}

/// Returns the lattice value for the results of the callable region.
MutableArrayRef<LatticeValue> getResultLatticeValues() {
return resultLatticeValues;
}

/// Add a call to this callable. This is only used if the callable defines a
/// symbol.
void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }

/// Return the calls that reference this callable. This is only used
/// if the callable defines a symbol.
ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }

private:
/// The arguments of the callable region.
Block::BlockArgListType callableArguments;

/// The lattice state for each of the results of this region. The return
/// values of the callable aren't SSA values, so we need to track them
/// separately.
SmallVector<LatticeValue, 4> resultLatticeValues;

/// The calls referencing this callable if this callable defines a symbol.
/// This removes the need to recompute symbol references during propagation.
/// Value based references are trivial to resolve, so they can be done
/// in-place.
SmallVector<Operation *, 4> symbolCalls;
};

/// This class represents the solver for the SCCP analysis. This class acts as
/// the propagation engine for computing which values form constants.
class SCCPSolver {
public:
/// Initialize the solver with a given set of regions.
SCCPSolver(MutableArrayRef<Region> regions);
/// Initialize the solver with the given top-level operation.
SCCPSolver(Operation *op);

/// Run the solver until it converges.
void solve();
Expand All @@ -132,6 +176,11 @@ class SCCPSolver {
void rewrite(MLIRContext *context, MutableArrayRef<Region> regions);

private:
/// Initialize the set of symbol defining callables that can have their
/// arguments and results tracked. 'op' is the top-level operation that SCCP
/// is operating on.
void initializeSymbolCallables(Operation *op);

/// Replace the given value with a constant if the corresponding lattice
/// represents a constant. Returns success if the value was replaced, failure
/// otherwise.
Expand All @@ -149,6 +198,13 @@ class SCCPSolver {
/// Visit the given operation and compute any necessary lattice state.
void visitOperation(Operation *op);

/// Visit the given call operation and compute any necessary lattice state.
void visitCallOperation(CallOpInterface op);

/// Visit the given callable operation and compute any necessary lattice
/// state.
void visitCallableOperation(Operation *op);

/// Visit the given operation, which defines regions, and compute any
/// necessary lattice state. This also resolves the lattice state of both the
/// operation results and any nested regions.
Expand All @@ -168,6 +224,11 @@ class SCCPSolver {
void visitTerminatorOperation(Operation *op,
ArrayRef<Attribute> constantOperands);

/// Visit the given terminator operation that exits a callable region. These
/// are terminators with no CFG successors.
void visitCallableTerminatorOperation(Operation *callable,
Operation *terminator);

/// Visit the given block and compute any necessary lattice state.
void visitBlock(Block *block);

Expand Down Expand Up @@ -235,11 +296,20 @@ class SCCPSolver {

/// A worklist of operations that need to be processed.
SmallVector<Operation *, 64> opWorklist;

/// The callable operations that have their argument/result state tracked.
DenseMap<Operation *, CallableLatticeState> callableLatticeState;

/// A map between a call operation and the resolved symbol callable. This
/// avoids re-resolving symbol references during propagation. Value based
/// callables are trivial to resolve, so they can be done in-place.
DenseMap<Operation *, Operation *> callToSymbolCallable;
};
} // end anonymous namespace

SCCPSolver::SCCPSolver(MutableArrayRef<Region> regions) {
for (Region &region : regions) {
SCCPSolver::SCCPSolver(Operation *op) {
/// Initialize the solver with the regions within this operation.
for (Region &region : op->getRegions()) {
if (region.empty())
continue;
Block *entryBlock = &region.front();
Expand All @@ -251,6 +321,7 @@ SCCPSolver::SCCPSolver(MutableArrayRef<Region> regions) {
// as overdefined.
markAllOverdefined(entryBlock->getArguments());
}
initializeSymbolCallables(op);
}

void SCCPSolver::solve() {
Expand Down Expand Up @@ -310,6 +381,73 @@ void SCCPSolver::rewrite(MLIRContext *context,
}
}

void SCCPSolver::initializeSymbolCallables(Operation *op) {
// Initialize the set of symbol callables that can have their state tracked.
// This tracks which symbol callable operations we can propagate within and
// out of.
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
Region &symbolTableRegion = symTable->getRegion(0);
Block *symbolTableBlock = &symbolTableRegion.front();
for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
// We won't be able to track external callables.
Region *callableRegion = callable.getCallableRegion();
if (!callableRegion)
continue;
// We only care about symbol defining callables here.
auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
if (!symbol)
continue;
callableLatticeState.try_emplace(callable, callableRegion,
callable.getCallableResults().size());

// If not all of the uses of this symbol are visible, we can't track the
// state of the arguments.
if (symbol.isPublic() || (!allUsesVisible && symbol.isNested()))
markAllOverdefined(callableRegion->front().getArguments());
}
if (callableLatticeState.empty())
return;

// After computing the valid callables, walk any symbol uses to check
// for non-call references. We won't be able to track the lattice state
// for arguments to these callables, as we can't guarantee that we can see
// all of its calls.
Optional<SymbolTable::UseRange> uses =
SymbolTable::getSymbolUses(&symbolTableRegion);
if (!uses) {
// If we couldn't gather the symbol uses, conservatively assume that
// we can't track information for any nested symbols.
op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
return;
}

for (const SymbolTable::SymbolUse &use : *uses) {
// If the use is a call, track it to avoid the need to recompute the
// reference later.
if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
Operation *symCallable = callOp.resolveCallable();
auto callableLatticeIt = callableLatticeState.find(symCallable);
if (callableLatticeIt != callableLatticeState.end()) {
callToSymbolCallable.try_emplace(callOp, symCallable);

// We only need to record the call in the lattice if it produces any
// values.
if (callOp.getOperation()->getNumResults())
callableLatticeIt->second.addSymbolCall(callOp);
}
continue;
}
// This use isn't a call, so don't we know all of the callers.
auto *symbol = SymbolTable::lookupSymbolIn(op, use.getSymbolRef());
auto it = callableLatticeState.find(symbol);
if (it != callableLatticeState.end())
markAllOverdefined(it->second.getCallableArguments());
}
};
SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
walkFn);
}

LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder,
OperationFolder &folder,
Value value) {
Expand Down Expand Up @@ -347,6 +485,16 @@ void SCCPSolver::visitOperation(Operation *op) {
if (op->isKnownTerminator())
visitTerminatorOperation(op, operandConstants);

// Process call operations. The call visitor processes result values, so we
// can exit afterwards.
if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call);

// Process callable operations. These are specially handled region operations
// that track dataflow via calls.
if (isa<CallableOpInterface>(op))
return visitCallableOperation(op);

// Process region holding operations. The region visitor processes result
// values, so we can exit afterwards.
if (op->getNumRegions())
Expand Down Expand Up @@ -399,6 +547,62 @@ void SCCPSolver::visitOperation(Operation *op) {
}
}

void SCCPSolver::visitCallableOperation(Operation *op) {
// Mark the regions as executable.
bool isTrackingLatticeState = callableLatticeState.count(op);
for (Region &region : op->getRegions()) {
if (region.empty())
continue;
Block *entryBlock = &region.front();
markBlockExecutable(entryBlock);

// If we aren't tracking lattice state for this callable, mark all of the
// region arguments as overdefined.
if (!isTrackingLatticeState)
markAllOverdefined(entryBlock->getArguments());
}

// TODO: Add support for non-symbol callables when necessary. If the callable
// has non-call uses we would mark overdefined, otherwise allow for
// propagating the return values out.
markAllOverdefined(op, op->getResults());
}

void SCCPSolver::visitCallOperation(CallOpInterface op) {
ResultRange callResults = op.getOperation()->getResults();

// Resolve the callable operation for this call.
Operation *callableOp = nullptr;
if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
callableOp = callableValue.getDefiningOp();
else
callableOp = callToSymbolCallable.lookup(op);

// The callable of this call can't be resolved, mark any results overdefined.
if (!callableOp)
return markAllOverdefined(op, callResults);

// If this callable is tracking state, merge the argument operands with the
// arguments of the callable.
auto callableLatticeIt = callableLatticeState.find(callableOp);
if (callableLatticeIt == callableLatticeState.end())
return markAllOverdefined(op, callResults);

OperandRange callOperands = op.getArgOperands();
auto callableArgs = callableLatticeIt->second.getCallableArguments();
for (auto it : llvm::zip(callOperands, callableArgs)) {
BlockArgument callableArg = std::get<1>(it);
if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)]))
visitUsers(callableArg);
}

// Merge in the lattice state for the callable results as well.
auto callableResults = callableLatticeIt->second.getResultLatticeValues();
for (auto it : llvm::zip(callResults, callableResults))
meet(/*owner=*/op, /*to=*/latticeValues[std::get<0>(it)],
/*from=*/std::get<1>(it));
}

void SCCPSolver::visitRegionOperation(Operation *op,
ArrayRef<Attribute> constantOperands) {
// Check to see if we can reason about the internal control flow of this
Expand Down Expand Up @@ -509,9 +713,14 @@ void SCCPSolver::visitTerminatorOperation(
Operation *op, ArrayRef<Attribute> constantOperands) {
// If this operation has no successors, we treat it as an exiting terminator.
if (op->getNumSuccessors() == 0) {
// Check to see if the parent tracks region control flow.
Region *parentRegion = op->getParentRegion();
Operation *parentOp = parentRegion->getParentOp();

// Check to see if this is a terminator for a callable region.
if (isa<CallableOpInterface>(parentOp))
return visitCallableTerminatorOperation(parentOp, op);

// Otherwise, check to see if the parent tracks region control flow.
auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
return;
Expand Down Expand Up @@ -552,6 +761,42 @@ void SCCPSolver::visitTerminatorOperation(
markEdgeExecutable(block, succ);
}

void SCCPSolver::visitCallableTerminatorOperation(Operation *callable,
Operation *terminator) {
// If there are no exiting values, we have nothing to track.
if (terminator->getNumOperands() == 0)
return;

// If this callable isn't tracking any lattice state there is nothing to do.
auto latticeIt = callableLatticeState.find(callable);
if (latticeIt == callableLatticeState.end())
return;
assert(callable->getNumResults() == 0 && "expected symbol callable");

// If this terminator is not "return-like", conservatively mark all of the
// call-site results as overdefined.
auto callableResultLattices = latticeIt->second.getResultLatticeValues();
if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
for (auto &it : callableResultLattices)
it.markOverdefined();
for (Operation *call : latticeIt->second.getSymbolCalls())
markAllOverdefined(call, call->getResults());
return;
}

// Merge the terminator operands into the results.
bool anyChanged = false;
for (auto it : llvm::zip(terminator->getOperands(), callableResultLattices))
anyChanged |= std::get<1>(it).meet(latticeValues[std::get<0>(it)]);
if (!anyChanged)
return;

// If any of the result lattices changed, update the callers.
for (Operation *call : latticeIt->second.getSymbolCalls())
for (auto it : llvm::zip(call->getResults(), callableResultLattices))
meet(call, latticeValues[std::get<0>(it)], std::get<1>(it));
}

void SCCPSolver::visitBlock(Block *block) {
// If the block is not the entry block we need to compute the lattice state
// for the block arguments. Entry block argument lattices are computed
Expand Down Expand Up @@ -663,7 +908,7 @@ void SCCP::runOnOperation() {
Operation *op = getOperation();

// Solve for SCCP constraints within nested regions.
SCCPSolver solver(op->getRegions());
SCCPSolver solver(op);
solver.solve();

// Cleanup any operations using the solver analysis.
Expand Down
42 changes: 13 additions & 29 deletions mlir/lib/Transforms/SymbolDCE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ void SymbolDCE::runOnOperation() {
// A flag that signals if the top level symbol table is hidden, i.e. not
// accessible from parent scopes.
bool symbolTableIsHidden = true;
if (symbolTableOp->getParentOp() && SymbolTable::isSymbol(symbolTableOp)) {
symbolTableIsHidden = SymbolTable::getSymbolVisibility(symbolTableOp) ==
SymbolTable::Visibility::Private;
}
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
if (symbolTableOp->getParentOp() && symbol)
symbolTableIsHidden = symbol.isPrivate();

// Compute the set of live symbols within the symbol table.
DenseSet<Operation *> liveSymbols;
Expand All @@ -61,7 +60,7 @@ void SymbolDCE::runOnOperation() {
for (auto &block : nestedSymbolTable->getRegion(0)) {
for (Operation &op :
llvm::make_early_inc_range(block.without_terminator())) {
if (SymbolTable::isSymbol(&op) && !liveSymbols.count(&op))
if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op))
op.erase();
}
}
Expand All @@ -80,30 +79,16 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
// Walk the symbols within the current symbol table, marking the symbols that
// are known to be live.
for (auto &block : symbolTableOp->getRegion(0)) {
// Add all non-symbols or symbols that can't be discarded.
for (Operation &op : block.without_terminator()) {
// Always add non symbol operations to the worklist.
if (!SymbolTable::isSymbol(&op)) {
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
if (!symbol) {
worklist.push_back(&op);
continue;
}

// Check the visibility to see if this symbol may be referenced
// externally.
SymbolTable::Visibility visibility =
SymbolTable::getSymbolVisibility(&op);

// Private symbols are always initially considered dead.
if (visibility == mlir::SymbolTable::Visibility::Private)
continue;
// We only include nested visibility here if the symbol table isn't
// hidden.
if (symbolTableIsHidden && visibility == SymbolTable::Visibility::Nested)
continue;

// TODO(riverriddle) Add hooks here to allow symbols to provide additional
// information, e.g. linkage can be used to drop some symbols that may
// otherwise be considered "live".
if (liveSymbols.insert(&op).second)
bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
symbol.canDiscardOnUseEmpty();
if (!isDiscardable && liveSymbols.insert(&op).second)
worklist.push_back(&op);
}
}
Expand All @@ -117,10 +102,9 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
if (op->hasTrait<OpTrait::SymbolTable>()) {
// The internal symbol table is hidden if the parent is, if its not a
// symbol, or if it is a private symbol.
bool symbolIsHidden = symbolTableIsHidden || !SymbolTable::isSymbol(op) ||
SymbolTable::getSymbolVisibility(op) ==
SymbolTable::Visibility::Private;
if (failed(computeLiveness(op, symbolIsHidden, liveSymbols)))
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
if (failed(computeLiveness(op, symIsHidden, liveSymbols)))
return failure();
}

Expand Down
257 changes: 257 additions & 0 deletions mlir/test/Transforms/sccp-callgraph.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -sccp -split-input-file | FileCheck %s -dump-input-on-failure
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="module(sccp)" -split-input-file | FileCheck %s --check-prefix=NESTED -dump-input-on-failure

/// Check that a constant is properly propagated through the arguments and
/// results of a private function.

// CHECK-LABEL: func @private(
func @private(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
// CHECK: %[[CST:.*]] = constant 1 : i32
// CHECK: return %[[CST]] : i32

return %arg0 : i32
}

// CHECK-LABEL: func @simple_private(
func @simple_private() -> i32 {
// CHECK: %[[CST:.*]] = constant 1 : i32
// CHECK: return %[[CST]] : i32

%1 = constant 1 : i32
%result = call @private(%1) : (i32) -> i32
return %result : i32
}

// -----

/// Check that a constant is properly propagated through the arguments and
/// results of a visible nested function.

// CHECK: func @nested(
func @nested(%arg0 : i32) -> i32 attributes { sym_visibility = "nested" } {
// CHECK: %[[CST:.*]] = constant 1 : i32
// CHECK: return %[[CST]] : i32

return %arg0 : i32
}

// CHECK-LABEL: func @simple_nested(
func @simple_nested() -> i32 {
// CHECK: %[[CST:.*]] = constant 1 : i32
// CHECK: return %[[CST]] : i32

%1 = constant 1 : i32
%result = call @nested(%1) : (i32) -> i32
return %result : i32
}

// -----

/// Check that non-visible nested functions do not track arguments.
module {
// NESTED-LABEL: module @nested_module
module @nested_module attributes { sym_visibility = "public" } {

// NESTED: func @nested(
func @nested(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "nested" } {
// NESTED: %[[CST:.*]] = constant 1 : i32
// NESTED: return %[[CST]], %arg0 : i32, i32

%1 = constant 1 : i32
return %1, %arg0 : i32, i32
}

// NESTED: func @nested_not_all_uses_visible(
func @nested_not_all_uses_visible() -> (i32, i32) {
// NESTED: %[[CST:.*]] = constant 1 : i32
// NESTED: %[[CALL:.*]]:2 = call @nested
// NESTED: return %[[CST]], %[[CALL]]#1 : i32, i32

%1 = constant 1 : i32
%result:2 = call @nested(%1) : (i32) -> (i32, i32)
return %result#0, %result#1 : i32, i32
}
}
}

// -----

/// Check that public functions do not track arguments.

// CHECK-LABEL: func @public(
func @public(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "public" } {
%1 = constant 1 : i32
return %1, %arg0 : i32, i32
}

// CHECK-LABEL: func @simple_public(
func @simple_public() -> (i32, i32) {
// CHECK: %[[CST:.*]] = constant 1 : i32
// CHECK: %[[CALL:.*]]:2 = call @public
// CHECK: return %[[CST]], %[[CALL]]#1 : i32, i32

%1 = constant 1 : i32
%result:2 = call @public(%1) : (i32) -> (i32, i32)
return %result#0, %result#1 : i32, i32
}

// -----

/// Check that functions with non-call users don't have arguments tracked.

func @callable(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "private" } {
%1 = constant 1 : i32
return %1, %arg0 : i32, i32
}

// CHECK-LABEL: func @non_call_users(
func @non_call_users() -> (i32, i32) {
// CHECK: %[[CST:.*]] = constant 1 : i32
// CHECK: %[[CALL:.*]]:2 = call @callable
// CHECK: return %[[CST]], %[[CALL]]#1 : i32, i32

%1 = constant 1 : i32
%result:2 = call @callable(%1) : (i32) -> (i32, i32)
return %result#0, %result#1 : i32, i32
}

"live.user"() {uses = [@callable]} : () -> ()

// -----

/// Check that return values are overdefined in the presence of an unknown terminator.

func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
"unknown.return"(%arg0) : (i32) -> ()
}

// CHECK-LABEL: func @unknown_terminator(
func @unknown_terminator() -> i32 {
// CHECK: %[[CALL:.*]] = call @callable
// CHECK: return %[[CALL]] : i32

%1 = constant 1 : i32
%result = call @callable(%1) : (i32) -> i32
return %result : i32
}

// -----

/// Check that return values are overdefined when the constant conflicts.

func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
"unknown.return"(%arg0) : (i32) -> ()
}

// CHECK-LABEL: func @conflicting_constant(
func @conflicting_constant() -> (i32, i32) {
// CHECK: %[[CALL1:.*]] = call @callable
// CHECK: %[[CALL2:.*]] = call @callable
// CHECK: return %[[CALL1]], %[[CALL2]] : i32, i32

%1 = constant 1 : i32
%2 = constant 2 : i32
%result = call @callable(%1) : (i32) -> i32
%result2 = call @callable(%2) : (i32) -> i32
return %result, %result2 : i32, i32
}

// -----

/// Check that return values are overdefined when the constant conflicts with a
/// non-constant.

func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
"unknown.return"(%arg0) : (i32) -> ()
}

// CHECK-LABEL: func @conflicting_constant(
func @conflicting_constant(%arg0 : i32) -> (i32, i32) {
// CHECK: %[[CALL1:.*]] = call @callable
// CHECK: %[[CALL2:.*]] = call @callable
// CHECK: return %[[CALL1]], %[[CALL2]] : i32, i32

%1 = constant 1 : i32
%result = call @callable(%1) : (i32) -> i32
%result2 = call @callable(%arg0) : (i32) -> i32
return %result, %result2 : i32, i32
}

// -----

/// Check a more complex interaction with calls and control flow.

// CHECK-LABEL: func @complex_inner_if(
func @complex_inner_if(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
// CHECK-DAG: %[[TRUE:.*]] = constant 1 : i1
// CHECK-DAG: %[[CST:.*]] = constant 1 : i32
// CHECK: cond_br %[[TRUE]], ^bb1

%cst_20 = constant 20 : i32
%cond = cmpi "ult", %arg0, %cst_20 : i32
cond_br %cond, ^bb1, ^bb2

^bb1:
// CHECK: ^bb1:
// CHECK: return %[[CST]] : i32

%cst_1 = constant 1 : i32
return %cst_1 : i32

^bb2:
%cst_1_2 = constant 1 : i32
%arg_inc = addi %arg0, %cst_1_2 : i32
return %arg_inc : i32
}

func @complex_cond() -> i1

// CHECK-LABEL: func @complex_callee(
func @complex_callee(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
// CHECK: %[[CST:.*]] = constant 1 : i32

%loop_cond = call @complex_cond() : () -> i1
cond_br %loop_cond, ^bb1, ^bb2

^bb1:
// CHECK: ^bb1:
// CHECK-NEXT: return %[[CST]] : i32
return %arg0 : i32

^bb2:
// CHECK: ^bb2:
// CHECK: call @complex_inner_if(%[[CST]]) : (i32) -> i32
// CHECK: call @complex_callee(%[[CST]]) : (i32) -> i32
// CHECK: return %[[CST]] : i32

%updated_arg = call @complex_inner_if(%arg0) : (i32) -> i32
%res = call @complex_callee(%updated_arg) : (i32) -> i32
return %res : i32
}

// CHECK-LABEL: func @complex_caller(
func @complex_caller(%arg0 : i32) -> i32 {
// CHECK: %[[CST:.*]] = constant 1 : i32
// CHECK: return %[[CST]] : i32

%1 = constant 1 : i32
%result = call @complex_callee(%1) : (i32) -> i32
return %result : i32
}

// -----

/// Check that non-symbol defining callables currently go to overdefined.

// CHECK-LABEL: func @non_symbol_defining_callable
func @non_symbol_defining_callable() -> i32 {
// CHECK: %[[RES:.*]] = call_indirect
// CHECK: return %[[RES]] : i32

%fn = "test.functional_region_op"() ({
%1 = constant 1 : i32
"test.return"(%1) : (i32) -> ()
}) : () -> (() -> i32)
%res = call_indirect %fn() : () -> (i32)
return %res : i32
}
3 changes: 2 additions & 1 deletion mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffects.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
Expand Down Expand Up @@ -1089,7 +1090,7 @@ def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> {
//===----------------------------------------------------------------------===//

def TestRegionBuilderOp : TEST_Op<"region_builder">;
def TestReturnOp : TEST_Op<"return", [Terminator]>,
def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]>,
Arguments<(ins Variadic<AnyType>)>;
def TestCastOp : TEST_Op<"cast">,
Arguments<(ins Variadic<AnyType>)>, Results<(outs AnyType)>;
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/IR/TestSymbolUses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct SymbolUsesPass
// Walk nested symbols.
SmallVector<FuncOp, 4> deadFunctions;
module.getBodyRegion().walk([&](Operation *nestedOp) {
if (SymbolTable::isSymbol(nestedOp))
if (isa<SymbolOpInterface>(nestedOp))
return operateOnSymbol(nestedOp, module, deadFunctions);
return WalkResult::advance();
});
Expand Down
2 changes: 2 additions & 0 deletions mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
os << " static LogicalResult verifyTrait(Operation* op) {\n"
<< std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n";
}
if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
os << extraTraitDecls << "\n";

os << " };\n";
}
Expand Down