W2DAY10_Weight Initialization

Goal:
- Initialize network weights using torch.nn.init functions like xavier_uniform_, kaiming_normal_, etc.

Description
- how to manually apply different weight initialization techniques (e.g., Xavier, He) in PyTorch models to improve convergence.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [2]:
transform = transforms.ToTensor()
train_ds = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root='data', train=False, download=True, transform=transform)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=300, shuffle=False)

100%|██████████| 9.91M/9.91M [00:00<00:00, 17.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 483kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.43MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.79MB/s]


In [3]:
class InitNet(nn.Module):
    def __init__(self):
        super(InitNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.out = nn.Linear(128, 10)
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.kaiming_normal_(self.fc2.weight)
        nn.init.constant_(self.out.weight, 0.1)

        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)
        nn.init.zeros_(self.out.bias)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

In [5]:
model = InitNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [6]:
from tqdm import tqdm

In [8]:
for epoch in range(7):
    model.train()
    running_losss = 0.0
    pbar = tqdm(train_dl, desc=f'Epoch {epoch + 1}', leave=False)
    for image, label in pbar:
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        running_losss += loss.item()
        pbar.set_postfix(loss=loss.item())

    # Use the correct variable name 'running_losss'
    avg_loss = running_losss / len(train_dl)
    print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}")



Epoch 1, Avg Loss: 0.1120




Epoch 2, Avg Loss: 0.0751




Epoch 3, Avg Loss: 0.0538




Epoch 4, Avg Loss: 0.0389




Epoch 5, Avg Loss: 0.0332




Epoch 6, Avg Loss: 0.0247


                                                                       

Epoch 7, Avg Loss: 0.0205


