Skip to content

Commit

Permalink
[mlir] An implementation of sparse data-flow analysis
Browse files Browse the repository at this point in the history
This patch introduces a (forward) sparse data-flow analysis implemented with the data-flow analysis framework. The analysis interacts with liveness information that can be provided by dead-code analysis to be conditional. This patch re-implements SCCP using dead-code analysis and (conditional) constant propagation analyses.

Depends on D127064

Reviewed By: rriddle, phisiart

Differential Revision: https://reviews.llvm.org/D127139
  • Loading branch information
Mogball committed Jul 7, 2022
1 parent 6611d58 commit 9432fbf
Show file tree
Hide file tree
Showing 7 changed files with 547 additions and 160 deletions.
18 changes: 18 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
Expand Up @@ -60,6 +60,24 @@ class ConstantValue {
Dialect *dialect;
};

//===----------------------------------------------------------------------===//
// SparseConstantPropagation
//===----------------------------------------------------------------------===//

/// This analysis implements sparse constant propagation, which attempts to
/// determine constant-valued results for operations using constant-valued
/// operands, by speculatively folding operations. When combined with dead-code
/// analysis, this becomes sparse conditional constant propagation (SCCP).
class SparseConstantPropagation
: public SparseDataFlowAnalysis<Lattice<ConstantValue>> {
public:
using SparseDataFlowAnalysis::SparseDataFlowAnalysis;

void visitOperation(Operation *op,
ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) override;
};

} // end namespace dataflow
} // end namespace mlir

Expand Down
19 changes: 15 additions & 4 deletions mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
Expand Up @@ -89,6 +89,9 @@ class Executable : public AnalysisState {
/// the predecessor to its entry block, and the exiting terminator or a callable
/// operation can be the predecessor of the call operation.
///
/// The state can optionally contain information about which values are
/// propagated from each predecessor to the successor point.
///
/// The state can indicate that it is underdefined, meaning that not all live
/// control-flow predecessors can be known.
class PredecessorState : public AnalysisState {
Expand Down Expand Up @@ -118,12 +121,17 @@ class PredecessorState : public AnalysisState {
return knownPredecessors.getArrayRef();
}

/// Add a known predecessor.
ChangeResult join(Operation *predecessor) {
return knownPredecessors.insert(predecessor) ? ChangeResult::Change
: ChangeResult::NoChange;
/// Get the successor inputs from a predecessor.
ValueRange getSuccessorInputs(Operation *predecessor) const {
return successorInputs.lookup(predecessor);
}

/// Add a known predecessor.
ChangeResult join(Operation *predecessor);

/// Add a known predecessor with successor inputs.
ChangeResult join(Operation *predecessor, ValueRange inputs);

private:
/// Whether all predecessors are known. Optimistically assume that we know
/// all predecessors.
Expand All @@ -133,6 +141,9 @@ class PredecessorState : public AnalysisState {
SetVector<Operation *, SmallVector<Operation *, 4>,
SmallPtrSet<Operation *, 4>>
knownPredecessors;

/// The successor inputs when branching from a given predecessor.
DenseMap<Operation *, ValueRange> successorInputs;
};

//===----------------------------------------------------------------------===//
Expand Down
132 changes: 132 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
Expand Up @@ -16,6 +16,7 @@
#define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"

namespace mlir {
Expand Down Expand Up @@ -179,6 +180,137 @@ class Lattice : public AbstractSparseLattice {
Optional<ValueT> optimisticValue;
};

//===----------------------------------------------------------------------===//
// AbstractSparseDataFlowAnalysis
//===----------------------------------------------------------------------===//

/// Base class for sparse (forward) data-flow analyses. A sparse analysis
/// implements a transfer function on operations from the lattices of the
/// operands to the lattices of the results. This analysis will propagate
/// lattices across control-flow edges and the callgraph using liveness
/// information.
class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
public:
/// Initialize the analysis by visiting every owner of an SSA value: all
/// operations and blocks.
LogicalResult initialize(Operation *top) override;

/// Visit a program point. If this is a block and all control-flow
/// predecessors or callsites are known, then the arguments lattices are
/// propagated from them. If this is a call operation or an operation with
/// region control-flow, then its result lattices are set accordingly.
/// Otherwise, the operation transfer function is invoked.
LogicalResult visit(ProgramPoint point) override;

protected:
explicit AbstractSparseDataFlowAnalysis(DataFlowSolver &solver);

/// The operation transfer function. Given the operand lattices, this
/// function is expected to set the result lattices.
virtual void
visitOperationImpl(Operation *op,
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;

/// Get the lattice element of a value.
virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;

/// Get a read-only lattice element for a value and add it as a dependency to
/// a program point.
const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
Value value);

/// Mark the given lattice elements as having reached their pessimistic
/// fixpoints and propagate an update if any changed.
void markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice *> lattices);

/// Join the lattice element and propagate and update if it changed.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);

private:
/// Recursively initialize the analysis on nested operations and blocks.
LogicalResult initializeRecursively(Operation *op);

/// Visit an operation. If this is a call operation or an operation with
/// region control-flow, then its result lattices are set accordingly.
/// Otherwise, the operation transfer function is invoked.
void visitOperation(Operation *op);

/// Visit a block to compute the lattice values of its arguments. If this is
/// an entry block, then the argument values are determined from the block's
/// "predecessors" as set by `PredecessorState`. The predecessors can be
/// region terminators or callable callsites. Otherwise, the values are
/// determined from block predecessors.
void visitBlock(Block *block);

/// Visit a program point `point` with predecessors within a region branch
/// operation `branch`, which can either be the entry block of one of the
/// regions or the parent operation itself, and set either the argument or
/// parent result lattices.
void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
Optional<unsigned> successorIndex,
ArrayRef<AbstractSparseLattice *> lattices);
};

//===----------------------------------------------------------------------===//
// SparseDataFlowAnalysis
//===----------------------------------------------------------------------===//

/// A sparse (forward) data-flow analysis for propagating SSA value lattices
/// across the IR by implementing transfer functions for operations.
///
/// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
template <typename StateT>
class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
static_assert(
std::is_base_of<AbstractSparseLattice, StateT>::value,
"analysis state class expected to subclass AbstractSparseLattice");

public:
explicit SparseDataFlowAnalysis(DataFlowSolver &solver)
: AbstractSparseDataFlowAnalysis(solver) {}

/// Visit an operation with the lattices of its operands. This function is
/// expected to set the lattices of the operation's results.
virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
ArrayRef<StateT *> results) = 0;

