# Problem: Implement parameter initialization strategies for a CNN model in Pytorch

### Problem Statement

You are tasked with employing and evaluating a CNN model\'s parameter initialization strategies in Pytorch.
Your goal is to initialize the weights and biases of a vanilla CNN model provided in the problem statement and comment on the implications of each strategy.

### Requirements

1. **Initialize** weights and biases in the following ways:
   - **Zero Initialization**: set the parameters to zero
   - **Random Initialization**: sets model parameters to random values drawn from a normal distribution
   - **Xavier Initialization** sets them to random values from a normal distribution with **mean=0 and variance=1\/n**
   - **Kaiming He Initialization** initializes to random values from a normal distribution with **mean=0 and variance=2\/n**
2. Train and compute accuracy for each strategy

### Constraints

- Use the given CNN model and the training and testing helper functions for accuracy computations.
- Ensure the model is compatible with the CIFAR-10 dataset, which contains 10 classes.

<details>
  <summary>💡 Hint</summary>
  - Use `torch.nn.init` for weight initialization
  <br>
  - Resources to read: [All you need is a good init](https://arxiv.org/pdf/1511.06422)
</details>


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

train_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:37<00:00, 4554780.54it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
def train_test_loop(model, train_loader, test_loader, epochs=10):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        for image, label in train_loader:
            pred = model(image)
            loss = criterion(pred, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Training loss at epoch {epoch} = {loss.item()}")

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for image_test, label_test in test_loader:
            pred_test = model(image_test)
            _, pred_test_vals = torch.max(pred_test, dim=1)
            total += label_test.size(0)
            correct += (pred_test_vals == label_test).sum().item()
    print(f"Test Accuracy = {(correct * 100) / total}")

In [4]:
class VanillaCNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [9]:
def config_init(init_type="kaiming"):
    # TODO: Add Kaiming initialization according to a CNN or a Linear layer
    def kaiming_init(m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    # TODO: Implement Xavier initialization
    def xavier_init(m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    # TODO: Initialize weights and biases to zero
    def zeros_init(m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.zeros_(m.weight)
            nn.init.zeros_(m.bias)

    # TODO: Initialize weights and biases from a normal distribution
    def random_init(m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.normal_(m.weight)
            nn.init.normal_(m.bias)

    initializer_dict = {
        "kaiming": kaiming_init,
        "xavier": xavier_init,
        "zeros": zeros_init,
        "random": random_init,
    }

    return initializer_dict.get(init_type)

In [10]:
for name, model in zip(
    ["Vanilla", "Kaiming", "Xavier", "Zeros", "Random"],
    [
        VanillaCNNModel(),
        VanillaCNNModel().apply(config_init("kaiming")),
        VanillaCNNModel().apply(config_init("xavier")),
        VanillaCNNModel().apply(config_init("zeros")),
        VanillaCNNModel().apply(config_init("random")),
    ],
):
    print(f"_________{name}_______________________")
    train_test_loop(model, train_loader, test_loader)

_________Vanilla_______________________
Training loss at epoch 0 = 1.0964444875717163
Training loss at epoch 1 = 0.6543671488761902
Training loss at epoch 2 = 0.6886001229286194
Training loss at epoch 3 = 0.153495192527771
Training loss at epoch 4 = 0.5994948148727417
Training loss at epoch 5 = 0.16613924503326416
Training loss at epoch 6 = 0.14586687088012695
Training loss at epoch 7 = 0.23620009422302246
Training loss at epoch 8 = 0.1030372679233551
Training loss at epoch 9 = 0.08064061403274536
Test Accuracy = 67.84
_________Kaiming_______________________
Training loss at epoch 0 = 0.8084465265274048
Training loss at epoch 1 = 0.48890042304992676
Training loss at epoch 2 = 0.9989274144172668
Training loss at epoch 3 = 0.801313579082489
Training loss at epoch 4 = 0.6529940962791443
Training loss at epoch 5 = 0.19916988909244537
Training loss at epoch 6 = 0.03064189851284027
Training loss at epoch 7 = 0.20245900750160217
Training loss at epoch 8 = 0.04117992892861366
Training loss at 