Skip to content

Commit

Permalink
bn rename some variables
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jan 22, 2024
1 parent 74b110b commit a031682
Showing 1 changed file with 18 additions and 22 deletions.
40 changes: 18 additions & 22 deletions e3nn_jax/_src/batchnorm/bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,62 +57,58 @@ def _roll_avg(curr, update):
# [batch, sample, mul, repr]
if ir.is_scalar(): # scalars
if is_instance:
field_mean = chunk.mean(1).reshape(batch, mul) # [batch, mul]
mean = chunk.mean(1).reshape(batch, mul) # [batch, mul]
else:
if mask is None:
field_mean = chunk.mean([0, 1]).reshape(mul) # [mul]
mean = chunk.mean([0, 1]).reshape(mul) # [mul]
else:
field_mean = (chunk.mean(1).squeeze(2) * mask[:, None]).sum(
mean = (chunk.mean(1).squeeze(2) * mask[:, None]).sum(
0
) / mask.sum() # [mul]
new_ra_means.append(
_roll_avg(ra_mean[i_rmu : i_rmu + mul], field_mean)
)
new_ra_means.append(_roll_avg(ra_mean[i_rmu : i_rmu + mul], mean))

if use_running_average:
field_mean = ra_mean[i_rmu : i_rmu + mul]
mean = ra_mean[i_rmu : i_rmu + mul]
i_rmu += mul

# [batch, sample, mul, repr]
chunk = chunk - field_mean.reshape(-1, 1, mul, 1)
chunk = chunk - mean.reshape(-1, 1, mul, 1)

if normalization == "norm":
field_norm = jnp.square(chunk).sum(3) # [batch, sample, mul]
norm_squared = jnp.square(chunk).sum(3) # [batch, sample, mul]
elif normalization == "component":
field_norm = jnp.square(chunk).mean(3) # [batch, sample, mul]
norm_squared = jnp.square(chunk).mean(3) # [batch, sample, mul]
else:
raise ValueError(
"Invalid normalization option {}".format(normalization)
)

if reduce == "mean":
field_norm = field_norm.mean(1) # [batch, mul]
norm_squared = norm_squared.mean(1) # [batch, mul]
elif reduce == "max":
field_norm = field_norm.max(1) # [batch, mul]
norm_squared = norm_squared.max(1) # [batch, mul]
else:
raise ValueError("Invalid reduce option {}".format(reduce))

if not is_instance:
if mask is None:
field_norm = field_norm.mean(0) # [mul]
norm_squared = norm_squared.mean(0) # [mul]
else:
field_norm = (field_norm * mask[:, None]).sum(0) / mask.sum()
new_ra_vars.append(_roll_avg(ra_var[i_wei : i_wei + mul], field_norm))
norm_squared = (norm_squared * mask[:, None]).sum(0) / mask.sum()
new_ra_vars.append(_roll_avg(ra_var[i_wei : i_wei + mul], norm_squared))

if use_running_average:
field_norm = ra_var[i_wei : i_wei + mul]
norm_squared = ra_var[i_wei : i_wei + mul]

field_norm = jax.lax.rsqrt(
(1 - epsilon) * field_norm + epsilon
inverse = jax.lax.rsqrt(
(1 - epsilon) * norm_squared + epsilon
) # [(batch,) mul]

if use_affine:
sub_weight = weight[i_wei : i_wei + mul] # [mul]
field_norm = field_norm * sub_weight # [(batch,) mul]
inverse = inverse * sub_weight # [(batch,) mul]

# TODO add test case for when mul == 0
field_norm = field_norm[..., None, :, None] # [(batch,) 1, mul, 1]
chunk = chunk * field_norm # [batch, sample, mul, repr]
chunk = chunk * inverse[..., None, :, None] # [batch, sample, mul, repr]

if use_affine and ir.is_scalar(): # scalars
sub_bias = bias[i_bia : i_bia + mul] # [mul]
Expand Down

0 comments on commit a031682

Please sign in to comment.