protected:
/// Get the lattice element for a value.
StateT *getLatticeElement(Value value) override {
return getOrCreate<StateT>(value);
}

/// Get the lattice element for a value and create a dependency on the
/// provided program point.
const StateT *getLatticeElementFor(ProgramPoint point, Value value) {
return static_cast<const StateT *>(
AbstractSparseDataFlowAnalysis::getLatticeElementFor(point, value));
}

/// Mark the lattice elements of a range of values as having reached their
/// pessimistic fixpoint.
void markAllPessimisticFixpoint(ArrayRef<StateT *> lattices) {
AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
{reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
lattices.size()});
}

private:
/// Type-erased wrappers that convert the abstract lattice operands to derived
/// lattices and invoke the virtual hooks operating on the derived lattices.
void visitOperationImpl(
Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) override {
visitOperation(
op,
{reinterpret_cast<const StateT *const *>(operandLattices.begin()),
operandLattices.size()},
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
};

} // end namespace dataflow
} // end namespace mlir

Expand Down
69 changes: 69 additions & 0 deletions mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
Expand Up @@ -7,6 +7,10 @@
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "constant-propagation"

using namespace mlir;
using namespace mlir::dataflow;
Expand All @@ -20,3 +24,68 @@ void ConstantValue::print(raw_ostream &os) const {
return constant.print(os);
os << "<NO VALUE>";
}

//===----------------------------------------------------------------------===//
// SparseConstantPropagation
//===----------------------------------------------------------------------===//

