<a href="https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/layer_norm_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import collections

try:
    import jax
except:
    %pip install jax jaxlib
    import jax
import jax.numpy as jnp

try:
    import matplotlib.pyplot as plt
except:
    %pip install matplotlib
    import matplotlib.pyplot as plt


try:
    import flax.linen as nn
except:
    %pip install flax
    import flax.linen as nn

key = jax.random.PRNGKey(seed=1)



In [2]:
# batch size 3, feature size 2
x = jnp.array([[1, 2, 3], [4, 5, 6]])

print('batch norm')
mu_batch = jnp.mean(x,axis=0)
sigma_batch = jnp.std(x,axis=0)
x_batch_norm = (x-mu_batch)/sigma_batch
print(x_batch_norm)

print('layer norm')
mu_layer = jnp.expand_dims(jnp.mean(x,axis=1),axis=1)
sigma_layer = jnp.expand_dims(jnp.std(x,axis=1), axis=1)
x_layer_norm = (x-mu_layer)/sigma_layer
print(x_layer_norm)


batch norm
[[-1. -1. -1.]
 [ 1.  1.  1.]]
layer norm
[[-1.2247448  0.         1.2247448]
 [-1.2247448  0.         1.2247448]]


In [3]:
x = jnp.array(x, dtype=jnp.float32)

batch_norm = nn.BatchNorm(use_running_average=False)
layer_norm = nn.LayerNorm()

print('batch norm')
params_batch_norm =  batch_norm.init(key,x)
x_batch_norm_train,_ = batch_norm.apply(params_batch_norm,x,mutable=['batch_stats'])
print(x_batch_norm_train)

assert(jnp.allclose(x_batch_norm_train, x_batch_norm, atol=1e-3))

print('layer norm')
params_layer_norm = layer_norm.init(key,x)
x_layer_norm_train = layer_norm.apply(params_layer_norm,x)
print(x_layer_norm_train)

assert(jnp.allclose(x_layer_norm_train, x_layer_norm, atol=1e-3))

batch norm
[[-0.9999978 -0.9999978 -0.9999978]
 [ 0.9999978  0.9999978  0.9999978]]
layer norm
[[-1.2247437  0.         1.2247437]
 [-1.2247428  0.         1.2247428]]
