Skip to content

Commit

Permalink
Support hk.dynamic_unroll outside of transformed functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 318236679
Change-Id: I4ae9c5df495f378201a84dec592c85b54621ac28
  • Loading branch information
tomhennigan authored and Copybara-Service committed Jun 25, 2020
1 parent 5457f4b commit 66f9c69
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 1 deletion.
4 changes: 3 additions & 1 deletion haiku/_src/recurrent.py
Expand Up @@ -37,6 +37,7 @@
hk.get_parameter = base.get_parameter
hk.Module = module.Module
hk.scan = stateful.scan
inside_transform = base.inside_transform
del base, basic, conv, initializers, module


Expand Down Expand Up @@ -154,11 +155,12 @@ def dynamic_unroll(core, input_sequence, initial_state):
of shape ``[T, ...]``.
* **final_state** - Core state at time step ``T``.
"""
scan = hk.scan if inside_transform() else jax.lax.scan
# Swap the input and output of core.
def scan_f(prev_state, inputs):
outputs, next_state = core(inputs, prev_state)
return next_state, outputs
final_state, output_sequence = hk.scan(
final_state, output_sequence = scan(
scan_f,
initial_state,
input_sequence)
Expand Down
8 changes: 8 additions & 0 deletions haiku/_src/recurrent_test.py
Expand Up @@ -96,6 +96,14 @@ def test_core_unroll_nested(self, unroll):
for out in outs:
self.assertEqual(out.shape, (4, 8, 4))

@parameterized.parameters(recurrent.dynamic_unroll, recurrent.static_unroll)
def test_unroll_outside_transform(self, unroll):
core = lambda x, s: (x + 1, s + 1)
seqs = jnp.arange(8)
outs, state = unroll(core, seqs, 0)
np.testing.assert_allclose(outs, jnp.arange(9)[1:])
np.testing.assert_allclose(state, 8)


class LSTMTest(absltest.TestCase):

Expand Down
4 changes: 4 additions & 0 deletions haiku/_src/stateful.py
Expand Up @@ -386,6 +386,10 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun):

def scan(f, init, xs, length=None, reverse=False):
"""Equivalent to `jax.lax.scan` but with Haiku state threaded in and out."""
if not base.inside_transform():
raise ValueError("hk.scan() should not be used outside of hk.transform(). "
"Use jax.scan() instead.")

if length is None:
length = jax.tree_leaves(xs)[0].shape[0]

Expand Down
5 changes: 5 additions & 0 deletions haiku/_src/stateful_test.py
Expand Up @@ -218,6 +218,11 @@ def test_difference_rng(self):
self.assertEmpty(diff.state)
self.assertIsNotNone(diff.rng)

def test_scan_no_transform(self):
xs = jnp.arange(3)
with self.assertRaises(ValueError, msg="Use jax.scan() instead"):
stateful.scan(lambda c, x: (c, x), (), xs)

@parameterized.parameters(0, 1, 2, 4, 8)
def test_scan_with_state(self, unroll_length):
def f(xs):
Expand Down

0 comments on commit 66f9c69

Please sign in to comment.