void SparseConstantPropagation::visitOperation(
Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) {
LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");

// Don't try to simulate the results of a region operation as we can't
// guarantee that folding will be out-of-place. We don't allow in-place
// folds as the desire here is for simulated execution, and not general
// folding.
if (op->getNumRegions())
return;

SmallVector<Attribute, 8> constantOperands;
constantOperands.reserve(op->getNumOperands());
for (auto *operandLattice : operands)
constantOperands.push_back(operandLattice->getValue().getConstantValue());

// Save the original operands and attributes just in case the operation
// folds in-place. The constant passed in may not correspond to the real
// runtime value, so in-place updates are not allowed.
SmallVector<Value, 8> originalOperands(op->getOperands());
DictionaryAttr originalAttrs = op->getAttrDictionary();

// Simulate the result of folding this operation to a constant. If folding
// fails or was an in-place fold, mark the results as overdefined.
SmallVector<OpFoldResult, 8> foldResults;
foldResults.reserve(op->getNumResults());
if (failed(op->fold(constantOperands, foldResults))) {
markAllPessimisticFixpoint(results);
return;
}

// If the folding was in-place, mark the results as overdefined and reset
// the operation. We don't allow in-place folds as the desire here is for
// simulated execution, and not general folding.
if (foldResults.empty()) {
op->setOperands(originalOperands);
op->setAttrs(originalAttrs);
return;
}

// Merge the fold results into the lattice for this operation.
assert(foldResults.size() == op->getNumResults() && "invalid result size");
for (const auto it : llvm::zip(results, foldResults)) {
Lattice<ConstantValue> *lattice = std::get<0>(it);

// Merge in the result of the fold, either a constant or a value.
OpFoldResult foldResult = std::get<1>(it);
if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
propagateIfChanged(lattice,
lattice->join(ConstantValue(attr, op->getDialect())));
} else {
LLVM_DEBUG(llvm::dbgs()
<< "Folded to value: " << foldResult.get<Value>() << "\n");
AbstractSparseDataFlowAnalysis::join(
lattice, *getLatticeElement(foldResult.get<Value>()));
}
}
}
34 changes: 28 additions & 6 deletions mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
Expand Up @@ -59,6 +59,23 @@ void PredecessorState::print(raw_ostream &os) const {
os << " " << *op << "\n";
}

ChangeResult PredecessorState::join(Operation *predecessor) {
return knownPredecessors.insert(predecessor) ? ChangeResult::Change
: ChangeResult::NoChange;
}

ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) {
ChangeResult result = join(predecessor);
if (!inputs.empty()) {
ValueRange &curInputs = successorInputs[predecessor];
if (curInputs != inputs) {
curInputs = inputs;
result |= ChangeResult::Change;
}
}
return result;
}

//===----------------------------------------------------------------------===//
// CFGEdge
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -333,14 +350,18 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
SmallVector<RegionSuccessor> successors;
branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
for (const RegionSuccessor &successor : successors) {
// The successor can be either an entry block or the parent operation.
ProgramPoint point = successor.getSuccessor()
? &successor.getSuccessor()->front()
: ProgramPoint(branch);
// Mark the entry block as executable.
Region *region = successor.getSuccessor();
assert(region && "expected a region successor");
auto *state = getOrCreate<Executable>(&region->front());
auto *state = getOrCreate<Executable>(point);
propagateIfChanged(state, state->setToLive());
// Add the parent op as a predecessor.
auto *predecessors = getOrCreate<PredecessorState>(&region->front());
propagateIfChanged(predecessors, predecessors->join(branch));
auto *predecessors = getOrCreate<PredecessorState>(point);
propagateIfChanged(
predecessors,
predecessors->join(branch, successor.getSuccessorInputs()));
}
}

Expand All @@ -366,7 +387,8 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
// Add this terminator as a predecessor to the parent op.
predecessors = getOrCreate<PredecessorState>(branch);
}
propagateIfChanged(predecessors, predecessors->join(op));
propagateIfChanged(predecessors,
predecessors->join(op, successor.getSuccessorInputs()));
}
}

Expand Down

0 comments on commit 9432fbf

Please sign in to comment.