Skip to content

Commit

Permalink
Add flag to enable checking, and turn on checking in tests. (#2900)
Browse files Browse the repository at this point in the history
Fix an error in check_jaxpr.
  • Loading branch information
gnecula committed May 1, 2020
1 parent e06bde8 commit 2e9047d
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 10 deletions.
9 changes: 9 additions & 0 deletions jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,13 @@ def __getattr__(self, name):

config = Config()
flags = config
FLAGS = flags.FLAGS

already_configured_with_absl = False

flags.DEFINE_bool(
'jax_enable_checks',
bool_env('JAX_ENABLE_CHECKS', False),
help=
'Turn on invariant checking (core.skip_checks = False)'
)
23 changes: 19 additions & 4 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,27 @@
import numpy as onp

from . import dtypes
from .config import FLAGS
from . import linear_util as lu

from .util import safe_zip, safe_map, partial, curry, prod, partialmethod
from .pprint_util import pp, vcat, hcat, pp_kv_pairs, PrettyPrint

# TODO(dougalm): the trace cache breaks the leak detector. Consisder solving.
check_leaks = False
# TODO(dougalm): put this behind a flag that's enabled during testing
skip_checks = True # not __debug__ # google doesn't use -O

"""Disables internal invariant checks."""
skip_checks = not FLAGS.jax_enable_checks # not __debug__ # google doesn't use -O

@contextmanager
def skipping_checks():
"""Context manager for temporarily disabling checks."""
global skip_checks
old_value, skip_checks = skip_checks, True
try:
yield
finally:
skip_checks = old_value

zip = safe_zip
map = safe_map
Expand Down Expand Up @@ -653,7 +665,10 @@ class Bot(AbstractValue): pass
bot = Bot()

class AbstractUnit(AbstractValue):
def join(self, other): return self
def join(self, other):
if not skip_checks:
assert other is abstract_unit, other
return self
def _eq(self, self_traced, other): return get_aval(other) is self

abstract_unit = AbstractUnit()
Expand Down Expand Up @@ -1048,7 +1063,7 @@ def write_env(env: Set[Var], v: Var):
map(write, jaxpr.constvars)
map(write, jaxpr.invars)
for eqn in jaxpr.eqns:
if eqn.primitive.call_primitive or eqn.map_primitive:
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
if "call_jaxpr" not in eqn.params:
raise Exception("Call primitive {} should have a 'call_jaxpr' parameter"
.format(eqn.primitive))
Expand Down
3 changes: 2 additions & 1 deletion jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .. import core
from .. import linear_util as lu
from ..abstract_arrays import ShapedArray, ConcreteArray, raise_to_shaped
from ..ad_util import zero
from ..util import (unzip2, safe_zip, safe_map, toposort, partial, split_list,
wrap_name, cache)
from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval,
Expand All @@ -49,7 +50,7 @@ def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
if not core.skip_checks:
# type checks
assert isinstance(pv, (AbstractValue, type(None))), xs
assert isinstance(const, core.Tracer) or core.valid_jaxtype(const), xs
assert isinstance(const, core.Tracer) or const is zero or core.valid_jaxtype(const), xs
# invariant checks
if isinstance(pv, AbstractValue):
assert const == core.unit, xs
Expand Down
3 changes: 3 additions & 0 deletions jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,9 @@ class JaxTestCase(parameterized.TestCase):
# def tearDown(self) -> None:
# assert core.reset_trace_state()

def setUp(self):
core.skip_checks = False

def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None):
"""Assert that x and y are close (up to numerical tolerances)."""
self.assertEqual(x.shape, y.shape)
Expand Down
3 changes: 2 additions & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3075,7 +3075,8 @@ def f_vjp(x, y):
api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash

def test_custom_transforms_vjp_nones(self):
# issue rasied by jsnoek@ and jumper@
core.skip_checks = True # Fails with checks
# issue raised by jsnoek@ and jumper@
@jax.custom_transforms
def solve(a, b):
return np.dot(np.linalg.inv(a), b)
Expand Down
5 changes: 4 additions & 1 deletion tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as onp

import jax
from jax import core
from jax import dtypes
from jax import numpy as np
from jax import test_util as jtu
Expand Down Expand Up @@ -167,7 +168,9 @@ class AnEnum(enum.IntEnum):
A = 42
B = 101
onp.testing.assert_equal(onp.array(42), onp.array(AnEnum.A))
onp.testing.assert_equal(np.array(42), np.array(AnEnum.A))
with core.skipping_checks():
# Passing AnEnum.A to np.array fails the type check in bind
onp.testing.assert_equal(np.array(42), np.array(AnEnum.A))
onp.testing.assert_equal(onp.int32(101), onp.int32(AnEnum.B))
onp.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))

Expand Down
2 changes: 1 addition & 1 deletion tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2643,7 +2643,7 @@ def f2(x, y):
expected = onp.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)

with self.assertRaises(TypeError):
with self.assertRaises(TypeError if core.skip_checks else AssertionError):
lax.stop_gradient(lambda x: x)

# TODO(mattjj): make this a more systematic test
Expand Down
7 changes: 5 additions & 2 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as onp

from jax import core
from jax import test_util as jtu
from jax.test_util import check_grads
from jax import nn
Expand Down Expand Up @@ -92,12 +93,14 @@ def testDtypeMatchesInput(self, dtype, fn):
@jtu.skip_on_devices("gpu", "tpu")
def testEluMemory(self):
# see https://github.com/google/jax/pull/1640
jax.make_jaxpr(nn.elu)(np.ones((10 ** 12,))) # don't oom
with core.skipping_checks(): # With checks we materialize the array
jax.make_jaxpr(nn.elu)(np.ones((10 ** 12,))) # don't oom

@jtu.skip_on_devices("gpu", "tpu")
def testHardTanhMemory(self):
# see https://github.com/google/jax/pull/1640
jax.make_jaxpr(nn.hard_tanh)(np.ones((10 ** 12,))) # don't oom
with core.skipping_checks(): # With checks we materialize the array
jax.make_jaxpr(nn.hard_tanh)(np.ones((10 ** 12,))) # don't oom

def testOneHot(self):
actual = nn.one_hot(np.array([0, 1, 2]), 3)
Expand Down

0 comments on commit 2e9047d

Please sign in to comment.