Skip to content

Commit

Permalink
[mlir] support non-interprocedural dataflow analyses (#75583)
Browse files Browse the repository at this point in the history
The core implementation of the dataflow anlysis framework is
interpocedural by design. While this offers better analysis precision,
it also comes with additional cost as it takes longer for the analysis
to reach the fixpoint state. Add a configuration mechanism to the
dataflow solver to control whether it operates inteprocedurally or not
to offer clients a choice.

As a positive side effect, this change also adds hooks for explicitly
processing external/opaque function calls in the dataflow analyses,
e.g., based off of attributes present in the the function declaration or
call operation such as alias scopes and modref available in the LLVM
dialect.

This change should not affect existing analyses and the default solver
configuration remains interprocedural.

Co-authored-by: Jacob Peng <jacobmpeng@gmail.com>
  • Loading branch information
ftynse and pengmai committed Dec 18, 2023
1 parent 82a1bff commit 32a4e3f
Show file tree
Hide file tree
Showing 12 changed files with 771 additions and 171 deletions.
40 changes: 30 additions & 10 deletions mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ namespace dataflow {
// CallControlFlowAction
//===----------------------------------------------------------------------===//

/// Indicates whether the control enters or exits the callee.
enum class CallControlFlowAction { EnterCallee, ExitCallee };
/// Indicates whether the control enters, exits, or skips over the callee (in
/// the case of external functions).
enum class CallControlFlowAction { EnterCallee, ExitCallee, ExternalCallee };

//===----------------------------------------------------------------------===//
// AbstractDenseLattice
Expand Down Expand Up @@ -131,14 +132,21 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {

/// Propagate the dense lattice forward along the call control flow edge,
/// which can be either entering or exiting the callee. Default implementation
/// just meets the states, meaning that operations implementing
/// `CallOpInterface` don't have any effect on the lattice that isn't already
/// expressed by the interface itself.
/// for enter and exit callee actions just meets the states, meaning that
/// operations implementing `CallOpInterface` don't have any effect on the
/// lattice that isn't already expressed by the interface itself. Default
/// implementation for the external callee action additionally sets the
/// "after" lattice to the entry state.
virtual void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
const AbstractDenseLattice &before,
AbstractDenseLattice *after) {
join(after, before);
// Note that `setToEntryState` may be a "partial fixpoint" for some
// lattices, e.g., lattices that are lists of maps of other lattices will
// only set fixpoint for "known" lattices.
if (action == CallControlFlowAction::ExternalCallee)
setToEntryState(after);
}

/// Visit a program point within a region branch operation with predecessors
Expand All @@ -155,7 +163,9 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {

/// Visit an operation for which the data flow is described by the
/// `CallOpInterface`.
void visitCallOperation(CallOpInterface call, AbstractDenseLattice *after);
void visitCallOperation(CallOpInterface call,
const AbstractDenseLattice &before,
AbstractDenseLattice *after);
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -361,14 +371,22 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {

/// Propagate the dense lattice backwards along the call control flow edge,
/// which can be either entering or exiting the callee. Default implementation
/// just meets the states, meaning that operations implementing
/// `CallOpInterface` don't have any effect on hte lattice that isn't already
/// expressed by the interface itself.
/// for enter and exit callee action just meets the states, meaning that
/// operations implementing `CallOpInterface` don't have any effect on the
/// lattice that isn't already expressed by the interface itself. Default
/// implementation for external callee action additional sets the result to
/// the exit (fixpoint) state.
virtual void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);

// Note that `setToExitState` may be a "partial fixpoint" for some lattices,
// e.g., lattices that are lists of maps of other lattices will only
// set fixpoint for "known" lattices.
if (action == CallControlFlowAction::ExternalCallee)
setToExitState(before);
}

private:
Expand All @@ -394,7 +412,9 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// otherwise,
/// - meet that state with the state before the call-like op, or use the
/// custom logic if overridden by concrete analyses.
void visitCallOperation(CallOpInterface call, AbstractDenseLattice *before);
void visitCallOperation(CallOpInterface call,
const AbstractDenseLattice &after,
AbstractDenseLattice *before);

/// Symbol table for call-level control flow.
SymbolTableCollection &symbolTable;
Expand Down
55 changes: 55 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"

Expand Down Expand Up @@ -199,6 +200,12 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;

/// The transfer function for calls to external functions.
virtual void visitExternalCallImpl(
CallOpInterface call,
ArrayRef<const AbstractSparseLattice *> argumentLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;

/// Given an operation with region control-flow, the lattices of the operands,
/// and a region successor, compute the lattice values for block arguments
/// that are not accounted for by the branching control flow (ex. the bounds
Expand Down Expand Up @@ -271,6 +278,14 @@ class SparseForwardDataFlowAnalysis
virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
ArrayRef<StateT *> results) = 0;

/// Visit a call operation to an externally defined function given the
/// lattices of its arguments.
virtual void visitExternalCall(CallOpInterface call,
ArrayRef<const StateT *> argumentLattices,
ArrayRef<StateT *> resultLattices) {
setAllToEntryStates(resultLattices);
}

/// Given an operation with possible region control-flow, the lattices of the
/// operands, and a region successor, compute the lattice values for block
/// arguments that are not accounted for by the branching control flow (ex.
Expand Down Expand Up @@ -321,6 +336,17 @@ class SparseForwardDataFlowAnalysis
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
void visitExternalCallImpl(
CallOpInterface call,
ArrayRef<const AbstractSparseLattice *> argumentLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) override {
visitExternalCall(
call,
{reinterpret_cast<const StateT *const *>(argumentLattices.begin()),
argumentLattices.size()},
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
void visitNonControlFlowArgumentsImpl(
Operation *op, const RegionSuccessor &successor,
ArrayRef<AbstractSparseLattice *> argLattices,
Expand Down Expand Up @@ -363,6 +389,11 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;

/// The transfer function for calls to external functions.
virtual void visitExternalCallImpl(
CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;

// Visit operands on branch instructions that are not forwarded.
virtual void visitBranchOperand(OpOperand &operand) = 0;

Expand Down Expand Up @@ -444,6 +475,19 @@ class SparseBackwardDataFlowAnalysis
virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
ArrayRef<const StateT *> results) = 0;

/// Visit a call to an external function. This function is expected to set
/// lattice values of the call operands. By default, calls `visitCallOperand`
/// for all operands.
virtual void visitExternalCall(CallOpInterface call,
ArrayRef<StateT *> argumentLattices,
ArrayRef<const StateT *> resultLattices) {
(void)argumentLattices;
(void)resultLattices;
for (OpOperand &operand : call->getOpOperands()) {
visitCallOperand(operand);
}
};

protected:
/// Get the lattice element for a value.
StateT *getLatticeElement(Value value) override {
Expand Down Expand Up @@ -474,6 +518,17 @@ class SparseBackwardDataFlowAnalysis
{reinterpret_cast<const StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}

void visitExternalCallImpl(
CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) override {
visitExternalCall(
call,
{reinterpret_cast<StateT *const *>(operandLattices.begin()),
operandLattices.size()},
{reinterpret_cast<const StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
};

} // end namespace dataflow
Expand Down
38 changes: 38 additions & 0 deletions mlir/include/mlir/Analysis/DataFlowFramework.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,32 @@ struct ProgramPoint
/// Forward declaration of the data-flow analysis class.
class DataFlowAnalysis;

//===----------------------------------------------------------------------===//
// DataFlowConfig
//===----------------------------------------------------------------------===//

/// Configuration class for data flow solver and child analyses. Follows the
/// fluent API pattern.
class DataFlowConfig {
public:
DataFlowConfig() = default;

/// Set whether the solver should operate interpocedurally, i.e. enter the
/// callee body when available. Interprocedural analyses may be more precise,
/// but also more expensive as more states need to be computed and the
/// fixpoint convergence takes longer.
DataFlowConfig &setInterprocedural(bool enable) {
interprocedural = enable;
return *this;
}

/// Return `true` if the solver operates interprocedurally, `false` otherwise.
bool isInterprocedural() const { return interprocedural; }

private:
bool interprocedural = true;
};

//===----------------------------------------------------------------------===//
// DataFlowSolver
//===----------------------------------------------------------------------===//
Expand All @@ -195,6 +221,9 @@ class DataFlowAnalysis;
/// TODO: Optimize the internal implementation of the solver.
class DataFlowSolver {
public:
explicit DataFlowSolver(const DataFlowConfig &config = DataFlowConfig())
: config(config) {}

/// Load an analysis into the solver. Return the analysis instance.
template <typename AnalysisT, typename... Args>
AnalysisT *load(Args &&...args);
Expand Down Expand Up @@ -236,7 +265,13 @@ class DataFlowSolver {
/// dependent work items to the back of the queue.
void propagateIfChanged(AnalysisState *state, ChangeResult changed);

/// Get the configuration of the solver.
const DataFlowConfig &getConfig() const { return config; }

private:
/// Configuration of the dataflow solver.
DataFlowConfig config;

/// The solver's work queue. Work items can be inserted to the front of the
/// queue to be processed greedily, speeding up computations that otherwise
/// quickly degenerate to quadratic due to propagation of state updates.
Expand Down Expand Up @@ -423,6 +458,9 @@ class DataFlowAnalysis {
return state;
}

/// Return the configuration of the solver used for this analysis.
const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }

#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// When compiling with debugging, keep a name for the analyis.
StringRef debugName;
Expand Down
42 changes: 29 additions & 13 deletions mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,22 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
}

void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call, AbstractDenseLattice *after) {
CallOpInterface call, const AbstractDenseLattice &before,
AbstractDenseLattice *after) {
// Allow for customizing the behavior of calls to external symbols, including
// when the analysis is explicitly marked as non-interprocedural.
auto callable =
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
if (!getSolverConfig().isInterprocedural() ||
(callable && !callable.getCallableRegion())) {
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, before, after);
}

const auto *predecessors =
getOrCreateFor<PredecessorState>(call.getOperation(), call);
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
// Otherwise, if not all return sites are known, then conservatively assume we
// can't reason about the data-flow.
if (!predecessors->allPredecessorsKnown())
return setToEntryState(after);

Expand Down Expand Up @@ -108,7 +118,7 @@ void AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
// If this is a call operation, then join its lattices across known return
// sites.
if (auto call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call, after);
return visitCallOperation(call, *before, after);

// Invoke the operation transfer function.
visitOperationImpl(op, *before, after);
Expand All @@ -130,8 +140,10 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
if (callable && callable.getCallableRegion() == block->getParent()) {
const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
if (!callsites->allPredecessorsKnown())
// having reached their pessimistic fixpoints. Do the same if
// interprocedural analysis is not enabled.
if (!callsites->allPredecessorsKnown() ||
!getSolverConfig().isInterprocedural())
return setToEntryState(after);
for (Operation *callsite : callsites->getKnownPredecessors()) {
// Get the dense lattice before the callsite.
Expand Down Expand Up @@ -267,18 +279,20 @@ LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
}

void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call, AbstractDenseLattice *before) {
CallOpInterface call, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
// Find the callee.
Operation *callee = call.resolveCallable(&symbolTable);
auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
if (!callable)
return setToExitState(before);

// No region means the callee is only declared in this module and we shouldn't
// assume anything about it.
// No region means the callee is only declared in this module.
Region *region = callable.getCallableRegion();
if (!region || region->empty())
return setToExitState(before);
if (!region || region->empty() || !getSolverConfig().isInterprocedural()) {
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, after, before);
}

// Call-level control flow specifies the data flow here.
//
Expand Down Expand Up @@ -324,7 +338,7 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
before);
if (auto call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call, before);
return visitCallOperation(call, *after, before);

// Invoke the operation transfer function.
visitOperationImpl(op, *after, before);
Expand Down Expand Up @@ -359,8 +373,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
// If not all call sites are known, conservative mark all lattices as
// having reached their pessimistic fix points.
if (!callsites->allPredecessorsKnown())
if (!callsites->allPredecessorsKnown() ||
!getSolverConfig().isInterprocedural()) {
return setToExitState(before);
}

for (Operation *callsite : callsites->getKnownPredecessors()) {
const AbstractDenseLattice *after;
Expand Down
Loading

0 comments on commit 32a4e3f

Please sign in to comment.