Skip to content

Commit

Permalink
[RELAY] Non-recursive Graph Vistor and Rewriter (apache#4886)
Browse files Browse the repository at this point in the history
* First pass a defining a non-recursive Graph Vistor and Rewriter

autoformat

remove a currently empty test until testing is solidfied

* Make CalcDep from Dead Code Elimination non-recursive

* Partially working, not passing all tests yet

passes tests when disabling GetExprRefCount, I think I have a bug in visit counting

fix GetExprRefCount

Fix a subtle bug with nested recursive/non-recursive scopes

* Refactor

* improve comments

* respond to review comments on comments

* Fix a problem with default recursion for dataflow nodes

mark DataflowVisitor methods as override

* implement ScopeMutator

* convert forward_rewrite to ScopeMutator, remove DataflowMutator

* rewrite ExprRewriter and convert fast_math to use it

* switch BiasAddSimplifier to ExprRewriter

fix a clang warning

fix cpp lint

fix doc param error

* respond to review comments

* fix a typo in the iterative looping

* add a regression test for GetExprRefCount issue

* Normalize naming

* fix lint

* First pass a defining a non-recursive Graph Vistor and Rewriter

autoformat

remove a currently empty test until testing is solidfied

* Make CalcDep from Dead Code Elimination non-recursive

* Partially working, not passing all tests yet

passes tests when disabling GetExprRefCount, I think I have a bug in visit counting

fix GetExprRefCount

Fix a subtle bug with nested recursive/non-recursive scopes

* Refactor

* improve comments

* respond to review comments on comments

* Fix a problem with default recursion for dataflow nodes

mark DataflowVisitor methods as override

* implement ScopeMutator

* convert forward_rewrite to ScopeMutator, remove DataflowMutator

* rewrite ExprRewriter and convert fast_math to use it

* switch BiasAddSimplifier to ExprRewriter

fix a clang warning

fix cpp lint

fix doc param error

* respond to review comments

* fix a typo in the iterative looping

* add a regression test for GetExprRefCount issue

* Normalize naming

* fix lint

* respond to review comments
  • Loading branch information
mbrookhart authored and zhiics committed Apr 17, 2020
1 parent a3c951a commit 8689711
Show file tree
Hide file tree
Showing 10 changed files with 416 additions and 65 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relay/analysis.h
Expand Up @@ -30,6 +30,7 @@
#include <tvm/ir/module.h>
#include <tvm/relay/type.h>
#include <string>
#include <unordered_map>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -225,6 +226,16 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);

/*!
* \brief Get reference counter of each internal ExprNode in body.
*
* \param body The body expression.
*
* \return The reference count mapping.
*/
TVM_DLL std::unordered_map<const Object*, size_t>
GetExprRefCount(const Expr& body);

} // namespace relay
} // namespace tvm

Expand Down
183 changes: 183 additions & 0 deletions include/tvm/relay/expr_functor.h
Expand Up @@ -232,6 +232,189 @@ class ExprMutator
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
};

/*!
* \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
*
* MixedModeVisitor treats Expr as dataflow graph, and visits in post-DFS order
*
* MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
*/
class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
public:
/*! \brief The constructor of MixedModeVisitor
* \param visit_limit The number of times to allow visitation to a node. Usually 1, ocassionally
* higher (i.e., 2 for dead code elimiation), limited to 10 as a sanity check.
*/
explicit MixedModeVisitor(int visit_limit = 1);

/*!
* \brief VisitExpr is finalized to preserve call expansion of dataflow regions
*/
void VisitExpr(const Expr& expr) final;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;

protected:
/*!
* \brief A function to apply when reaching a leaf of the graph non-recursively
*/
virtual void VisitLeaf(const Expr& expr);
/*!
* \brief A function to determine if an expression has already been visited or needs to be
* re-visited
*/
virtual bool CheckVisited(const Expr& expr);
/*!
* \brief The max number of times to visit a node
*/
size_t visit_limit_;
};

