Skip to content

Commit

Permalink
update weight standarization (#8)
Browse files Browse the repository at this point in the history
* update weight standarization

* update version
  • Loading branch information
chaoming0625 committed Apr 13, 2024
1 parent 81e5ce0 commit 7827b35
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
2 changes: 1 addition & 1 deletion braintools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================


__version__ = "0.0.2"
__version__ = "0.0.3"

from . import metric
from . import input
Expand Down
37 changes: 23 additions & 14 deletions braintools/functional/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import braincore as bc


__all__ = [
Expand All @@ -28,35 +29,43 @@

def weight_standardization(
w: jax.typing.ArrayLike,
axes: Sequence[int],
eps: float,
gain: Optional[jax.Array] = None
eps: float = 1e-4,
gain: Optional[jax.Array] = None,
out_axis: int = -1,
):
"""
Scaled Weight Standardization.
Scaled Weight Standardization,
see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
Parameters
----------
w : jax.typing.ArrayLike
The weight tensor.
axes : Sequence[int]
The axes to calculate the mean and variance.
eps : float
A small value to avoid division by zero.
gain : Array
The gain function, by default None.
out_axis : int
The output axis, by default -1.
Returns
-------
jax.typing.ArrayLike
The scaled weight tensor.
"""
# Get Scaled WS weight HWIO;
fan_in = np.prod(w.shape[:-1])
mean = jnp.mean(w, axis=axes[:-1], keepdims=True)
var = jnp.var(w, axis=axes[:-1], keepdims=True)
weight = (w - mean) / (var * fan_in + eps) ** 0.5
if out_axis < 0:
out_axis = w.ndim + out_axis
fan_in = 1 # get the fan-in of the weight tensor
axes = [] # get the axes of the weight tensor
for i in range(w.ndim):
if i != out_axis:
fan_in *= w.shape[i]
axes.append(i)
# normalize the weight
mean = jnp.mean(w, axis=axes, keepdims=True)
var = jnp.var(w, axis=axes, keepdims=True)
scale = jax.lax.rsqrt(jnp.maximum(var * fan_in, eps))
if gain is not None:
weight = gain * weight
return weight
scale = gain * scale
shift = mean * scale
return w * scale - shift

0 comments on commit 7827b35

Please sign in to comment.