# Tutorial 10: Batch Normalization

This notebook implements BatchNorm from scratch and visualizes its effects on training.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
np.random.seed(42)
torch.manual_seed(42)

## Part 1: BatchNorm from Scratch

In [None]:
class BatchNorm1D:
    """Batch Normalization implemented from scratch"""
    
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        
        # Learnable parameters
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)
        
        # Running statistics for inference
        self.running_mean = np.zeros(num_features)
        self.running_var = np.ones(num_features)
        
        # Cache for backward pass
        self.cache = None
        self.training = True
    
    def forward(self, x):
        """
        Forward pass.
        x: input, shape (batch_size, num_features)
        """
        if self.training:
            # Batch statistics
            mu = np.mean(x, axis=0)
            var = np.var(x, axis=0)
            
            # Update running statistics
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mu
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mu = self.running_mean
            var = self.running_var
        
        # Normalize
        std = np.sqrt(var + self.eps)
        x_norm = (x - mu) / std
        
        # Scale and shift
        out = self.gamma * x_norm + self.beta
        
        # Cache for backward
        self.cache = (x, x_norm, mu, var, std)
        
        return out
    
    def backward(self, dout):
        """
        Backward pass.
        dout: upstream gradient, shape (batch_size, num_features)
        """
        x, x_norm, mu, var, std = self.cache
        m = x.shape[0]
        
        # Gradients for gamma and beta
        self.dgamma = np.sum(dout * x_norm, axis=0)
        self.dbeta = np.sum(dout, axis=0)
        
        # Gradient for x (the complex part!)
        dx_norm = dout * self.gamma
        
        # Using the derived formula from theory
        dx = (1 / (m * std)) * (m * dx_norm - np.sum(dx_norm, axis=0) 
                                 - x_norm * np.sum(dx_norm * x_norm, axis=0))
        
        return dx

# Test our implementation
bn = BatchNorm1D(4)
x = np.random.randn(32, 4) * 5 + 3  # Non-zero mean, non-unit variance
out = bn.forward(x)

print("Input statistics:")
print(f"  Mean: {x.mean(axis=0)}")
print(f"  Var:  {x.var(axis=0)}")
print("\nOutput statistics (should be ~0 mean, ~1 var):")
print(f"  Mean: {out.mean(axis=0)}")
print(f"  Var:  {out.var(axis=0)}")

## Part 2: Effect on Activation Distributions

In [None]:
# Create networks with and without BatchNorm
class DeepNetNoBN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 256), nn.Sigmoid(),
            nn.Linear(256, 256), nn.Sigmoid(),
            nn.Linear(256, 256), nn.Sigmoid(),
            nn.Linear(256, 256), nn.Sigmoid(),
            nn.Linear(256, 10)
        )
    
    def forward(self, x):
        return self.layers(x)

class DeepNetWithBN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 256), nn.BatchNorm1d(256), nn.Sigmoid(),
            nn.Linear(256, 256), nn.BatchNorm1d(256), nn.Sigmoid(),
            nn.Linear(256, 256), nn.BatchNorm1d(256), nn.Sigmoid(),
            nn.Linear(256, 256), nn.BatchNorm1d(256), nn.Sigmoid(),
            nn.Linear(256, 10)
        )
    
    def forward(self, x):
        return self.layers(x)

# Visualize activations
def get_activations(model, x):
    """Get activations at each layer"""
    activations = []
    for layer in model.layers:
        x = layer(x)
        if isinstance(layer, nn.Sigmoid):
            activations.append(x.detach().numpy().flatten())
    return activations

# Random input
x = torch.randn(100, 784)

model_no_bn = DeepNetNoBN()
model_with_bn = DeepNetWithBN()

act_no_bn = get_activations(model_no_bn, x)
model_with_bn.eval()  # Use running stats
act_with_bn = get_activations(model_with_bn, x)

