# Tutorial 4-1: The Art of Initialization & Normalization

**Course:** CSEN 342: Deep Learning  
**Topic:** Weight Initialization, Vanishing/Exploding Activations, and Batch Normalization

## Objective
Training deep neural networks is difficult because of the **Vanishing** or **Exploding** gradient problem. As signals pass through many layers, their variance can either shrink to zero (silence) or grow to infinity (instability).

In this tutorial, we will visualize this phenomenon and apply the standard fixes from industry:
1.  **Visualize Activations:** Monitor the distribution of data as it passes through a deep network.
2.  **Bad Initialization:** See what happens when weights are initialized too large ($\sigma=1$) or too small ($\sigma=0.01$).
3.  **Kaiming Initialization:** Apply the mathematical fix derived by Kaiming He to keep variance constant.
4.  **Batch Normalization:** Apply the architectural fix that forces layer inputs to be stable, making initialization less critical.

---

## Part 1: Monitoring Activation Statistics

We need a way to peek inside the network. We will create a Deep MLP that stores the `mean` and `std` (standard deviation) of the activations at every layer during the forward pass.

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms

# Import utility functions
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))
from utils import download_fashion_mnist

# Device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ActivationMonitorNet(nn.Module):
    def __init__(self, depth=10, hidden_dim=100, use_bn=False):
        super().__init__()
        self.layers = nn.ModuleList()
        self.use_bn = use_bn
        self.activation_stats = [] # To store (mean, std) for visualization
        
        # Create a deep network
        for i in range(depth):
            # Input layer handles 784 features, others handle hidden_dim
            input_dim = 28*28 if i == 0 else hidden_dim
            
            # Define blocks: FC -> [BN] -> ReLU
            block = nn.Sequential()
            block.add_module(f'linear_{i}', nn.Linear(input_dim, hidden_dim))
            
            if use_bn:
                # Batch Norm is typically applied BEFORE the activation (Slide 22)
                block.add_module(f'bn_{i}', nn.BatchNorm1d(hidden_dim))
                
            block.add_module(f'relu_{i}', nn.ReLU())
            
            self.layers.append(block)

    def forward(self, x):
        self.activation_stats = [] # Clear previous stats
        x = x.view(x.size(0), -1)  # Flatten
        
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # Record statistics of the output of this layer
            mean = x.mean().item()
            std = x.std().item()
            self.activation_stats.append((mean, std))
            
        return x

    def init_weights(self, method, std=0.01):
        # Helper to re-initialize weights easily
        for m in self.modules():
            if isinstance(m, nn.Linear):
                if method == 'normal':
                    nn.init.normal_(m.weight, mean=0.0, std=std)
                elif method == 'kaiming':
                    nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                nn.init.zeros_(m.bias)

print("Network defined.")

---

## Part 2: The Initialization Problem

We will feed a batch of data through the network and observe the **standard deviation** of activations at each layer. 
* Ideally, `std` should remain stable (around 1.0) so the signal propagates.
* If `std` drops to 0, the signal has vanished.
* If `std` explodes, the signal has saturated or blown up.

In [None]:
# Load one batch of data for testing
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
download_fashion_mnist()
trainset = torchvision.datasets.FashionMNIST(root='../data', train=True, download=False, transform=transform)
loader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True)
dummy_batch, _ = next(iter(loader))
dummy_batch = dummy_batch.to(device)

def visualize_flow(model, title):
    # Run forward pass
    model.to(device)
    with torch.no_grad():
        _ = model(dummy_batch)
    
    # Extract stats
    means, stds = zip(*model.activation_stats)
    layers = range(len(model.activation_stats))
    
    # Plot
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(layers, means, 'o-', label='Mean')
    plt.title(f"{title} - Activations Mean")
    plt.xlabel("Layer Depth")
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(layers, stds, 'o-', color='orange', label='Std Dev')
    plt.title(f"{title} - Activations Std Dev")
    plt.xlabel("Layer Depth")
    plt.grid(True)
    plt.show()

