# How and Why BatchNorm Works - Part 2


In Part 1, we showed how using batch normalization can help train neural networks by reducing internal covariate shift and smoothing gradient updates. We also touch upon the idea that it makes the loss surface smoother and more convex (residual connections have a similar effect). 

In their 2018 paper, Santukar et al. argue that internal covariate shift isn't substantially reduced at all. In fact, when additional noise is added to the activations to force covariate shift, batch normalization continues to improve training performance. This is despite no significant reduction in covariate shift. They argue that the major reason for improved performance is the smoother loss surface.

However, we saw in our own experiments using a shallow neural network that the reduction in covariate shift is substantial. Let's get to the bottom of this.

In this experiment, we will start by reproducing the results from Santukar et al. Then we will try and understand why their results and our previous experiment differ.

## Let's start as usual by setting up the dataset
We will use the CIFAR-10 dataset for this experiment.
The CIFAR-10 dataset is a widely used benchmark in machine learning for evaluating image classification algorithms. It consists of 60,000 color images, each sized 32x32 pixels, divided into 10 distinct classes such as airplanes, automobiles, birds, cats, and ships. The dataset is split into 50,000 training images and 10,000 test images, with an equal number of examples per class. Created by the Canadian Institute for Advanced Research, CIFAR-10 is challenging due to its low resolution and high intra-class variability. It serves as a foundational dataset for developing and comparing deep learning models, especially convolutional neural networks (CNNs).

In [22]:
from deepkit import load_CIFAR10
train_loader, test_loader = load_CIFAR10()

In [None]:
for images, labels in train_loader:
    batch = images
    break

## Next, let's set up the model
We'll be using the standard VGG-Net architecture as per the paper.






In [24]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx

class VGGBlock(nnx.Module):
    def __init__(self, in_features: int, out_features: int, rngs: nnx.Rngs):
        self.conv = nnx.Conv(in_features=in_features, 
                             out_features=out_features,
                             kernel_size=(3, 3), 
                             padding='SAME', 
                             rngs=rngs)
        self.bn = nnx.BatchNorm(num_features=out_features, use_running_average=False, 
                                rngs=rngs)

    def __call__(self, x, training: bool):
        x = self.conv(x)
        x = self.bn(x, use_running_average=not training)
        x = nnx.relu(x)
        return x

class VGGNet(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.convs = [
            VGGBlock(in_features=3, out_features=64,  rngs=rngs),
            VGGBlock(in_features=64, out_features=128, rngs=rngs),
            VGGBlock(in_features=128, out_features=256, rngs=rngs),
            VGGBlock(in_features=256, out_features=256, rngs=rngs),
            VGGBlock(in_features=256, out_features=512, rngs=rngs),
            VGGBlock(in_features=512, out_features=512, rngs=rngs),
            VGGBlock(in_features=512, out_features=512, rngs=rngs),
            VGGBlock(in_features=512, out_features=512, rngs=rngs),
        ]

        self.fc1 = nnx.Linear(in_features=512, out_features=512, rngs=rngs)
        self.fc2 = nnx.Linear(in_features=512, out_features=512, rngs=rngs)
        self.out = nnx.Linear(in_features=512, out_features=10,  rngs=rngs)

    def __call__(self, x, training: bool = True):
        max_pool_after = [0, 1, 3, 5, 7]
        for layer in range(len(self.convs)):
            x = self.convs[layer](x, training)
            if layer in max_pool_after:
                x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = x.squeeze()
        x = nnx.relu(self.fc1(x))
        x = nnx.relu(self.fc2(x))
        x = self.out(x)
        return x

In [None]:
rng_key = jax.random.key(1337)
rngs = nnx.Rngs(rng_key)
model = VGGNet(rngs=rngs)
graphdef, state = nnx.split(model)
param_counts = sum(jax.tree_util.tree_leaves(jax.tree_util.tree_map(lambda x: x.size, state)))
print(f"Initialized model with {param_counts:,} parameters.")
nnx.display(state)

In [26]:
import optax

tx = optax.sgd(learning_rate=0.01, momentum=0.9)
optimizer = nnx.Optimizer(model, tx)

In [43]:
from functools import partial


def loss_fn(model, batch, targets):
    logits = model(batch)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
    return loss

@nnx.jit
def step_fn(model: nnx.Module, optimizer: nnx.Optimizer, batch: jax.Array, labels: jax.Array):
    loss, grads = nnx.value_and_grad(loss_fn)(model, batch, labels)
    optimizer.update(grads)
    return loss, grads


@nnx.jit
def accuracy(model: nnx.Module, batch: jax.Array, labels: jax.Array):
    logits = model(batch)
    #probs = nnx.softmax(logits, axis=-1)
    preds = jnp.argmax(logits, axis=-1)
    sum = jnp.sum(preds == labels)
    acc = sum/logits.shape[0]
    return acc
    

def test_accuracy(model: nnx.Module, testloader):
    acc, n = 0, 0
    for batch, labels in testloader: 
        batch = jnp.array(batch)
        labels = jnp.array(labels)
        acc += accuracy(model, batch, labels)
        n += 1
    return acc/n



In [44]:
num_epochs = 1
num_steps = num_epochs*len(train_loader)

In [None]:
for epoch in range(num_epochs):
    for batch, labels in train_loader:
        batch = jnp.array(batch)
        labels = jnp.array(labels)
        loss, grads = step_fn(model, optimizer, batch, labels)
        acc = accuracy(model, batch, labels)
        print(loss, acc)

In [None]:
test_accuracy(model, test_loader)