In [73]:
import jax
from typing import Any
from jax import numpy as jnp
from flax import linen as nn
import optax

This shows BatchNorm that needs to maintain `batch_stats` variables. See `Flax basics.ipynb` for the case w/o variables.

# Data Prep

In [74]:
x_key, noise_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
xs = jax.random.normal(x_key, (100, 1))
noise = jax.random.normal(noise_key, (100, 1))
W, b = 2, -1
ys = xs + noise + b

# Layer with a batch norm

In [75]:
class MultiLayerDnnWithBatchNorm(nn.Module):
    
    @nn.compact
    def __call__(self, x, train: bool):
        x = nn.Dense(512)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        x = nn.Dense(256)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x
    

In [76]:
from flax.training import train_state

class MyTrainState(train_state.TrainState):
    batch_stats: Any

In [77]:
@jax.jit
def train_step(state: MyTrainState, xs, ys):
    
    def loss_fn(params):
        # When there're variables, their updates is returned by model.apply.
        yhats, updates = state.apply_fn(
            {
                'params': params, 
                'batch_stats': state.batch_stats # Variables are specified.
            },
            xs, 
            train=True, # For batch norm train vs eval
            mutable=['batch_stats'] # For variables updates
        )
        loss = jnp.mean((ys - yhats) ** 2)
        return loss, updates
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, updates), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    # Required to update the batch statistics.
    state = state.replace(batch_stats=updates['batch_stats'])
    metrics = {'loss': loss}
    return state, metrics

@jax.jit
def eval_step(state: MyTrainState, xs, ys):
    yhats, _ = state.apply_fn({'params': state.params}, xs, train=False)
    loss = jnp.mean((ys - yhats) ** 2)
    metrics = {'loss': loss}
    return metrics        

In [78]:
model = MultiLayerDnnWithBatchNorm()
variables = model.init(model_key, xs, train=False)
params, batch_stats = variables['params'], variables['batch_stats']

state = MyTrainState.create(
    apply_fn=model.apply,
    params=params,
    batch_stats=batch_stats,
    tx=optax.adam(1e-3),
)

for i in range(1001):
    state, metrics = train_step(state, xs, ys)
    if i % 100 == 0:
        print(f'Loss @ {i}: {metrics["loss"]}')

Loss @ 0: 2.3187077045440674
Loss @ 100: 1.0786985158920288
Loss @ 200: 1.0493273735046387
Loss @ 300: 1.0176451206207275
Loss @ 400: 0.9789206981658936
Loss @ 500: 0.9367572069168091
Loss @ 600: 0.9019906520843506
Loss @ 700: 0.8809128403663635
Loss @ 800: 0.8683042526245117
Loss @ 900: 0.8628705143928528
Loss @ 1000: 0.8435438871383667


Inspect params

In [79]:
import json

def pretty_dict(d):
  return json.dumps(d, indent=2)

print(jax.tree_util.tree_map(jnp.shape, state.batch_stats))
print(pretty_dict(jax.tree_util.tree_map(lambda x: x.tolist()[:3], state.batch_stats)))

{'BatchNorm_0': {'mean': (512,), 'var': (512,)}, 'BatchNorm_1': {'mean': (256,), 'var': (256,)}}
{
  "BatchNorm_0": {
    "mean": [
      -0.02687761001288891,
      0.03383558616042137,
      0.017856178805232048
    ],
    "var": [
      1.5422340631484985,
      2.033207893371582,
      0.8356245160102844
    ]
  },
  "BatchNorm_1": {
    "mean": [
      -1.114841341972351,
      2.4079277515411377,
      3.353090286254883
    ],
    "var": [
      0.9122587442398071,
      11.84886646270752,
      5.084463119506836
    ]
  }
}
