In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [17]:
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_data = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
train_loader = DataLoader(
    train_data,
    batch_size=64,
    shuffle=True
)

In [32]:
class InitWeightNet(nn.Module):
    def __init__(self):
        super(InitWeightNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 10)
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.kaiming_normal_(self.fc2.weight)
        nn.init.constant_(self.output.weight, 0.01)
        
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)
        nn.init.zeros_(self.output.bias)
        
    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.output(x)
        return x

In [33]:
model = InitWeightNet()
criterion = nn.CrossEntropyLoss()  # Loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Optimizer

In [34]:
e = []
l = []
for epoch in range(10):
    for images, labels in train_loader:
        # print(labels)
        
        outputs = model(images) # predict : Forward pass
        # print(outputs)
        # print(outputs.shape, labels.shape)
        loss = criterion(outputs, labels) # calculate loss
        
        optimizer.zero_grad() # clear previous gradients
        loss.backward() # backpropagation: compute gradients
        optimizer.step() # update weights using gradients

    e.append(epoch)
    l.append(loss.item())
    print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')
       

Epoch 1, Loss: 0.1530
Epoch 2, Loss: 0.0475
Epoch 3, Loss: 0.0233
Epoch 4, Loss: 0.0666
Epoch 5, Loss: 0.0626
Epoch 6, Loss: 0.1319
Epoch 7, Loss: 0.0161
Epoch 8, Loss: 0.0065
Epoch 9, Loss: 0.0712
Epoch 10, Loss: 0.0014