# Experiment 1: Small Weights (std=0.01)
model_small = ActivationMonitorNet(depth=10)
model_small.init_weights('normal', std=0.01)
visualize_flow(model_small, "Small Init (std=0.01)")

# Experiment 2: Large Weights (std=1.0)
model_large = ActivationMonitorNet(depth=10)
model_large.init_weights('normal', std=1.0)
visualize_flow(model_large, "Large Init (std=1.0)")

### Discussion
1.  **Small Init:** You should see the standard deviation collapse to 0 very quickly (Vanishing Activations). The network is essentially dead after a few layers.
2.  **Large Init:** You might see the means shift significantly or variances explode, depending on the randomness. This leads to unstable gradients.

---

## Part 3: The Fix (Kaiming Initialization)

Kaiming He derived a formula specifically for ReLU networks. It sets the variance of weights to $\frac{2}{n_{in}}$ to preserve the variance of activations through the ReLU nonlinearity.

In [None]:
# Experiment 3: Kaiming He Initialization
model_kaiming = ActivationMonitorNet(depth=10)
model_kaiming.init_weights('kaiming')
visualize_flow(model_kaiming, "Kaiming/He Init")

**Success!** The standard deviation should remain roughly stable (constant) across all 10 layers. This allows deep networks to train effectively.

---

## Part 4: The Stabilizer (Batch Normalization)

Sometimes we cannot tune initialization perfectly. **Batch Normalization (BN)** is a technique that explicitly forces the activations of a layer to have mean 0 and variance 1.

Let's assume we messed up and used the **Bad Initialization** (std=0.01), but this time we enable Batch Norm.

In [None]:
# Experiment 4: Bad Init + Batch Norm
model_bn = ActivationMonitorNet(depth=10, use_bn=True)
model_bn.init_weights('normal', std=0.01) # Deliberately bad init
visualize_flow(model_bn, "Bad Init + Batch Norm")

### Observation
Even though we initialized with terrible weights ($0.01$), Batch Norm forced the activations back to a healthy range (Std Dev ~ 1.0). This makes the network robust to initialization choices.

---

## Part 5: Impact on Training

Visualizing stats is nice, but does it actually help the model learn? Let's train three models for just 20 epochs and compare their loss curves.

In [None]:
def train_quick(model, name):
    model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    losses = []
    
    # Create a classification head for the final layer output
    # Our model outputs 100 features, we need 10 classes
    head = nn.Linear(100, 10).to(device)
    # We need to add head parameters to optimizer
    optimizer.add_param_group({'params': head.parameters()})
    
    model.train()
    for i, (inputs, labels) in enumerate(loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        features = model(inputs)
        outputs = head(features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        if i % 10 == 0:
            losses.append(loss.item())
        if i > 200: break # Run for just 200 batches to see initial convergence
            
    return losses

print("Training 'Bad Init'...")
loss_bad = train_quick(model_small, "Bad Init")

print("Training 'Kaiming Init'...")
loss_kaiming = train_quick(model_kaiming, "Kaiming Init")

print("Training 'Batch Norm' (with Bad Init)...")
loss_bn = train_quick(model_bn, "Batch Norm")

# Plot Loss Curves
plt.figure(figsize=(10, 6))
plt.plot(loss_bad, label="Bad Init (0.01)")
plt.plot(loss_kaiming, label="Kaiming Init")
plt.plot(loss_bn, label="Batch Norm (with Bad Init)")
plt.title("Training Convergence Speed")
plt.xlabel("Iterations (x10)")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

### Conclusion
1.  **Bad Init:** Likely gets stuck at a high loss (approx `ln(10) = 2.3`) because signal doesn't propagate.
2.  **Kaiming Init:** Starts learning immediately.
3.  **Batch Norm:** Also starts learning immediately and often converges **faster** than Kaiming init alone because it allows for higher learning rates and smoother gradients.