Skip to content

Commit

Permalink
[linen] Only pmean the mean if it is non-zero, in normalization mod…
Browse files Browse the repository at this point in the history
…ules.

PiperOrigin-RevId: 545933541
  • Loading branch information
chr1sj0nes authored and Flax Authors committed Jul 6, 2023
1 parent 1e7a8b1 commit 0734d00
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

"""Normalization modules for Flax."""

import functools
from typing import (Any, Callable, Iterable, Optional, Tuple, Union)
from flax.linen.dtypes import canonicalize_dtype

from flax.linen.module import Module, compact, merge_param # pylint: disable=g-multiple-import
from jax import lax
from jax.nn import initializers
Expand Down Expand Up @@ -90,12 +90,13 @@ def _compute_stats(x: Array, axes: Optional[Axes],
mean = jnp.zeros(mean2.shape, dtype=dtype)

if axis_name is not None:
concatenated_mean = jnp.concatenate([mean, mean2])
mean, mean2 = jnp.split(
lax.pmean(
concatenated_mean,
axis_name=axis_name,
axis_index_groups=axis_index_groups), 2)
pmean = functools.partial(
lax.pmean, axis_name=axis_name, axis_index_groups=axis_index_groups
)
if use_mean:
mean, mean2 = jnp.split(pmean(jnp.concatenate([mean, mean2])), 2)
else:
mean2 = pmean(mean2)
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
# to floating point round-off errors.
var = jnp.maximum(0., mean2 - _abs_sq(mean))
Expand Down

0 comments on commit 0734d00

Please sign in to comment.