# How and Why BatchNorm Works - Part 2


In Part 1, we showed how using batch normalization can help train neural networks by reducing internal covariate shift and smoothing gradient updates. We also touch upon the idea that it makes the loss surface smoother and more convex (residual connections have a similar effect). 

In their 2018 paper, Santukar et al. argue that internal covariate shift isn't substantially reduced at all. In fact, when additional noise is added to the activations to force covariate shift, batch normalization continues to improve training performance. This is despite no significant reduction in covariate shift. They argue that the major reason for improved performance is the smoother loss surface.

However, we saw in our own experiments using a shallow neural network that the reduction in covariate shift is substantial. Let's get to the bottom of this.

In this experiment, we will start by reproducing the results from Santukar et al. Then we will try and understand why their results and our previous experiment differ.

## Let's start as usual by setting up the dataset
We will use the CIFAR-10 dataset for this experiment.




In [1]:
from deepkit import load_CIFAR10
train_loader, test_loader = load_CIFAR10()

In [2]:
for images, labels in train_loader:
    batch = images
    break

## Next, let's set up the model
We'll be using the standard VGG-Net architecture as per the paper.






In [3]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx

class VGGBlock(nnx.Module):
    def __init__(self, in_features: int, out_features: int, rngs: nnx.Rngs):
        self.conv = nnx.Conv(in_features=in_features, 
                             out_features=out_features,
                             kernel_size=(3, 3), 
                             padding='SAME', 
                             rngs=rngs)
        self.bn = nnx.BatchNorm(num_features=out_features, use_running_average=False, 
                                rngs=rngs)

    def __call__(self, x, training: bool):
        x = self.conv(x)
        x = self.bn(x, use_running_average=not training)
        x = nnx.relu(x)
        return x

class VGGNet(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.convs = [
            VGGBlock(in_features=3, out_features=64,  rngs=rngs),
            VGGBlock(in_features=64, out_features=128, rngs=rngs),
            VGGBlock(in_features=128, out_features=256, rngs=rngs),
            VGGBlock(in_features=256, out_features=256, rngs=rngs),
            VGGBlock(in_features=256, out_features=512, rngs=rngs),
            VGGBlock(in_features=512, out_features=512, rngs=rngs),
            VGGBlock(in_features=512, out_features=512, rngs=rngs),
            VGGBlock(in_features=512, out_features=512, rngs=rngs),
        ]

        self.fc1 = nnx.Linear(in_features=512, out_features=512, rngs=rngs)
        self.fc2 = nnx.Linear(in_features=512, out_features=512, rngs=rngs)
        self.out = nnx.Linear(in_features=512, out_features=10,  rngs=rngs)

    def __call__(self, x, training: bool = True):
        max_pool_after = [0, 1, 3, 5, 7]
        for layer in range(len(self.convs)):
            x = self.convs[layer](x, training)
            if layer in max_pool_after:
                x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = x.squeeze()
        x = nnx.relu(self.fc1(x))
        x = nnx.relu(self.fc2(x))
        x = self.out(x)
        return x

In [4]:
rng_key = jax.random.key(1337)
rngs = nnx.Rngs(rng_key)
model = VGGNet(rngs=rngs)
graphdef, state = nnx.split(model)
param_counts = sum(jax.tree_util.tree_leaves(jax.tree_util.tree_map(lambda x: x.size, state)))
print(f"Initialized model with {param_counts:,} parameters.")
nnx.display(state)

Initialized model with 9,761,930 parameters.


In [6]:
import optax

tx = optax.sgd(learning_rate=0.01, momentum=0.9)
optimizer = nnx.Optimizer(model, tx)


In [16]:
from functools import partial


def loss_fn(model, batch, targets):
    logits = model(batch)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
    return loss


def step_fn(mdoel: nnx.Module, optimizer: nnx.Optimizer, batch: jax.Array, labels: jax.Array):
    loss, grads = nnx.value_and_grad(loss_fn)(model, batch, labels)
    optimizer.update(grads)
    return loss, grads

In [17]:
num_steps = 50

In [18]:
num_steps = 50
step = 0
for batch, labels in train_loader:
    batch = jnp.array(batch)
    labels = jnp.array(labels)
    loss, grads = step_fn(model, optimizer, batch, labels)
    print(loss)
    step += 1
    if step == num_steps:
        break

1.5034597
1.5275065
1.5373363
1.4794407
1.6233981
1.5572228
1.725085
1.5889261
1.3287871
1.364134
1.5321941
1.4493389
1.6147181
1.5473435
1.6925219
1.4746581
1.5144938
1.5387815
1.510521
1.3089714
1.4482768
1.427937
1.3845593
1.7322769
1.3451676
1.4809359
1.3829606
1.5216036
1.500274
1.3754203
1.4822083
1.3936832
1.4991524
1.6139977
1.3755976
1.5631541
1.704683
1.3772638
1.3870444
1.4636002
1.4720532
1.3482751
1.3566463
1.4308008
1.4814514
1.3342916
1.4501083
1.501696
1.3991804
1.3116863
