diff --git a/haiku/_src/recurrent.py b/haiku/_src/recurrent.py index 413025a93..9adccdf63 100644 --- a/haiku/_src/recurrent.py +++ b/haiku/_src/recurrent.py @@ -37,6 +37,7 @@ hk.get_parameter = base.get_parameter hk.Module = module.Module hk.scan = stateful.scan +inside_transform = base.inside_transform del base, basic, conv, initializers, module @@ -154,11 +155,12 @@ def dynamic_unroll(core, input_sequence, initial_state): of shape ``[T, ...]``. * **final_state** - Core state at time step ``T``. """ + scan = hk.scan if inside_transform() else jax.lax.scan # Swap the input and output of core. def scan_f(prev_state, inputs): outputs, next_state = core(inputs, prev_state) return next_state, outputs - final_state, output_sequence = hk.scan( + final_state, output_sequence = scan( scan_f, initial_state, input_sequence) diff --git a/haiku/_src/recurrent_test.py b/haiku/_src/recurrent_test.py index c7f2ef817..cbaaccc05 100644 --- a/haiku/_src/recurrent_test.py +++ b/haiku/_src/recurrent_test.py @@ -96,6 +96,14 @@ def test_core_unroll_nested(self, unroll): for out in outs: self.assertEqual(out.shape, (4, 8, 4)) + @parameterized.parameters(recurrent.dynamic_unroll, recurrent.static_unroll) + def test_unroll_outside_transform(self, unroll): + core = lambda x, s: (x + 1, s + 1) + seqs = jnp.arange(8) + outs, state = unroll(core, seqs, 0) + np.testing.assert_allclose(outs, jnp.arange(9)[1:]) + np.testing.assert_allclose(state, 8) + class LSTMTest(absltest.TestCase): diff --git a/haiku/_src/stateful.py b/haiku/_src/stateful.py index 69e3d320f..e2ace4a26 100644 --- a/haiku/_src/stateful.py +++ b/haiku/_src/stateful.py @@ -386,6 +386,10 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun): def scan(f, init, xs, length=None, reverse=False): """Equivalent to `jax.lax.scan` but with Haiku state threaded in and out.""" + if not base.inside_transform(): + raise ValueError("hk.scan() should not be used outside of hk.transform(). " + "Use jax.scan() instead.") + if length is None: length = jax.tree_leaves(xs)[0].shape[0] diff --git a/haiku/_src/stateful_test.py b/haiku/_src/stateful_test.py index dfcbf66f0..0ccf3378c 100644 --- a/haiku/_src/stateful_test.py +++ b/haiku/_src/stateful_test.py @@ -218,6 +218,11 @@ def test_difference_rng(self): self.assertEmpty(diff.state) self.assertIsNotNone(diff.rng) + def test_scan_no_transform(self): + xs = jnp.arange(3) + with self.assertRaises(ValueError, msg="Use jax.scan() instead"): + stateful.scan(lambda c, x: (c, x), (), xs) + @parameterized.parameters(0, 1, 2, 4, 8) def test_scan_with_state(self, unroll_length): def f(xs):