# Implement parameter initialization for a CNN


## Problem Statement
You are tasked with employing and evaluating a CNN model's parameter initialization strategies in Pytorch. Your goal is to initialize the weights and biases of a vanilla CNN model provided in the problem statement and comment on the implications of each strategy.

### Requirements
1. **Initialize** weights and biases in the following ways:
   - **Zero Initialization**: set the parameters to zero
   - **Random Initialization**: sets model parameters to random values drawn from a normal distribution
   - **Xavier Initialization** sets them to random values from a normal distribution with **mean=0 and variance=1/n**
   - **Kaiming He Initialization** initializes to random values from a normal distribution with **mean=0 and variance=2/n**
2. Train and compute accuracy for each strategy

### Constraints
   - Use the given CNN model and the training and testing helper functions for accuracy computations.
   - Ensure the model is compatible with the CIFAR-10 dataset, which contains 10 classes.

**! Hint:**
   - Use `torch.nn.init` for weight initialization
   - Resources to read: [All you need is a good init](https://arxiv.org/pdf/1511.06422)


### Code template

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the CNN Model
# TODO: Add convolutional, pooling, and fully connected layers
class CNNModel(nn.Module):
    def __init__(self):
        ...

    def forward(self, x):
        ...

# Initialize the model, loss function, and optimizer
model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 10
for epoch in range(epochs):
    for images, labels in train_loader:
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")


# Evaluate on the test set
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

## Solution

### Rephrase

- initialize the weights and biases of a vanilla CNN model
   - Zero Initialization: set the parameters to zero
   - Random Initialization: sets model parameters to random values drawn from a normal distribution
   - Xavier Initialization sets them to random values from a normal distribution with mean=0 and variance=1/n
   - Kaiming He Initialization initializes to random values from a normal distribution with mean=0 and variance=2/n
- comment on the implications of each strategy
- ensure the model is compatible with the CIFAR-10 dataset, which contains 10 classes.

**Note**:
- By default, PyTorch uses Xavier initialization implemented via Kaiming code (`a=math.sqrt(5)` in `kaiming_uniform_` is mathematically equivalent to Xavier uniform).". For details, see https://github.com/pytorch/pytorch/issues/57109)
- Weights can be redefined, but not in `__init__` (just defining the architecture), and not in `forward` (only calculate the pass) - but after the model is created, before training.

## Implementation notes

| Strategy | Formulas | Behavior | Expected accuracy |
|----------|----------|----------|----------|
| Zero Initialization | W = 0, b = 0 | All neurons in a layer are identical -> symmetry is not broken -> gradients are the same -> learning does not occur. | ~10% (guessing) |
| Random (Naive) | W ~ N(0, 1) | The symmetry is broken. But: with deep networks -> exploding/vanishing gradients: Var(output) = n_in * Var(weight) * Var(input) -> grows/decays exponentially | Low/unstable |
| Xavier (Glorot) | W ~ N(0, 1/n_in) or 1/(n_in + n_out) | Preserves signal variance on forward and backward passes -> Suitable for tanh/sigmoid, but not for ReLU. | Good (but not optimal for ReLU) |
| Kaiming (He) | W ~ N(0, 2/n_in) | Takes into account that ReLU zeroes out ~50% of the values ‚Äã‚Äã-> compensates for variance losses. -> Standard for ReLU/CNN. | Best (fast convergence, high accuracy) |

`n_in` = `fan_in` = number of input features (e.g. `in_channels * kernel_size¬≤` for `Conv2d`).

### Zero Initialization Principles

- All neurons in a single layer receive the same weights (zeros).
- On the first forward pass, they compute the same outputs.
- During backpropagation, the gradients for all weights in a single layer will be the same.
- As a result, all weights will be updated identically.

All neurons in a layer remain symmetrical - they learn the same features and are essentially copies of each other. This completely invalidates the idea of ‚Äã‚Äãhaving multiple neurons in a layer!

`W := 0 - lr * grad_W` -> All weights will become equal to `-lr * grad_W` (same value)

Sometimes bias is initialized to units or small positive values ‚Äã‚Äãto avoid problems with "vanishing gradients".

### Xavier (Glorot) Initialization Principles

- Goal: Make the variance of activations the same at the input and output of each layer, both on the forward and backward passes.
- W ~ N(0, œÉ¬≤),
  - œÉ¬≤ = 1 / n_in (*Xavier-normal*, fan-in only) **or**
  - œÉ¬≤ = 2 / (n_in + n_out) (*Glorot-normal*, default –≤ `xavier_normal_`) - harmonic mean between forward and reverse pass requirements. Suitable when the activation is symmetrical, for example: `tanh`, `sigmoid` (their derivatives are ~1 near 0)
  - **`n_in`** = `fan_in` = number of inputs to the neuron (e.g. `3*3*3 = 27` for the first `Conv2d(3 -> 16)`)

! **Not for Relu**

ReLU zeroes out ~50% of the values ‚Äã‚Äã‚Üí the output variance is half what it should be.

**Xavier undercompensates** - the signal still attenuates

### Kaiming (He) Initialization Principles

- Compensates for the loss of 50% variance in ReLU
- `W ~ N(0, œÉ¬≤)`
   - œÉ¬≤ = 2 / n_in
- Although at 2025, PyTorch provides kaiming_uniform_() function, the DEFAULT initialization in nn.Conv2d/Linear is actually Xavier-equivalent, not true Kaiming for ReLU (https://github.com/pytorch/pytorch/issues/57109)

**Xavier** preserves signal variance for symmetric activations like tanh. **Kaiming** does the same for ReLU - it doubles the weight variance to compensate for ReLU zeroing out half of the values. Therefore, **for modern CNNs with ReLU, Kaiming is the standard**.

## Solution Code

### Fast Code Attempt (base by provided template)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the CNN Model
# TODO: Add convolutional, pooling, and fully connected layers
class CNNModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.network = nn.Sequential(

    nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(32 * 8 * 8, 512),
    nn.ReLU(),
    nn.Linear(512, 10)
  )

  def forward(self, x):
    return self.network(x)

init_strategies = ['zero', 'random']

# apply(fn) ‚Äî recursively applies the function fn to all submodules (including self), calling fn(module) for each.
def make_initializer(strategy):
  def init_fn(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
      if strategy == 'zero':
        nn.init.zeros_(m.weight)
      elif strategy == 'random':
        nn.init.normal_(m.weight, 0.0, 1.0)
  return init_fn


# Initialize the model, loss function, and optimizer
for init_strategy in init_strategies:
  model = CNNModel()
  criterion = nn.CrossEntropyLoss()
  print("\n")
  print("*"*25, '\n')
  print(f"*** Strategy: {init_strategy.upper()} ***\n")

  print("* Before redefining * \n")
  print("Part of weights: \n\n", model.network[0].weight[0, 0, :2, :2], "\n")
  print("Sum of weights: ", model.network[0].weight.sum())
  print("\n")

  model.apply(make_initializer(init_strategy))

  print("* After redefining * \n")
  print("Part of weights: \n\n", model.network[0].weight[0, 0, :2, :2], "\n")
  print("Sum of weights: ", model.network[0].weight.sum())
  print("\n")

  print("*** Training \n")

  optimizer = optim.Adam(model.parameters(), lr=0.001)

  # Training loop
  epochs = 10
  for epoch in range(epochs):
    for images, labels in train_loader:
      # Forward pass
      outputs = model(images)
      loss = criterion(outputs, labels)

      # Backward pass and optimization
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")


  # Evaluate on the test set
  correct = 0
  total = 0
  with torch.no_grad():
    for images, labels in test_loader:
      outputs = model(images)
      _, predicted = torch.max(outputs, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  print(f"Test Accuracy: {100 * correct / total:.2f}%")



************************* 

*** Strategy: ZERO ***

* Before redefining * 

Part of weights: 

 tensor([[-0.1416, -0.0199],
        [ 0.0670, -0.0019]], grad_fn=<SliceBackward0>) 

Sum of weights:  tensor(2.2325, grad_fn=<SumBackward0>)


* After redefining * 

Part of weights: 

 tensor([[0., 0.],
        [0., 0.]], grad_fn=<SliceBackward0>) 

Sum of weights:  tensor(0., grad_fn=<SumBackward0>)


*** Training 

Epoch [1/10], Loss: 2.3021
Epoch [2/10], Loss: 2.3059
Epoch [3/10], Loss: 2.3018
Epoch [4/10], Loss: 2.2996
Epoch [5/10], Loss: 2.3065
Epoch [6/10], Loss: 2.3032
Epoch [7/10], Loss: 2.3006
Epoch [8/10], Loss: 2.3030
Epoch [9/10], Loss: 2.3022
Epoch [10/10], Loss: 2.3032
Test Accuracy: 10.00%


************************* 

*** Strategy: RANDOM ***

* Before redefining * 

Part of weights: 

 tensor([[0., 0.],
        [0., 0.]], grad_fn=<SliceBackward0>) 

Sum of weights:  tensor(0., grad_fn=<SumBackward0>)


* After redefining * 

Part of weights: 

 tensor([[ 0.2839, -0.3263],

### Full solution code

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

results = {}

# prepare init strategies
# by defaule in PyTorch nn.Conv2d ‚Üí kaiming_uniform, nn.Linear ‚Üí kaiming_uniform_ (ReLU) or xavier_uniform_
init_strategies = ['default', 'zero', 'random', 'xavier', 'kaiming']

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the CNN Model
# TODO: Add convolutional, pooling, and fully connected layers
class CNNModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.network = nn.Sequential(

    nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(32 * 8 * 8, 512),
    nn.ReLU(),
    nn.Linear(512, 10)
  )

  def forward(self, x):
    return self.network(x)

# apply(fn) ‚Äî recursively applies the function fn to all submodules (including self), calling fn(module) for each.
# I initialize bias to 0 because this is standard practice: weights start out distributed around 0, and bias should be symmetrical.
def make_initializer(strategy):
  def init_fn(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
      if strategy == 'default':
        pass
      if strategy == 'zero':
        nn.init.zeros_(m.weight)
        if m.bias is not None:
          nn.init.zeros_(m.bias)
      elif strategy == 'random':
        nn.init.normal_(m.weight, 0.0, 1.0)
        if m.bias is not None:
          nn.init.zeros_(m.bias)
      elif strategy == 'xavier':
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
          nn.init.zeros_(m.bias)
      elif strategy == 'kaiming':
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
          nn.init.zeros_(m.bias)
  return init_fn

def train_and_evaluate(model, train_loader, test_loader, criterion, optimizer, epochs=10):
  loss_history = []
  for epoch in range(epochs):
    model.train()
    epoch_losses = []
    for images, labels in train_loader:
      outputs = model(images)
      loss = criterion(outputs, labels)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      epoch_losses.append(loss.item())

    avg_loss = sum(epoch_losses) / len(epoch_losses)
    loss_history.append(avg_loss)
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}")

  # Evaluate
  model.eval()
  correct = 0
  total = 0
  with torch.no_grad():
    for images, labels in test_loader:
      outputs = model(images)
      _, predicted = torch.max(outputs, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  accuracy = 100 * correct / total
  return loss_history, accuracy

for init_strategy in init_strategies:
  print("\n" + "="*40)
  print(f"*** Strategy: {init_strategy.upper()} ***")
  print("="*40)

  model = CNNModel()  # ‚Üê –ù–û–í–ê–Ø –º–æ–¥–µ–ª—å –¥–ª—è –∫–∞–∂–¥–æ–π —Å—Ç—Ä–∞—Ç–µ–≥–∏–∏!

  print("* Before redefining * \n")
  print("Part of weights: \n\n", model.network[0].weight[0, 0, :2, :2], "\n")
  print("Sum of weights: ", model.network[0].weight.sum().item())
  print("\n")

  model.apply(make_initializer(init_strategy))

  print("* After redefining * \n")
  print("Part of weights: \n\n", model.network[0].weight[0, 0, :2, :2], "\n")
  print("Sum of weights: ", model.network[0].weight.sum().item())
  print("\n")

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=0.001)

  loss_history, accuracy = train_and_evaluate(model, train_loader, test_loader, criterion, optimizer, epochs=8)

  results[init_strategy] = {
      'loss_history': loss_history,
      'accuracy': accuracy
  }

  print(f"\nFinal Test Accuracy: {accuracy:.2f}%\n")

print("\n" + "="*50)
print("Summary of Initialization Strategies")
print("="*50)
print(f"{'Strategy':<10} | {'Final Loss':<12} | {'Accuracy':<10}")
print("-" * 50)
for strategy, data in results.items():
    final_loss = data['loss_history'][-1]
    accuracy = data['accuracy']
    print(f"{strategy:<10} | {final_loss:<12.4f} | {accuracy:<9.2f}%")
print("="*50)


*** Strategy: DEFAULT ***
* Before redefining * 

Part of weights: 

 tensor([[ 0.1077, -0.1501],
        [-0.0384,  0.0809]], grad_fn=<SliceBackward0>) 

Sum of weights:  -1.3259371519088745


* After redefining * 

Part of weights: 

 tensor([[ 0.1077, -0.1501],
        [-0.0384,  0.0809]], grad_fn=<SliceBackward0>) 

Sum of weights:  -1.3259371519088745


Epoch [1/8], Loss: 1.3928
Epoch [2/8], Loss: 1.0373
Epoch [3/8], Loss: 0.8621
Epoch [4/8], Loss: 0.7250
Epoch [5/8], Loss: 0.6030
Epoch [6/8], Loss: 0.4871
Epoch [7/8], Loss: 0.3760
Epoch [8/8], Loss: 0.2770

Final Test Accuracy: 71.13%


*** Strategy: ZERO ***
* Before redefining * 

Part of weights: 

 tensor([[-0.1660, -0.0113],
        [ 0.0037, -0.0995]], grad_fn=<SliceBackward0>) 

Sum of weights:  0.19332683086395264


* After redefining * 

Part of weights: 

 tensor([[0., 0.],
        [0., 0.]], grad_fn=<SliceBackward0>) 

Sum of weights:  0.0


Epoch [1/8], Loss: 2.3027
Epoch [2/8], Loss: 2.3027
Epoch [3/8], Loss: 2.3027

## Results

| Strategy | Final Loss | Accuracy | Status |
|----------|------------|----------|--------|
| default  | 0.2770     | 71.13%   | ‚úÖ **Best** |
| zero     | 2.3027     | 10.00%   | üî¥ **Failed** |
| random   | 57.3428    | 28.92%   | üü° **Poor** |
| xavier   | 0.1859     | 70.12%   | üü¢ **Good** |
| kaiming  | 0.1984     | 68.98%   | üü¢ **Good** |

---

- Zero initialization is disastrous (accuracy = 10% - random guessing)
- Random without scaling is also bad (huge losses)
- Xavier and Kaiming perform almost as well as default
- PyTorch's default shows the best results

## Why initialization is important in practice, thoughts

Until 2015, deep networks (more than 20-30 layers) had issues with stable training. They suffered from two problems:
- Vanishing gradients ‚Äì gradients became too small
- Exploding gradients ‚Äì gradients became too large

In 2015, Kaiming He proposed an initialization specifically designed for ReLU activations and then (together with the Microsoft Research team) presented the 152-layer ResNet convolutional network:
- Top-5 error: 3.57% (exceeded human performance for the first time!)
- Won the ImageNet 2015 competition

Key innovations of ResNet:
- Skip connections and working with ReLUs and BatchNorm
- Residual learning - learns the difference (residual) rather than a direct transformation

Standard initialization (Xavier/Glorot) worked well from 2010 for sigmoid/tanh, but failed for ReLU.

Kaiming and Xavier are still the basis of production training. It is Important to remember that in modern architectures (e.g. Transformers) these initialization schemes are often modified and adapted to specific layers (e.g. the initialization of weights in linear layers and attention layers may differ).

### Fine-tuning
Don't reinitialize the embedding layers‚Äîthey've already been trained!
But for new head layers, Kaiming/Xavier are a must.

## Post-Thinking notes

Even with Kaiming/Xavier initialization and BatchNorm, training ultra-deep networks (hundreds or thousands of layers) remains challenging.
BatchNorm itself introduces limitations: it depends on batch size, breaks in small-batch regimes (e.g., RL, medical imaging), and is incompatible with recurrent architectures.

Modern approaches - such as Fixup, ReZero, and DeepNorm - shift the paradigm:
instead of preserving variance layer by layer (as Xavier/Kaiming do), they are designed to ensure stable signal propagation in the infinite-depth limit..

For example, they initialise residual branches to zero or scale them by `1/‚àöL`, making the network behave like an identity map at initialization - which enables training of 1000+ layer models even without normalization layers

This group of methods is the next evolutionary step after He/Xavier for extreme scenarios.

## Interesting Publications

- The task, without a solution, was taken from [here](https://github.com/Exorust/TorchLeet/tree/main/torch/medium)

### Kaiming He, ResNet
- [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. (Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun)](https://arxiv.org/abs/1502.01852)
- [Deep Residual Learning for Image Recognition. (Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun)](https://arxiv.org/abs/1512.03385)
- [Identity Mappings in Deep Residual Networks. (Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun)](https://arxiv.org/abs/1603.05027)

### Useful niche solutions
- Fixup (Training without BatchNorm, batch size < 8) -> RL without normalization
  - [Fixup Initialization: Residual Learning Without Normalization. (Hongyi Zhang, Yann N. Dauphin, Tengyu Ma)](https://arxiv.org/abs/1901.09321)
- T-Fixup (Transformers without LayerNorm) -> Transformers on edge devices
- LayerScale (Weak residual links in ViTs) -> ViT research