Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]):
def vmap(fn: Callable[..., Any],
variable_axes: Mapping[CollectionFilter, InOutAxis],
split_rngs: Mapping[PRNGSequenceFilter, bool],
in_axes=0, out_axes=0, axis_size=None) -> Callable[..., Any]:
in_axes=0, out_axes=0, axis_size=None, axis_name=None) -> Callable[..., Any]:
"""A lifted version of `jax.vmap`.

See `jax.vmap` for the unlifted batch transform in Jax.
Expand All @@ -333,8 +333,11 @@ def vmap(fn: Callable[..., Any],
of the batch dimension. Unsplit PRNGs will be broadcasted.
in_axes: Specifies the mapping of the input arguments (see `jax.vmap).
out_axes: Specifies the mapping of the return value (see `jax.vmap).
axes_size: Specifies the size of the batch axis. This only needs
axis_size: Specifies the size of the batch axis. This only needs
to be specified if it cannot be derived from the input arguments.
axis_name: Specifies a name for the batch axis. Can be used together
with parallel reduction primitives (e.g. `jax.lax.pmean`,
`jax.lax.ppermute`, etc.)
"""
variable_in_axes, variable_out_axes = _split_in_out_axes(variable_axes)
variable_in_groups, variable_in_axes = _unzip2(variable_in_axes.items())
Expand Down Expand Up @@ -369,7 +372,8 @@ def find_axis_size(axis, x):

@functools.partial(jax.vmap,
in_axes=(variable_in_axes, rng_axes, in_axes),
out_axes=(out_axes, variable_out_axes))
out_axes=(out_axes, variable_out_axes),
axis_name=axis_name)
@functools.wraps(fn)
def mapped(variable_groups, rng_groups, args):
scope = scope_fn(variable_groups, rng_groups)
Expand Down
28 changes: 28 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,34 @@ def vmap(fn):
y2 = vmap_model.apply(init_variables, x2)
np.testing.assert_allclose(y1, y2, atol=1e-7)

def test_vmap_batchnorm(self):
key1, key2 = random.split(random.PRNGKey(3), 2)
x = random.uniform(key1, (4, 4))
x2 = random.uniform(key1, (5, 4, 4))

def vmap(cls):
return nn.vmap(cls,
in_axes=(0,),
variable_axes={'params': None, 'batch_stats': None},
split_rngs={'params': False},
axis_name='batch')
class MlpBn(nn.Module):
axis_name: Any = None

@nn.compact
def __call__(self, x):
x = nn.Dense(3)(x)
x = nn.BatchNorm(axis_name=self.axis_name, use_running_average=False)(x)
return x

normal_model = MlpBn()
vmap_model = vmap(MlpBn)(axis_name='batch')
init_variables = normal_model.init(key2, x)
y1 = normal_model.apply(init_variables, x2.reshape((-1, 4)), mutable=['batch_stats'])[0]
y1 = y1.reshape((5, 4, 3))
y2 = vmap_model.apply(init_variables, x2, mutable=['batch_stats'])[0]
np.testing.assert_allclose(y1, y2, atol=1e-6)

def test_scan(self):
class SimpleScan(nn.Module):
@nn.compact
Expand Down