diff --git a/haiku/_src/base.py b/haiku/_src/base.py index 2e253263b..b03e4b408 100644 --- a/haiku/_src/base.py +++ b/haiku/_src/base.py @@ -131,6 +131,10 @@ def module(self, module_state: ModuleState): current_frame = frame_stack.peek +def inside_transform(): + return bool(frame_stack) + + def safe_get_module_name(module) -> Text: # TODO(tomhennigan) Module specific code should be part of `module.py`. if not hasattr(module, "module_name"): diff --git a/haiku/_src/stateful.py b/haiku/_src/stateful.py index 245f90f86..132393e17 100644 --- a/haiku/_src/stateful.py +++ b/haiku/_src/stateful.py @@ -174,6 +174,9 @@ def value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False): integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. """ + if not base.inside_transform(): + raise ValueError("hk.grad() should not be used outside of hk.transform(). " + "Use jax.grad() instead.") @functools.wraps(fun) def stateful_fun(*args, **kwargs): @@ -204,6 +207,11 @@ def thread_hk_state_in_kwargs(dec_fun): def wrapped_dec_fun(fun, *dec_args, **dec_kwargs): """Decorates a modified version of `fun` that passes haiku state.""" + if not base.inside_transform(): + raise ValueError( + "hk.{0}() should not be used outside of hk.transform. " + "Use jax.{0}() instead.".format(dec_fun.__name__)) + @functools.wraps(fun) def stateful_fun(*args, **kwargs): with temporary_internal_state(kwargs.pop("hk_state")): @@ -240,6 +248,9 @@ def new_branch_fun(operand): def cond(pred, true_operand, true_fun, false_operand, false_fun): """Equivalent to `jax.lax.cond` but with Haiku state threaded in and out.""" + if not base.inside_transform(): + raise ValueError("hk.cond() should not be used outside of hk.transform(). " + "Use jax.cond() instead.") state = internal_state() out, state = jax.lax.cond(pred, true_operand=(state, true_operand), diff --git a/haiku/_src/stateful_test.py b/haiku/_src/stateful_test.py index 11cf79331..655f4c6e4 100644 --- a/haiku/_src/stateful_test.py +++ b/haiku/_src/stateful_test.py @@ -34,6 +34,11 @@ def test_grad(self): g = stateful.grad(SquareModule())(x) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) + def test_grad_no_transform(self): + x = jnp.array(3.) + with self.assertRaises(ValueError, msg="Use jax.grad() instead"): + stateful.grad(lambda x: x**2)(x) + @test_utils.transform_and_run def test_value_and_grad(self): x = jnp.array(2.) @@ -41,6 +46,11 @@ def test_value_and_grad(self): self.assertEqual(y, x ** 2) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) + def test_value_and_grad_no_transform(self): + x = jnp.array(3.) + with self.assertRaises(ValueError, msg="Use jax.grad() instead"): + stateful.value_and_grad(lambda x: x**2)(x) + @test_utils.transform_and_run def test_grad_aux(self): o = object() @@ -98,6 +108,11 @@ def test_jit(self): y = stateful.jit(mod)(x) self.assertEqual(y, x ** 2) + def test_jit_no_transform(self): + x = jnp.array(2) + with self.assertRaises(ValueError, msg="Use jax.jit() instead"): + stateful.jit(lambda x: x**2)(x) + @test_utils.transform_and_run def test_remat(self): forward, backward = [], [] @@ -133,6 +148,11 @@ def test(remat): self.assertGreater(num_forward_remat, num_forward_no_remat) self.assertEqual(num_backward_remat, num_backward_no_remat) + def test_remat_no_transform(self): + x = jnp.array(3.) + with self.assertRaises(ValueError, msg="Use jax.remat() instead"): + stateful.remat(lambda x: x**2)(x) + def test_cond(self): def f(x): mod = SquareModule() @@ -145,6 +165,11 @@ def f(x): self.assertEqual(state, {"square_module": {"y": y}}) self.assertEqual(out, y) + def test_cond_no_transform(self): + x = jnp.array(3.) + with self.assertRaises(ValueError, msg="Use jax.cond() instead"): + stateful.cond(x == 2, x, lambda x: x**2, x, lambda x: (x + 1)**2) + def _callback_prim(forward, backward): def f_impl(x): @@ -183,5 +208,6 @@ def __call__(self, x): base.set_state("y", y) return y + if __name__ == "__main__": absltest.main()