Expand Up
@@ -116,12 +116,56 @@ class LatticeValue {
Dialect *constantDialect;
};
// / This class contains various state used when computing the lattice of a
// / callable operation.
class CallableLatticeState {
public:
// / Build a lattice state with a given callable region, and a specified number
// / of results to be initialized to the default lattice value (Unknown).
CallableLatticeState (Region *callableRegion, unsigned numResults)
: callableArguments(callableRegion->front ().getArguments()),
resultLatticeValues(numResults) {}
// / Returns the arguments to the callable region.
Block::BlockArgListType getCallableArguments () const {
return callableArguments;
}
// / Returns the lattice value for the results of the callable region.
MutableArrayRef<LatticeValue> getResultLatticeValues () {
return resultLatticeValues;
}
// / Add a call to this callable. This is only used if the callable defines a
// / symbol.
void addSymbolCall (Operation *op) { symbolCalls.push_back (op); }
// / Return the calls that reference this callable. This is only used
// / if the callable defines a symbol.
ArrayRef<Operation *> getSymbolCalls () const { return symbolCalls; }
private:
// / The arguments of the callable region.
Block::BlockArgListType callableArguments;
// / The lattice state for each of the results of this region. The return
// / values of the callable aren't SSA values, so we need to track them
// / separately.
SmallVector<LatticeValue, 4 > resultLatticeValues;
// / The calls referencing this callable if this callable defines a symbol.
// / This removes the need to recompute symbol references during propagation.
// / Value based references are trivial to resolve, so they can be done
// / in-place.
SmallVector<Operation *, 4 > symbolCalls;
};
// / This class represents the solver for the SCCP analysis. This class acts as
// / the propagation engine for computing which values form constants.
class SCCPSolver {
public:
// / Initialize the solver with a given set of regions .
SCCPSolver (MutableArrayRef<Region> regions );
// / Initialize the solver with the given top-level operation .
SCCPSolver (Operation *op );
// / Run the solver until it converges.
void solve ();
Expand All
@@ -132,6 +176,11 @@ class SCCPSolver {
void rewrite (MLIRContext *context, MutableArrayRef<Region> regions);
private:
// / Initialize the set of symbol defining callables that can have their
// / arguments and results tracked. 'op' is the top-level operation that SCCP
// / is operating on.
void initializeSymbolCallables (Operation *op);
// / Replace the given value with a constant if the corresponding lattice
// / represents a constant. Returns success if the value was replaced, failure
// / otherwise.
Expand All
@@ -149,6 +198,13 @@ class SCCPSolver {
// / Visit the given operation and compute any necessary lattice state.
void visitOperation (Operation *op);
// / Visit the given call operation and compute any necessary lattice state.
void visitCallOperation (CallOpInterface op);
// / Visit the given callable operation and compute any necessary lattice
// / state.
void visitCallableOperation (Operation *op);
// / Visit the given operation, which defines regions, and compute any
// / necessary lattice state. This also resolves the lattice state of both the
// / operation results and any nested regions.
Expand All
@@ -168,6 +224,11 @@ class SCCPSolver {
void visitTerminatorOperation (Operation *op,
ArrayRef<Attribute> constantOperands);
// / Visit the given terminator operation that exits a callable region. These
// / are terminators with no CFG successors.
void visitCallableTerminatorOperation (Operation *callable,
Operation *terminator);
// / Visit the given block and compute any necessary lattice state.
void visitBlock (Block *block);
Expand Down
Expand Up
@@ -235,11 +296,20 @@ class SCCPSolver {
// / A worklist of operations that need to be processed.
SmallVector<Operation *, 64 > opWorklist;
// / The callable operations that have their argument/result state tracked.
DenseMap<Operation *, CallableLatticeState> callableLatticeState;
// / A map between a call operation and the resolved symbol callable. This
// / avoids re-resolving symbol references during propagation. Value based
// / callables are trivial to resolve, so they can be done in-place.
DenseMap<Operation *, Operation *> callToSymbolCallable;
};
} // end anonymous namespace
SCCPSolver::SCCPSolver (MutableArrayRef<Region> regions) {
for (Region ®ion : regions) {
SCCPSolver::SCCPSolver (Operation *op) {
// / Initialize the solver with the regions within this operation.
for (Region ®ion : op->getRegions ()) {
if (region.empty ())
continue ;
Block *entryBlock = ®ion.front ();
Expand All
@@ -251,6 +321,7 @@ SCCPSolver::SCCPSolver(MutableArrayRef<Region> regions) {
// as overdefined.
markAllOverdefined (entryBlock->getArguments ());
}
initializeSymbolCallables (op);
}
void SCCPSolver::solve () {
Expand Down
Expand Up
@@ -310,6 +381,73 @@ void SCCPSolver::rewrite(MLIRContext *context,
}
}
void SCCPSolver::initializeSymbolCallables (Operation *op) {
// Initialize the set of symbol callables that can have their state tracked.
// This tracks which symbol callable operations we can propagate within and
// out of.
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
Region &symbolTableRegion = symTable->getRegion (0 );
Block *symbolTableBlock = &symbolTableRegion.front ();
for (auto callable : symbolTableBlock->getOps <CallableOpInterface>()) {
// We won't be able to track external callables.
Region *callableRegion = callable.getCallableRegion ();
if (!callableRegion)
continue ;
// We only care about symbol defining callables here.
auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation ());
if (!symbol)
continue ;
callableLatticeState.try_emplace (callable, callableRegion,
callable.getCallableResults ().size ());
// If not all of the uses of this symbol are visible, we can't track the
// state of the arguments.
if (symbol.isPublic () || (!allUsesVisible && symbol.isNested ()))
markAllOverdefined (callableRegion->front ().getArguments ());
}
if (callableLatticeState.empty ())
return ;
// After computing the valid callables, walk any symbol uses to check
// for non-call references. We won't be able to track the lattice state
// for arguments to these callables, as we can't guarantee that we can see
// all of its calls.
Optional<SymbolTable::UseRange> uses =
SymbolTable::getSymbolUses (&symbolTableRegion);
if (!uses) {
// If we couldn't gather the symbol uses, conservatively assume that
// we can't track information for any nested symbols.
op->walk ([&](CallableOpInterface op) { callableLatticeState.erase (op); });
return ;
}
for (const SymbolTable::SymbolUse &use : *uses) {
// If the use is a call, track it to avoid the need to recompute the
// reference later.
if (auto callOp = dyn_cast<CallOpInterface>(use.getUser ())) {
Operation *symCallable = callOp.resolveCallable ();
auto callableLatticeIt = callableLatticeState.find (symCallable);
if (callableLatticeIt != callableLatticeState.end ()) {
callToSymbolCallable.try_emplace (callOp, symCallable);
// We only need to record the call in the lattice if it produces any
// values.
if (callOp.getOperation ()->getNumResults ())
callableLatticeIt->second .addSymbolCall (callOp);
}
continue ;
}
// This use isn't a call, so don't we know all of the callers.
auto *symbol = SymbolTable::lookupSymbolIn (op, use.getSymbolRef ());
auto it = callableLatticeState.find (symbol);
if (it != callableLatticeState.end ())
markAllOverdefined (it->second .getCallableArguments ());
}
};
SymbolTable::walkSymbolTables (op, /* allSymUsesVisible=*/ !op->getBlock (),
walkFn);
}
LogicalResult SCCPSolver::replaceWithConstant (OpBuilder &builder,
OperationFolder &folder,
Value value) {
Expand Down
Expand Up
@@ -347,6 +485,16 @@ void SCCPSolver::visitOperation(Operation *op) {
if (op->isKnownTerminator ())
visitTerminatorOperation (op, operandConstants);
// Process call operations. The call visitor processes result values, so we
// can exit afterwards.
if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
return visitCallOperation (call);
// Process callable operations. These are specially handled region operations
// that track dataflow via calls.
if (isa<CallableOpInterface>(op))
return visitCallableOperation (op);
// Process region holding operations. The region visitor processes result
// values, so we can exit afterwards.
if (op->getNumRegions ())
Expand Down
Expand Up
@@ -399,6 +547,62 @@ void SCCPSolver::visitOperation(Operation *op) {
}
}
void SCCPSolver::visitCallableOperation (Operation *op) {
// Mark the regions as executable.
bool isTrackingLatticeState = callableLatticeState.count (op);
for (Region ®ion : op->getRegions ()) {
if (region.empty ())
continue ;
Block *entryBlock = ®ion.front ();
markBlockExecutable (entryBlock);
// If we aren't tracking lattice state for this callable, mark all of the
// region arguments as overdefined.
if (!isTrackingLatticeState)
markAllOverdefined (entryBlock->getArguments ());
}
// TODO: Add support for non-symbol callables when necessary. If the callable
// has non-call uses we would mark overdefined, otherwise allow for
// propagating the return values out.
markAllOverdefined (op, op->getResults ());
}
void SCCPSolver::visitCallOperation (CallOpInterface op) {
ResultRange callResults = op.getOperation ()->getResults ();
// Resolve the callable operation for this call.
Operation *callableOp = nullptr ;
if (Value callableValue = op.getCallableForCallee ().dyn_cast <Value>())
callableOp = callableValue.getDefiningOp ();
else
callableOp = callToSymbolCallable.lookup (op);
// The callable of this call can't be resolved, mark any results overdefined.
if (!callableOp)
return markAllOverdefined (op, callResults);
// If this callable is tracking state, merge the argument operands with the
// arguments of the callable.
auto callableLatticeIt = callableLatticeState.find (callableOp);
if (callableLatticeIt == callableLatticeState.end ())
return markAllOverdefined (op, callResults);
OperandRange callOperands = op.getArgOperands ();
auto callableArgs = callableLatticeIt->second .getCallableArguments ();
for (auto it : llvm::zip (callOperands, callableArgs)) {
BlockArgument callableArg = std::get<1 >(it);
if (latticeValues[callableArg].meet (latticeValues[std::get<0 >(it)]))
visitUsers (callableArg);
}
// Merge in the lattice state for the callable results as well.
auto callableResults = callableLatticeIt->second .getResultLatticeValues ();
for (auto it : llvm::zip (callResults, callableResults))
meet (/* owner=*/ op, /* to=*/ latticeValues[std::get<0 >(it)],
/* from=*/ std::get<1 >(it));
}
void SCCPSolver::visitRegionOperation (Operation *op,
ArrayRef<Attribute> constantOperands) {
// Check to see if we can reason about the internal control flow of this
Expand Down
Expand Up
@@ -509,9 +713,14 @@ void SCCPSolver::visitTerminatorOperation(
Operation *op, ArrayRef<Attribute> constantOperands) {
// If this operation has no successors, we treat it as an exiting terminator.
if (op->getNumSuccessors () == 0 ) {
// Check to see if the parent tracks region control flow.
Region *parentRegion = op->getParentRegion ();
Operation *parentOp = parentRegion->getParentOp ();
// Check to see if this is a terminator for a callable region.
if (isa<CallableOpInterface>(parentOp))
return visitCallableTerminatorOperation (parentOp, op);
// Otherwise, check to see if the parent tracks region control flow.
auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
if (!regionInterface || !isBlockExecutable (parentOp->getBlock ()))
return ;
Expand Down
Expand Up
@@ -552,6 +761,42 @@ void SCCPSolver::visitTerminatorOperation(
markEdgeExecutable (block, succ);
}
void SCCPSolver::visitCallableTerminatorOperation (Operation *callable,
Operation *terminator) {
// If there are no exiting values, we have nothing to track.
if (terminator->getNumOperands () == 0 )
return ;
// If this callable isn't tracking any lattice state there is nothing to do.
auto latticeIt = callableLatticeState.find (callable);
if (latticeIt == callableLatticeState.end ())
return ;
assert (callable->getNumResults () == 0 && " expected symbol callable" );
// If this terminator is not "return-like", conservatively mark all of the
// call-site results as overdefined.
auto callableResultLattices = latticeIt->second .getResultLatticeValues ();
if (!terminator->hasTrait <OpTrait::ReturnLike>()) {
for (auto &it : callableResultLattices)
it.markOverdefined ();
for (Operation *call : latticeIt->second .getSymbolCalls ())
markAllOverdefined (call, call->getResults ());
return ;
}
// Merge the terminator operands into the results.
bool anyChanged = false ;
for (auto it : llvm::zip (terminator->getOperands (), callableResultLattices))
anyChanged |= std::get<1 >(it).meet (latticeValues[std::get<0 >(it)]);
if (!anyChanged)
return ;
// If any of the result lattices changed, update the callers.
for (Operation *call : latticeIt->second .getSymbolCalls ())
for (auto it : llvm::zip (call->getResults (), callableResultLattices))
meet (call, latticeValues[std::get<0 >(it)], std::get<1 >(it));
}
void SCCPSolver::visitBlock (Block *block) {
// If the block is not the entry block we need to compute the lattice state
// for the block arguments. Entry block argument lattices are computed
Expand Down
Expand Up
@@ -663,7 +908,7 @@ void SCCP::runOnOperation() {
Operation *op = getOperation ();
// Solve for SCCP constraints within nested regions.
SCCPSolver solver (op-> getRegions () );
SCCPSolver solver (op);
solver.solve ();
// Cleanup any operations using the solver analysis.
Expand Down