/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* MixedModeMutator treats Expr as dataflow graph, and only Rewrites each Expr once.
* The mutated results are memoized in a map and reused so that
* local transformation on the dataflow preserves the graph structure.
*
* MixedModeMutator provides the same recursive API as ExprMutator, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
*
* Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive behavior.
*/
class MixedModeMutator : public ::tvm::relay::ExprMutator {
public:
Expr VisitExpr(const Expr& expr) final;
virtual Expr DispatchVisitExpr(const Expr& expr);
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
/*!
* \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be
* able to rewrite the op only with data about the original node `pre` and the same node with
* modified inputs `post` and should not recurse.
*
* \param pre The expression node before rewriting.
* \param post The expression with rewritten inputs.
*/
virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;}
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; }
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }

protected:
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
* changed inputs.
*/
template <typename T>
Expr Rewrite(const T* op) {
Expr post = ExprMutator::VisitExpr_(op);
return Rewrite_(op, post);
}

virtual void VisitLeaf(const Expr& expr);
virtual bool CheckVisited(const Expr& expr);
};

#define RELAY_EXPR_REWRITER_DISPATCH(OP) \
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \
return self->Rewrite_(static_cast<const OP*>(n.get()), post); \
});

#define EXPR_REWRITER_REWRITE_DEFAULT \
{ return post; }

/*! \brief A non-iterating Expression Rewriter
*
* ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
*
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
* node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
* graph rewriting.
*/
class ExprRewriter {
private:
using TSelf = ExprRewriter;
using FType = tvm::NodeFunctor<Expr(const ObjectRef& n, TSelf* self, const Expr& post)>;

public:
/*! \brief virtual destructor */
virtual ~ExprRewriter() {}
/*!
* \brief Same as call.
* \param pre The expression node before rewriting.
* \param post The expression node with rewritten inputs.
* \return The result of the call
*/
Expr operator()(const Expr& pre, const Expr& post) {
return Rewrite(pre, post);
}
/*!
* \brief The functor call.
* \param pre The expression node before rewriting.
* \param post The expression node with rewritten inputs.
* \return The result of the call
*/
virtual Expr Rewrite(const Expr& pre, const Expr& post) {
CHECK(pre.defined());
static FType vtable = InitVTable();
return vtable(pre, this, post);
}
// Functions that can be overriden by subclass, should not recurse
virtual Expr Rewrite_(const VarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const GlobalVarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const ConstantNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const FunctionNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const LetNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const IfNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const OpNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const TupleGetItemNode* pre,
const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const RefCreateNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const RefReadNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const RefWriteNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const ConstructorNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const MatchNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;

private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_EXPR_REWRITER_DISPATCH(ConstantNode);
RELAY_EXPR_REWRITER_DISPATCH(TupleNode);
RELAY_EXPR_REWRITER_DISPATCH(VarNode);
RELAY_EXPR_REWRITER_DISPATCH(GlobalVarNode);
RELAY_EXPR_REWRITER_DISPATCH(FunctionNode);
RELAY_EXPR_REWRITER_DISPATCH(CallNode);
RELAY_EXPR_REWRITER_DISPATCH(LetNode);
RELAY_EXPR_REWRITER_DISPATCH(IfNode);
RELAY_EXPR_REWRITER_DISPATCH(OpNode);
RELAY_EXPR_REWRITER_DISPATCH(TupleGetItemNode);
RELAY_EXPR_REWRITER_DISPATCH(RefCreateNode);
RELAY_EXPR_REWRITER_DISPATCH(RefReadNode);
RELAY_EXPR_REWRITER_DISPATCH(RefWriteNode);
RELAY_EXPR_REWRITER_DISPATCH(ConstructorNode);
RELAY_EXPR_REWRITER_DISPATCH(MatchNode);
return vtable;
}
};

/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
* PostOrderRewrite provides the original node and the node with altered inputs for use by the
* ExprRewriter.
*/
Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/util.cc
Expand Up @@ -330,7 +330,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars")
*/
std::unordered_map<const Object*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
class ExprRefCounter : private MixedModeVisitor {
public:
std::unordered_map<const Object*, size_t>
Get(const Expr& body) {
Expand Down

0 comments on commit 8689711

Please sign in to comment.