Skip to content

Commit

Permalink
[mlir] Swap integer range inference to the new framework
Browse files Browse the repository at this point in the history
Integer range inference has been swapped to the new framework. The integer value range lattices automatically updates the corresponding constant value on update.
  • Loading branch information
Mogball committed Jun 27, 2022
1 parent ef55eb3 commit a14f859
Show file tree
Hide file tree
Showing 11 changed files with 328 additions and 334 deletions.
1 change: 0 additions & 1 deletion mlir/include/mlir/Analysis/DataFlowFramework.h
Expand Up @@ -226,7 +226,6 @@ class DataFlowSolver {
/// Push a work item onto the worklist.
void enqueue(WorkItem item) { worklist.push(std::move(item)); }

protected:
/// Get the state associated with the given program point. If it does not
/// exist, create an uninitialized state.
template <typename StateT, typename PointT>
Expand Down
78 changes: 66 additions & 12 deletions mlir/include/mlir/Analysis/IntRangeAnalysis.h
Expand Up @@ -15,27 +15,81 @@
#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H
#define MLIR_ANALYSIS_INTRANGEANALYSIS_H

#include "mlir/Analysis/SparseDataFlowAnalysis.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"

namespace mlir {
namespace detail {
class IntRangeAnalysisImpl;
} // end namespace detail

class IntRangeAnalysis {
/// This lattice value represents the integer range of an SSA value.
class IntegerValueRange {
public:
/// Analyze all operations rooted under (but not including)
/// `topLevelOperation`.
IntRangeAnalysis(Operation *topLevelOperation);
IntRangeAnalysis(IntRangeAnalysis &&other);
~IntRangeAnalysis();
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
/// range that is used to mark the value as unable to be analyzed further,
/// where `t` is the type of `value`.
static IntegerValueRange getPessimisticValueState(Value value);

/// Get inferred range for value `v` if one exists.
Optional<ConstantIntRanges> getResult(Value v);
/// Create an integer value range lattice value.
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}

/// Get the known integer value range.
const ConstantIntRanges &getValue() const { return value; }

/// Compare two ranges.
bool operator==(const IntegerValueRange &rhs) const {
return value == rhs.value;
}

/// Take the union of two ranges.
static IntegerValueRange join(const IntegerValueRange &lhs,
const IntegerValueRange &rhs) {
return lhs.value.rangeUnion(rhs.value);
}

/// Print the integer value range.
void print(raw_ostream &os) const { os << value; }

private:
std::unique_ptr<detail::IntRangeAnalysisImpl> impl;
/// The known integer value range.
ConstantIntRanges value;
};

/// This lattice element represents the integer value range of an SSA value.
/// When this lattice is updated, it automatically updates the constant value
/// of the SSA value (if the range can be narrowed to one).
class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
public:
using Lattice::Lattice;

/// If the range can be narrowed to an integer constant, update the constant
/// value of the SSA value.
void onUpdate(DataFlowSolver *solver) const override;
};

/// Integer range analysis determines the integer value range of SSA values
/// using operations that define `InferIntRangeInterface` and also sets the
/// range of iteration indices of loops with known bounds.
class IntegerRangeAnalysis
: public SparseDataFlowAnalysis<IntegerValueRangeLattice> {
public:
using SparseDataFlowAnalysis::SparseDataFlowAnalysis;

/// Visit an operation. Invoke the transfer function on each operation that
/// implements `InferIntRangeInterface`.
void visitOperation(Operation *op,
ArrayRef<const IntegerValueRangeLattice *> operands,
ArrayRef<IntegerValueRangeLattice *> results) override;

/// Visit block arguments or operation results of an operation with region
/// control-flow for which values are not defined by region control-flow. This
/// function calls `InferIntRangeInterface` to provide values for block
/// arguments or tries to reduce the range on loop induction variables with
/// known bounds.
void
visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor,
ArrayRef<IntegerValueRangeLattice *> argLattices,
unsigned firstIndex) override;
};

} // end namespace mlir

#endif
33 changes: 33 additions & 0 deletions mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
Expand Up @@ -455,6 +455,14 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;

/// Given an operation with region control-flow, the lattices of the operands,
/// and a region successor, compute the lattice values for block arguments
/// that are not accounted for by the branching control flow (ex. the bounds
/// of loops).
virtual void visitNonControlFlowArgumentsImpl(
Operation *op, const RegionSuccessor &successor,
ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;

/// Get the lattice element of a value.
virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;

Expand Down Expand Up @@ -515,6 +523,21 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
ArrayRef<StateT *> results) = 0;

/// Given an operation with possible region control-flow, the lattices of the
/// operands, and a region successor, compute the lattice values for block
/// arguments that are not accounted for by the branching control flow (ex.
/// the bounds of loops). By default, this method marks all such lattice
/// elements as having reached a pessimistic fixpoint. `firstIndex` is the
/// index of the first element of `argLattices` that is set by control-flow.
virtual void visitNonControlFlowArguments(Operation *op,
const RegionSuccessor &successor,
ArrayRef<StateT *> argLattices,
unsigned firstIndex) {
markAllPessimisticFixpoint(argLattices.take_front(firstIndex));
markAllPessimisticFixpoint(argLattices.drop_front(
firstIndex + successor.getSuccessorInputs().size()));
}

protected:
/// Get the lattice element for a value.
StateT *getLatticeElement(Value value) override {
Expand Down Expand Up @@ -549,6 +572,16 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
void visitNonControlFlowArgumentsImpl(
Operation *op, const RegionSuccessor &successor,
ArrayRef<AbstractSparseLattice *> argLattices,
unsigned firstIndex) override {
visitNonControlFlowArguments(
op, successor,
{reinterpret_cast<StateT *const *>(argLattices.begin()),
argLattices.size()},
firstIndex);
}
};

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Interfaces/InferIntRangeInterface.td
Expand Up @@ -30,7 +30,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
since the dataflow analysis handles those case), the method should call
`setValueRange` with that `Value` as an argument. When `setValueRange`
is not called for some value, it will recieve a default value of the mimimum
and maximum values forits type (the unbounded range).
and maximum values for its type (the unbounded range).

When called on an op that also implements the RegionBranchOpInterface
or BranchOpInterface, this method should not attempt to infer the values
Expand Down
13 changes: 0 additions & 13 deletions mlir/lib/Analysis/DataFlowFramework.cpp
Expand Up @@ -87,19 +87,6 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
return failure();
}

// "Nudge" the state of the analysis by forcefully initializing states that
// are still uninitialized. All uninitialized states in the graph can be
// initialized in any order because the analysis reached fixpoint, meaning
// that there are no work items that would have further nudged the analysis.
for (AnalysisState &state :
llvm::make_pointee_range(llvm::make_second_range(analysisStates))) {
if (!state.isUninitialized())
continue;
DATAFLOW_DEBUG(llvm::dbgs() << "Default initializing " << state.debugName
<< " of " << state.point << "\n");
propagateIfChanged(&state, state.defaultInitialize());
}

// Iterate until all states are in some initialized state and the worklist
// is exhausted.
} while (!worklist.empty());
Expand Down

0 comments on commit a14f859

Please sign in to comment.