From a0316821f549311858b6e69b794f46c95a516710 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 22 Jan 2024 15:40:46 +0100 Subject: [PATCH] bn rename some variables --- e3nn_jax/_src/batchnorm/bn.py | 40 ++++++++++++++++------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/e3nn_jax/_src/batchnorm/bn.py b/e3nn_jax/_src/batchnorm/bn.py index d1a569c..f4aa9e4 100644 --- a/e3nn_jax/_src/batchnorm/bn.py +++ b/e3nn_jax/_src/batchnorm/bn.py @@ -57,62 +57,58 @@ def _roll_avg(curr, update): # [batch, sample, mul, repr] if ir.is_scalar(): # scalars if is_instance: - field_mean = chunk.mean(1).reshape(batch, mul) # [batch, mul] + mean = chunk.mean(1).reshape(batch, mul) # [batch, mul] else: if mask is None: - field_mean = chunk.mean([0, 1]).reshape(mul) # [mul] + mean = chunk.mean([0, 1]).reshape(mul) # [mul] else: - field_mean = (chunk.mean(1).squeeze(2) * mask[:, None]).sum( + mean = (chunk.mean(1).squeeze(2) * mask[:, None]).sum( 0 ) / mask.sum() # [mul] - new_ra_means.append( - _roll_avg(ra_mean[i_rmu : i_rmu + mul], field_mean) - ) + new_ra_means.append(_roll_avg(ra_mean[i_rmu : i_rmu + mul], mean)) if use_running_average: - field_mean = ra_mean[i_rmu : i_rmu + mul] + mean = ra_mean[i_rmu : i_rmu + mul] i_rmu += mul # [batch, sample, mul, repr] - chunk = chunk - field_mean.reshape(-1, 1, mul, 1) + chunk = chunk - mean.reshape(-1, 1, mul, 1) if normalization == "norm": - field_norm = jnp.square(chunk).sum(3) # [batch, sample, mul] + norm_squared = jnp.square(chunk).sum(3) # [batch, sample, mul] elif normalization == "component": - field_norm = jnp.square(chunk).mean(3) # [batch, sample, mul] + norm_squared = jnp.square(chunk).mean(3) # [batch, sample, mul] else: raise ValueError( "Invalid normalization option {}".format(normalization) ) if reduce == "mean": - field_norm = field_norm.mean(1) # [batch, mul] + norm_squared = norm_squared.mean(1) # [batch, mul] elif reduce == "max": - field_norm = field_norm.max(1) # [batch, mul] + norm_squared = norm_squared.max(1) # [batch, mul] else: raise ValueError("Invalid reduce option {}".format(reduce)) if not is_instance: if mask is None: - field_norm = field_norm.mean(0) # [mul] + norm_squared = norm_squared.mean(0) # [mul] else: - field_norm = (field_norm * mask[:, None]).sum(0) / mask.sum() - new_ra_vars.append(_roll_avg(ra_var[i_wei : i_wei + mul], field_norm)) + norm_squared = (norm_squared * mask[:, None]).sum(0) / mask.sum() + new_ra_vars.append(_roll_avg(ra_var[i_wei : i_wei + mul], norm_squared)) if use_running_average: - field_norm = ra_var[i_wei : i_wei + mul] + norm_squared = ra_var[i_wei : i_wei + mul] - field_norm = jax.lax.rsqrt( - (1 - epsilon) * field_norm + epsilon + inverse = jax.lax.rsqrt( + (1 - epsilon) * norm_squared + epsilon ) # [(batch,) mul] if use_affine: sub_weight = weight[i_wei : i_wei + mul] # [mul] - field_norm = field_norm * sub_weight # [(batch,) mul] + inverse = inverse * sub_weight # [(batch,) mul] - # TODO add test case for when mul == 0 - field_norm = field_norm[..., None, :, None] # [(batch,) 1, mul, 1] - chunk = chunk * field_norm # [batch, sample, mul, repr] + chunk = chunk * inverse[..., None, :, None] # [batch, sample, mul, repr] if use_affine and ir.is_scalar(): # scalars sub_bias = bias[i_bia : i_bia + mul] # [mul]