diff --git a/haiku/_src/stateful.py b/haiku/_src/stateful.py index 6e5e842e6..5990fdeb8 100644 --- a/haiku/_src/stateful.py +++ b/haiku/_src/stateful.py @@ -15,6 +15,7 @@ """Wrappers for JAX transformations that respect Haiku internal state.""" import collections +import collections.abc import functools import inspect from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, TypeVar @@ -734,6 +735,8 @@ def wrapper(*args, **kwargs): wrapper.require_split_rng = True return wrapper +list_to_tuple = lambda x: tuple(x) if isinstance(x, list) else x + @add_split_rng_error def vmap( @@ -793,7 +796,7 @@ def vmap( params_axes = state_axes = None rng_axes = (0 if split_rng else None) haiku_state_axes = InternalState(params_axes, state_axes, rng_axes) - in_axes = in_axes, haiku_state_axes + in_axes = list_to_tuple(in_axes), haiku_state_axes out_axes = out_axes, haiku_state_axes @functools.wraps(fun) diff --git a/haiku/_src/stateful_test.py b/haiku/_src/stateful_test.py index f43a17857..661d07623 100644 --- a/haiku/_src/stateful_test.py +++ b/haiku/_src/stateful_test.py @@ -591,6 +591,12 @@ def test_vmap_in_axes_different_size(self): ValueError, "vmap got inconsistent sizes for array axes to be mapped"): stateful.vmap(lambda a, b: None, in_axes=(0, 1), split_rng=False)(x, x) + @test_utils.transform_and_run + def test_vmap_in_axes_supports_list(self): + a = jnp.ones([4]) + b = stateful.vmap(lambda a: a * 2, in_axes=[0], split_rng=False)(a) + np.testing.assert_array_equal(b, a * 2) + @test_utils.transform_and_run def test_vmap_no_split_rng(self): key_before = base.next_rng_key()