Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[RELAY] IR Wellform Checker (#1748)
  • Loading branch information
MarisaKirisame authored and tqchen committed Sep 22, 2018
1 parent e22ac6b commit 7beafdd
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 8 deletions.
12 changes: 5 additions & 7 deletions include/tvm/relay/error.h
Expand Up @@ -12,21 +12,19 @@
namespace tvm {
namespace relay {

struct Error : dmlc::Error {
struct Error : public dmlc::Error {
explicit Error(const std::string &msg) : dmlc::Error(msg) {}
};

struct InternalError : Error {
struct InternalError : public Error {
explicit InternalError(const std::string &msg) : Error(msg) {}
};

// TODO(@jroesch): we should change spanned errors to report
// errors against the Environment, inverting control to error definition.
struct FatalTypeError : dmlc::Error {
explicit FatalTypeError(const std::string &s) : dmlc::Error(s) {}
struct FatalTypeError : public Error {
explicit FatalTypeError(const std::string &s) : Error(s) {}
};

struct TypecheckerError : public dmlc::Error {
struct TypecheckerError : public Error {
explicit TypecheckerError(const std::string &msg) : Error(msg) {}
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/expr.h
Expand Up @@ -8,6 +8,7 @@

#include <tvm/attrs.h>
#include <string>
#include <functional>
#include "./base.h"
#include "./type.h"

Expand Down
12 changes: 12 additions & 0 deletions include/tvm/relay/pass.h
Expand Up @@ -80,6 +80,18 @@ bool AlphaEqual(const Expr& e1, const Expr& e2);
*/
bool AlphaEqual(const Type& t1, const Type& t2);

/*! brief Check that each Var is only bind once.
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
*
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, although x is not shadowed.
*
* \param e the expression to check.
*
* \return true iff all Var in e is bind at most once.
*/
bool WellFormed(const Expr & e);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_H_
1 change: 1 addition & 0 deletions python/tvm/relay/_ir_pass.pyi
Expand Up @@ -4,3 +4,4 @@ from . import ir
def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
def well_formed(expr: ir.Expr) -> bool: ...
2 changes: 2 additions & 0 deletions python/tvm/relay/ir_pass.py
Expand Up @@ -10,3 +10,5 @@
# Expose checking expression, should rename to infer_type.
# pylint: disable=invalid-name
check_expr = _ir_pass.check_expr

well_formed = _ir_pass.well_formed
61 changes: 61 additions & 0 deletions src/relay/pass/well_formed.cc
@@ -0,0 +1,61 @@
/*!
* Copyright (c) 2018 by Contributors
* \file well_formed.cc
* \brief check that expression is well formed.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <unordered_set>

namespace tvm {
namespace relay {

struct NotWellFormed { };

//! brief make sure each Var is bind at most once.
class WellFormedChecker : private ExprVisitor {
bool well_formed = true;

std::unordered_set<Var, NodeHash, NodeEqual> s;

void Check(const Var & v) {
if (s.count(v) != 0) {
well_formed = false;
}
s.insert(v);
}

void VisitExpr_(const LetNode * l) final {
// we do letrec only for FunctionNode,
// but shadowing let in let binding is likely programming error, and we should forbidden it.
Check(l->var);
CheckWellFormed(l->value);
CheckWellFormed(l->body);
}

void VisitExpr_(const FunctionNode * f) final {
for (const Param & p : f->params) {
Check(p->var);
}
CheckWellFormed(f->body);
}

public:
bool CheckWellFormed(const Expr & e) {
this->VisitExpr(e);
return well_formed;
}
};

bool WellFormed(const Expr & e) {
return WellFormedChecker().CheckWellFormed(e);
}

TVM_REGISTER_API("relay._ir_pass.well_formed")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Expr e = args[0];
*ret = WellFormed(e);
});

} // namespace relay
} // namespace tvm
1 change: 0 additions & 1 deletion tests/python/relay/test_relay_op.py
Expand Up @@ -24,4 +24,3 @@ def test_op_level1():
if __name__ == "__main__":
test_op_attr()
test_op_level1()

18 changes: 18 additions & 0 deletions tests/python/relay/test_well_formed.py
@@ -0,0 +1,18 @@
import tvm
from tvm import relay
from tvm.relay.ir_pass import well_formed

def test_well_formed():
x = relay.Var("x")
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
let = relay.Let(x, v, x, ty)
assert well_formed(let)
assert not well_formed(relay.Let(x, v, let, ty))
f = relay.Function([relay.Param(x, ty)], ty, x)
assert well_formed(f)
# this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other.
assert not well_formed(relay.Let(relay.Var("y"), f,
relay.Let(relay.Var("z"), f, v, ty), ty))

0 comments on commit 7beafdd

Please sign in to comment.