Skip to content

Commit

Permalink
Support vmap where in_axes is a list rather than a tuple.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 479522483
  • Loading branch information
tomhennigan authored and Copybara-Service committed Oct 7, 2022
1 parent 96e513a commit 307cf7d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 4 additions & 1 deletion haiku/_src/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Wrappers for JAX transformations that respect Haiku internal state."""

import collections
import collections.abc
import functools
import inspect
from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, TypeVar
Expand Down Expand Up @@ -734,6 +735,8 @@ def wrapper(*args, **kwargs):
wrapper.require_split_rng = True
return wrapper

list_to_tuple = lambda x: tuple(x) if isinstance(x, list) else x


@add_split_rng_error
def vmap(
Expand Down Expand Up @@ -793,7 +796,7 @@ def vmap(
params_axes = state_axes = None
rng_axes = (0 if split_rng else None)
haiku_state_axes = InternalState(params_axes, state_axes, rng_axes)
in_axes = in_axes, haiku_state_axes
in_axes = list_to_tuple(in_axes), haiku_state_axes
out_axes = out_axes, haiku_state_axes

@functools.wraps(fun)
Expand Down
6 changes: 6 additions & 0 deletions haiku/_src/stateful_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,12 @@ def test_vmap_in_axes_different_size(self):
ValueError, "vmap got inconsistent sizes for array axes to be mapped"):
stateful.vmap(lambda a, b: None, in_axes=(0, 1), split_rng=False)(x, x)

@test_utils.transform_and_run
def test_vmap_in_axes_supports_list(self):
a = jnp.ones([4])
b = stateful.vmap(lambda a: a * 2, in_axes=[0], split_rng=False)(a)
np.testing.assert_array_equal(b, a * 2)

@test_utils.transform_and_run
def test_vmap_no_split_rng(self):
key_before = base.next_rng_key()
Expand Down

0 comments on commit 307cf7d

Please sign in to comment.