diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h index 3c87c453a4cf0..5b7b45fdd1d58 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -127,6 +127,18 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis { /// them into the same equivalent class. virtual void buildOperationEquivalentLatticeAnchor(Operation *op) {} + /// Visit a block and propagate the dense lattice forward along the control + /// flow edge from predecessor to block. `point` corresponds to the program + /// point before `block`. The default implementation merges in the state from + /// the predecessor's terminator. + virtual void visitBlockTransfer(Block *block, ProgramPoint *point, + Block *predecessor, + const AbstractDenseLattice &before, + AbstractDenseLattice *after) { + // Merge in the state from the predecessor's terminator. + join(after, before); + } + /// Propagate the dense lattice forward along the control flow edge from /// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt` /// values correspond to control flow branches originating at or targeting the @@ -259,6 +271,22 @@ class DenseForwardDataFlowAnalysis branch, regionFrom, regionTo, before, after); } + /// Hook for customizing the behavior of lattice propagation along the control + /// flow edges between blocks. The control flows from `predecessor` to + /// `block`. The lattice is propagated forward along this edge. The lattices + /// are as follows: + /// - `before` is the lattice at the end of the predecessor block; + /// - `after` is the lattice at the beginning of the block. + /// By default, the `after` state is simply joined with the `before` state. + /// Concrete analyses can override this behavior or delegate to the parent + /// call for the default behavior. + virtual void visitBlockTransfer(Block *block, ProgramPoint *point, + Block *predecessor, const LatticeT &before, + LatticeT *after) { + AbstractDenseForwardDataFlowAnalysis::visitBlockTransfer( + block, point, predecessor, before, after); + } + protected: /// Get the dense lattice on this lattice anchor. LatticeT *getLattice(LatticeAnchor anchor) override { @@ -306,6 +334,13 @@ class DenseForwardDataFlowAnalysis static_cast(before), static_cast(after)); } + void visitBlockTransfer(Block *block, ProgramPoint *point, Block *predecessor, + const AbstractDenseLattice &before, + AbstractDenseLattice *after) final { + visitBlockTransfer(block, point, predecessor, + static_cast(before), + static_cast(after)); + } }; //===----------------------------------------------------------------------===// @@ -388,6 +423,17 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis { /// them into the same equivalent class. virtual void buildOperationEquivalentLatticeAnchor(Operation *op) {} + /// Visit a block and propagate the dense lattice backward along the control + /// flow edge from successor to block. `point` corresponds to the program + /// point after `block`. The default implementation merges in the state from + /// the successor's first operation or the block itself when empty. + virtual void visitBlockTransfer(Block *block, ProgramPoint *point, + Block *successor, + const AbstractDenseLattice &after, + AbstractDenseLattice *before) { + meet(before, after); + } + /// Propagate the dense lattice backwards along the control flow edge from /// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt` /// values correspond to control flow branches originating at or targeting the @@ -531,6 +577,22 @@ class DenseBackwardDataFlowAnalysis branch, regionFrom, regionTo, after, before); } + /// Hook for customizing the behavior of lattice propagation along the control + /// flow edges between blocks. The control flows from `successor` to + /// `block`. The lattice is propagated back along this edge. The lattices + /// are as follows: + /// - `after` is the lattice at the beginning of the successor block; + /// - `before` is the lattice at the end of the block. + /// By default, the `before` state is simply met with the `after` state. + /// Concrete analyses can override this behavior or delegate to the parent + /// call for the default behavior. + virtual void visitBlockTransfer(Block *block, ProgramPoint *point, + Block *successor, const LatticeT &after, + LatticeT *before) { + AbstractDenseBackwardDataFlowAnalysis::visitBlockTransfer( + block, point, successor, after, before); + } + protected: /// Get the dense lattice at the given lattice anchor. LatticeT *getLattice(LatticeAnchor anchor) override { @@ -577,6 +639,13 @@ class DenseBackwardDataFlowAnalysis static_cast(after), static_cast(before)); } + void visitBlockTransfer(Block *block, ProgramPoint *point, Block *successor, + const AbstractDenseLattice &after, + AbstractDenseLattice *before) final { + visitBlockTransfer(block, point, successor, + static_cast(after), + static_cast(before)); + } }; } // end namespace dataflow diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index 0682e5f26785a..22bc0b32a9bd1 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -266,9 +266,10 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { } LDBG() << " Joining state from predecessor " << predecessor; + const AbstractDenseLattice &before = *getLatticeFor( + point, getProgramPointAfter(predecessor->getTerminator())); // Merge in the state from the predecessor's terminator. - join(after, *getLatticeFor( - point, getProgramPointAfter(predecessor->getTerminator()))); + visitBlockTransfer(block, point, predecessor, before, after); } } @@ -614,7 +615,9 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { LDBG() << " Meeting state from successor " << successor; // Merge in the state from the successor: either the first operation, or the // block itself when empty. - meet(before, *getLatticeFor(point, getProgramPointBefore(successor))); + visitBlockTransfer(block, point, successor, + *getLatticeFor(point, getProgramPointBefore(successor)), + before); } }