Skip to content

Commit

Permalink
[linen] More minor cleanup in normalization compute_stats.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549515492
  • Loading branch information
chr1sj0nes authored and Flax Authors committed Jul 20, 2023
1 parent aa7b14e commit b4591c1
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,26 @@ def _compute_stats(
dtype = jnp.promote_types(dtype, jnp.float32)
x = jnp.asarray(x, dtype)

def pmean(x):
def mean(x, axes=axes):
mu = x.mean(axes)
if axis_name is None:
return x
return lax.pmean(x, axis_name, axis_index_groups=axis_index_groups)
return mu
return lax.pmean(mu, axis_name, axis_index_groups=axis_index_groups)

if use_mean:
if use_fast_variance:
mean, mean2 = pmean(jnp.stack([x.mean(axes), _abs_sq(x).mean(axes)]))
axes = _canonicalize_axes(x.ndim, axes)
mu, mu2 = mean(jnp.stack([x, _abs_sq(x)]), axes=[a + 1 for a in axes])
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
# to floating point round-off errors.
var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
else:
mean = pmean(x.mean(axes))
var = pmean(_abs_sq(x - jnp.expand_dims(mean, axes)).mean(axes))
mu = mean(x)
var = mean(_abs_sq(x - jnp.expand_dims(mu, axes)))
else:
var = pmean(_abs_sq(x).mean(axes))
mean = jnp.zeros_like(var)
return mean, var
var = mean(_abs_sq(x))
mu = jnp.zeros_like(var)
return mu, var


def _normalize(mdl: Module, x: Array, mean: Array, var: Array,
Expand Down

0 comments on commit b4591c1

Please sign in to comment.