In [27]:
%matplotlib inline
import os
import sys
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import ConcatDataset
import torch.optim as optim
from torchmetrics import Accuracy
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter


import matplotlib.pyplot as plt
import numpy as np

# Settings 

torch.set_printoptions(precision=3)

In [28]:
# useful constants
BATCH_SIZE = 128

In [29]:
transform_normalize = transforms.Compose([
    # transforms.Resize(32),  # Resize the images to 224x224
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize((0.1307), (0.30811))  # Normalize with MNIST's mean and std for each channel
])

mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                download=True, transform=transform_normalize)

mnist_testset = torchvision.datasets.MNIST(root='./data', train=False,
                                 download=True, transform=transform_normalize)

classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')

train_dataset, val_dataset = torch.utils.data.random_split(mnist_trainset, [50000, 10000])

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=BATCH_SIZE,
                                            shuffle=False, num_workers=2)

valloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE,
                                            shuffle=False, num_workers=2)

In [30]:
class ModifiedVGG16(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedVGG16, self).__init__()
        
        self.features = nn.Sequential(
            # Adjust the first conv layer to accept 1-channel input
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                init.constant_(m.bias, 0)

In [31]:
class VGG16_MNIST(nn.Module):
    def __init__(self):
        super(VGG16_MNIST, self).__init__()
        
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(128 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [33]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
accuracy = Accuracy(task='multiclass', num_classes=10)

accuracy = accuracy.to(device)
model_vgg16 = VGG16_MNIST().to(device)

criterion = nn.CrossEntropyLoss()
optimizer_adam = optim.Adam(model_vgg16.parameters(), lr=1e-4)

optimizer_sgd = torch.optim.SGD(model_vgg16.parameters(), lr=0.01, momentum=0.9)

In [34]:
summary(model=model_vgg16, input_size=(1, 1, 28, 28), col_width=17,
                  col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0)

Layer (type (var_name))                  Input Shape       Output Shape      Param #           Trainable
VGG16_MNIST (VGG16_MNIST)                [1, 1, 28, 28]    [1, 10]           --                True
├─Sequential (features)                  [1, 1, 28, 28]    [1, 128, 7, 7]    --                True
│    └─Conv2d (0)                        [1, 1, 28, 28]    [1, 64, 28, 28]   640               True
│    └─ReLU (1)                          [1, 64, 28, 28]   [1, 64, 28, 28]   --                --
│    └─Conv2d (2)                        [1, 64, 28, 28]   [1, 64, 28, 28]   36,928            True
│    └─ReLU (3)                          [1, 64, 28, 28]   [1, 64, 28, 28]   --                --
│    └─MaxPool2d (4)                     [1, 64, 28, 28]   [1, 64, 14, 14]   --                --
│    └─Conv2d (5)                        [1, 64, 14, 14]   [1, 128, 14, 14]  73,856            True
│    └─ReLU (6)                          [1, 128, 14, 14]  [1, 128, 14, 14]  --                --
│  

In [35]:
# Track the loss and accuracy
timestamp = datetime.now().strftime("%Y-%m-%d")
experiment_name = f'VGG16_MNIST_{timestamp}'
model_name = 'Vgg16'
log_dir = os.path.join('runs', timestamp, experiment_name, model_name)
log_writer = SummaryWriter(log_dir=log_dir)

In [36]:
for epoch in range(10):  # loop over the dataset multiple times
    train_loss, train_acc = 0, 0
    cumulative_batch = 0


    for X, y in trainloader:
        
        model_vgg16.train()
        X, y = X.to(device), y.to(device)

        cumulative_batch += BATCH_SIZE
        print(f'Batch {cumulative_batch} / 50000')

        y_pred = model_vgg16(X)

        loss = criterion(y_pred, y)
        train_loss += loss.item()

        acc = accuracy(y_pred, y)
        train_acc += acc
        print(f'Loss: {train_loss}, Accuracy: {acc}')

        optimizer_adam.zero_grad()
        loss.backward()
        optimizer_adam.step()
    
    train_loss /= len(trainloader)
    train_acc /= len(trainloader)

    val_loss, val_acc = 0, 0
    model_vgg16.eval()

    with torch.inference_mode():
        for X, y in valloader:
            X, y = X.to(device), y.to(device)
            y_pred = model_vgg16(X)
            loss = criterion(y_pred, y)
            val_loss += loss.item()
            acc = accuracy(y_pred, y)
            val_acc += acc
            print(f'Val Loss: {val_loss}, Val Accuracy: {acc}')
        
        val_loss /= len(valloader)
        val_acc /= len(valloader)


    log_writer.add_scalars(main_tag="Loss", tag_scalar_dict={"train/loss": train_loss, "val/loss": val_loss}, global_step=epoch)
    log_writer.add_scalars(main_tag="Accuracy", tag_scalar_dict={"train/acc": train_acc, "val/acc": val_acc}, global_step=epoch)
    print(f'Epoch {epoch+1}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}')

Batch 128 / 50000
Loss: 2.3054158687591553, Accuracy: 0.078125
Batch 256 / 50000
Loss: 4.602585554122925, Accuracy: 0.1484375
Batch 384 / 50000
Loss: 6.888758659362793, Accuracy: 0.265625
Batch 512 / 50000
Loss: 9.168584823608398, Accuracy: 0.21875
Batch 640 / 50000
Loss: 11.444666862487793, Accuracy: 0.171875
Batch 768 / 50000
Loss: 13.693395853042603, Accuracy: 0.296875
Batch 896 / 50000
Loss: 15.918555498123169, Accuracy: 0.3125
Batch 1024 / 50000
Loss: 18.113880395889282, Accuracy: 0.3515625
Batch 1152 / 50000
Loss: 20.311052560806274, Accuracy: 0.25
Batch 1280 / 50000
Loss: 22.476406812667847, Accuracy: 0.3046875
Batch 1408 / 50000
Loss: 24.548532724380493, Accuracy: 0.3984375
Batch 1536 / 50000
Loss: 26.574649572372437, Accuracy: 0.4609375
Batch 1664 / 50000
Loss: 28.540095925331116, Accuracy: 0.515625
Batch 1792 / 50000
Loss: 30.454214334487915, Accuracy: 0.5390625
Batch 1920 / 50000
Loss: 32.247523188591, Accuracy: 0.59375
Batch 2048 / 50000
Loss: 33.872238874435425, Accuracy: 

In [37]:
torch.save(model_vgg16.state_dict(), 'mini_vgg16.pth')