Skip to content

Commit

Permalink
Merge pull request #17 from ibab:no_transform
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 299432214
Change-Id: Ie01909e0c646a390b7a10ee76941eae79cc0fdd6
  • Loading branch information
Copybara-Service committed Mar 6, 2020
2 parents 8a81e5d + 47f78a6 commit ad067f2
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
4 changes: 4 additions & 0 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def module(self, module_state: ModuleState):
current_frame = frame_stack.peek


def inside_transform():
return bool(frame_stack)


def safe_get_module_name(module) -> Text:
# TODO(tomhennigan) Module specific code should be part of `module.py`.
if not hasattr(module, "module_name"):
Expand Down
11 changes: 11 additions & 0 deletions haiku/_src/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False):
integers, the gradient is a tuple of values with the same shapes and types
as the corresponding arguments.
"""
if not base.inside_transform():
raise ValueError("hk.grad() should not be used outside of hk.transform(). "
"Use jax.grad() instead.")

@functools.wraps(fun)
def stateful_fun(*args, **kwargs):
Expand Down Expand Up @@ -204,6 +207,11 @@ def thread_hk_state_in_kwargs(dec_fun):
def wrapped_dec_fun(fun, *dec_args, **dec_kwargs):
"""Decorates a modified version of `fun` that passes haiku state."""

if not base.inside_transform():
raise ValueError(
"hk.{0}() should not be used outside of hk.transform. "
"Use jax.{0}() instead.".format(dec_fun.__name__))

@functools.wraps(fun)
def stateful_fun(*args, **kwargs):
with temporary_internal_state(kwargs.pop("hk_state")):
Expand Down Expand Up @@ -240,6 +248,9 @@ def new_branch_fun(operand):

def cond(pred, true_operand, true_fun, false_operand, false_fun):
"""Equivalent to `jax.lax.cond` but with Haiku state threaded in and out."""
if not base.inside_transform():
raise ValueError("hk.cond() should not be used outside of hk.transform(). "
"Use jax.cond() instead.")
state = internal_state()
out, state = jax.lax.cond(pred,
true_operand=(state, true_operand),
Expand Down
26 changes: 26 additions & 0 deletions haiku/_src/stateful_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,23 @@ def test_grad(self):
g = stateful.grad(SquareModule())(x)
np.testing.assert_allclose(g, 2 * x, rtol=1e-4)

def test_grad_no_transform(self):
x = jnp.array(3.)
with self.assertRaises(ValueError, msg="Use jax.grad() instead"):
stateful.grad(lambda x: x**2)(x)

@test_utils.transform_and_run
def test_value_and_grad(self):
x = jnp.array(2.)
y, g = stateful.value_and_grad(SquareModule())(x)
self.assertEqual(y, x ** 2)
np.testing.assert_allclose(g, 2 * x, rtol=1e-4)

def test_value_and_grad_no_transform(self):
x = jnp.array(3.)
with self.assertRaises(ValueError, msg="Use jax.grad() instead"):
stateful.value_and_grad(lambda x: x**2)(x)

@test_utils.transform_and_run
def test_grad_aux(self):
o = object()
Expand Down Expand Up @@ -98,6 +108,11 @@ def test_jit(self):
y = stateful.jit(mod)(x)
self.assertEqual(y, x ** 2)

def test_jit_no_transform(self):
x = jnp.array(2)
with self.assertRaises(ValueError, msg="Use jax.jit() instead"):
stateful.jit(lambda x: x**2)(x)

@test_utils.transform_and_run
def test_remat(self):
forward, backward = [], []
Expand Down Expand Up @@ -133,6 +148,11 @@ def test(remat):
self.assertGreater(num_forward_remat, num_forward_no_remat)
self.assertEqual(num_backward_remat, num_backward_no_remat)

def test_remat_no_transform(self):
x = jnp.array(3.)
with self.assertRaises(ValueError, msg="Use jax.remat() instead"):
stateful.remat(lambda x: x**2)(x)

def test_cond(self):
def f(x):
mod = SquareModule()
Expand All @@ -145,6 +165,11 @@ def f(x):
self.assertEqual(state, {"square_module": {"y": y}})
self.assertEqual(out, y)

def test_cond_no_transform(self):
x = jnp.array(3.)
with self.assertRaises(ValueError, msg="Use jax.cond() instead"):
stateful.cond(x == 2, x, lambda x: x**2, x, lambda x: (x + 1)**2)


def _callback_prim(forward, backward):
def f_impl(x):
Expand Down Expand Up @@ -183,5 +208,6 @@ def __call__(self, x):
base.set_state("y", y)
return y


if __name__ == "__main__":
absltest.main()

0 comments on commit ad067f2

Please sign in to comment.