# Plot
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i, (ax, act) in enumerate(zip(axes[0], act_no_bn)):
    ax.hist(act, bins=50, alpha=0.7)
    ax.set_title(f'Layer {i+1} (No BN)\nmean={act.mean():.2f}, std={act.std():.2f}')
    ax.set_xlim(0, 1)

for i, (ax, act) in enumerate(zip(axes[1], act_with_bn)):
    ax.hist(act, bins=50, alpha=0.7, color='green')
    ax.set_title(f'Layer {i+1} (With BN)\nmean={act.mean():.2f}, std={act.std():.2f}')
    ax.set_xlim(0, 1)

plt.suptitle('Activation Distributions: Without vs With BatchNorm', fontsize=14)
plt.tight_layout()
plt.show()

print("Without BN: Activations saturate (all near 0 or 1) → vanishing gradients!")
print("With BN: Activations stay in useful range → healthy gradients!")

## Part 3: Training Speed Comparison

In [None]:
# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for x, y in loader:
        x = x.view(-1, 784)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        correct += (out.argmax(1) == y).sum().item()
        total += len(y)
    
    return total_loss / len(loader), correct / total

# Train both models
epochs = 10
lr = 0.1  # Larger LR to show BN's benefit

model_no_bn = DeepNetNoBN()
model_with_bn = DeepNetWithBN()

opt_no_bn = optim.SGD(model_no_bn.parameters(), lr=lr)
opt_with_bn = optim.SGD(model_with_bn.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

history_no_bn = {'loss': [], 'acc': []}
history_with_bn = {'loss': [], 'acc': []}

print("Training with LR=0.1 (large!)...")
for epoch in range(epochs):
    loss1, acc1 = train_epoch(model_no_bn, train_loader, opt_no_bn, criterion)
    loss2, acc2 = train_epoch(model_with_bn, train_loader, opt_with_bn, criterion)
    
    history_no_bn['loss'].append(loss1)
    history_no_bn['acc'].append(acc1)
    history_with_bn['loss'].append(loss2)
    history_with_bn['acc'].append(acc2)
    
    print(f"Epoch {epoch+1}: No BN acc={acc1:.2%}, With BN acc={acc2:.2%}")

In [None]:
# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history_no_bn['loss'], 'b-', label='Without BN', linewidth=2)
axes[0].plot(history_with_bn['loss'], 'g-', label='With BN', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history_no_bn['acc'], 'b-', label='Without BN', linewidth=2)
axes[1].plot(history_with_bn['acc'], 'g-', label='With BN', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nBatchNorm enables training with much larger learning rates!")

## Part 4: Normalization Variants Comparison

In [None]:
# Visualize different normalization methods
# Create a fake activation tensor: (batch=4, channels=3, height=2, width=2)
x = torch.randn(4, 3, 2, 2) * 5 + 3

# Different normalizations
bn = nn.BatchNorm2d(3)(x)   # Normalize over batch, H, W
ln = nn.LayerNorm([3, 2, 2])(x)  # Normalize over C, H, W
In = nn.InstanceNorm2d(3)(x)  # Normalize over H, W per sample/channel
gn = nn.GroupNorm(1, 3)(x)  # Normalize over groups of channels

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
titles = ['BatchNorm', 'LayerNorm', 'InstanceNorm', 'GroupNorm']
outputs = [bn, ln, In, gn]

for ax, title, out in zip(axes, titles, outputs):
    ax.hist(out.detach().numpy().flatten(), bins=30, alpha=0.7)
    ax.set_title(f'{title}\nmean={out.mean():.3f}, std={out.std():.3f}')
    ax.set_xlabel('Value')

plt.tight_layout()
plt.show()

## Summary

**BatchNorm normalizes activations:**
$$y = \gamma \cdot \frac{x - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} + \beta$$

**Benefits:**
1. Enables larger learning rates
2. Reduces sensitivity to initialization
3. Acts as regularization
4. Keeps activations in high-gradient regime