From 0d2251aa3005c4e452659f76d0ba61eefeddea63 Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Tue, 16 Feb 2021 13:16:41 +0000 Subject: [PATCH] Add axis_name arg to lifted vmap --- flax/core/lift.py | 10 +++++++--- tests/linen/linen_transforms_test.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/flax/core/lift.py b/flax/core/lift.py index 5b361c44e..de17597cd 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -310,7 +310,7 @@ def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]): def vmap(fn: Callable[..., Any], variable_axes: Mapping[CollectionFilter, InOutAxis], split_rngs: Mapping[PRNGSequenceFilter, bool], - in_axes=0, out_axes=0, axis_size=None) -> Callable[..., Any]: + in_axes=0, out_axes=0, axis_size=None, axis_name=None) -> Callable[..., Any]: """A lifted version of `jax.vmap`. See `jax.vmap` for the unlifted batch transform in Jax. @@ -333,8 +333,11 @@ def vmap(fn: Callable[..., Any], of the batch dimension. Unsplit PRNGs will be broadcasted. in_axes: Specifies the mapping of the input arguments (see `jax.vmap). out_axes: Specifies the mapping of the return value (see `jax.vmap). - axes_size: Specifies the size of the batch axis. This only needs + axis_size: Specifies the size of the batch axis. This only needs to be specified if it cannot be derived from the input arguments. + axis_name: Specifies a name for the batch axis. Can be used together + with parallel reduction primitives (e.g. `jax.lax.pmean`, + `jax.lax.ppermute`, etc.) """ variable_in_axes, variable_out_axes = _split_in_out_axes(variable_axes) variable_in_groups, variable_in_axes = _unzip2(variable_in_axes.items()) @@ -369,7 +372,8 @@ def find_axis_size(axis, x): @functools.partial(jax.vmap, in_axes=(variable_in_axes, rng_axes, in_axes), - out_axes=(out_axes, variable_out_axes)) + out_axes=(out_axes, variable_out_axes), + axis_name=axis_name) @functools.wraps(fn) def mapped(variable_groups, rng_groups, args): scope = scope_fn(variable_groups, rng_groups) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 6a1d1af0c..07c4e91f0 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -154,6 +154,34 @@ def vmap(fn): y2 = vmap_model.apply(init_variables, x2) np.testing.assert_allclose(y1, y2, atol=1e-7) + def test_vmap_batchnorm(self): + key1, key2 = random.split(random.PRNGKey(3), 2) + x = random.uniform(key1, (4, 4)) + x2 = random.uniform(key1, (5, 4, 4)) + + def vmap(cls): + return nn.vmap(cls, + in_axes=(0,), + variable_axes={'params': None, 'batch_stats': None}, + split_rngs={'params': False}, + axis_name='batch') + class MlpBn(nn.Module): + axis_name: Any = None + + @nn.compact + def __call__(self, x): + x = nn.Dense(3)(x) + x = nn.BatchNorm(axis_name=self.axis_name, use_running_average=False)(x) + return x + + normal_model = MlpBn() + vmap_model = vmap(MlpBn)(axis_name='batch') + init_variables = normal_model.init(key2, x) + y1 = normal_model.apply(init_variables, x2.reshape((-1, 4)), mutable=['batch_stats'])[0] + y1 = y1.reshape((5, 4, 3)) + y2 = vmap_model.apply(init_variables, x2, mutable=['batch_stats'])[0] + np.testing.assert_allclose(y1, y2, atol=1e-6) + def test_scan(self): class SimpleScan(nn.Module): @nn.compact