In [1]:
%matplotlib inline
import os
import sys
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision import transforms, datasets
from torch.utils.data import ConcatDataset
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchmetrics import Accuracy
from torchinfo import summary

from torch.utils.tensorboard import SummaryWriter
sys.path.append('../')  

from Models.alexnet import AlexNet

import matplotlib.pyplot as plt
import numpy as np

# Settings 

torch.set_printoptions(precision=3)

In [2]:
BATCH_SIZE = 64

In [3]:
class MNISTNoZero(datasets.MNIST):
    def __init__(self, *args, **kwargs):
        super(MNISTNoZero, self).__init__(*args, **kwargs)
        
        # Filter out indices of all '0' digits
        self.non_zero_indices = [i for i, target in enumerate(self.targets) if target != 0]
        
        # Keep only the data and targets that are not '0'
        self.data = self.data[self.non_zero_indices]
        self.targets = self.targets[self.non_zero_indices] - 1


# Normalization transform
transform_normalize = transforms.Compose(
    [transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

#Use to get the '0' digits
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_normalize)

# Create training and test datasets without '0' digits
mnist_trainset_no_zero = MNISTNoZero(root='./data', train=True, download=True, transform=transform_normalize)

mnist_testset_no_zero = MNISTNoZero(root='./data', train=False, download=True, transform=transform_normalize)
classes = ('1', '2', '3', '4', '5', '6', '7', '8', '9')


# Verify by checking the unique labels in the modified datasets
print("Unique labels in the modified training set:", mnist_trainset_no_zero.targets.unique())
print("Unique labels in the modified test set:", mnist_testset_no_zero.targets.unique())


Unique labels in the modified training set: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
Unique labels in the modified test set: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])


In [4]:
train_dataset, val_dataset = torch.utils.data.random_split(mnist_trainset_no_zero, [45077, 9000])

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64,
                                          shuffle=True, num_workers=0)

# good
testloader = torch.utils.data.DataLoader(mnist_testset_no_zero, batch_size=64,
                                            shuffle=False, num_workers=0)

valloader = torch.utils.data.DataLoader(val_dataset, batch_size=64,
                                            shuffle=False, num_workers=0)


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [6]:
accuracy = Accuracy(task='multiclass', num_classes=9)

accuracy = accuracy.to(device)

In [7]:
model_alexnet = AlexNet(num_classes=9, channels=1).to(device)

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer_adam = optim.Adam(model_alexnet.parameters(), lr=1e-4)
optimizer_sgd = torch.optim.SGD(model_alexnet.parameters(), lr=0.001, momentum=0.9)

In [9]:
summary(model=model_alexnet, input_size=(1, 1, 224, 224), col_width=20,
                  col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
AlexNet (AlexNet)                        [1, 1, 224, 224]     [1, 9]               --                   True
├─Sequential (features)                  [1, 1, 224, 224]     [1, 256, 6, 6]       --                   True
│    └─Conv2d (0)                        [1, 1, 224, 224]     [1, 64, 55, 55]      7,808                True
│    └─ReLU (1)                          [1, 64, 55, 55]      [1, 64, 55, 55]      --                   --
│    └─MaxPool2d (2)                     [1, 64, 55, 55]      [1, 64, 27, 27]      --                   --
│    └─Conv2d (3)                        [1, 64, 27, 27]      [1, 192, 27, 27]     307,392              True
│    └─ReLU (4)                          [1, 192, 27, 27]     [1, 192, 27, 27]     --                   --
│    └─MaxPool2d (5)                     [1, 192, 27, 27]     [1, 192, 13, 13]     --                   --
│    └─Conv2d (6)     

In [10]:
# Track the loss and accuracy
timestamp = datetime.now().strftime("%Y-%m-%d")
experiment_name = f'AlexNet_One_to_Nine_{timestamp}'
model_name = 'AlexNet_v1'
log_dir = os.path.join('runs', timestamp, experiment_name, model_name)
log_writer = SummaryWriter(log_dir=log_dir)

In [None]:
for epoch in range(15):  # loop over the dataset multiple times
    train_loss, train_acc = 0, 0
    cumulative_batch = 0

    print(f'Epoch {epoch+1}')
    for X, y in trainloader:
        
        model_alexnet.train()
        X, y = X.to(device), y.to(device)

        cumulative_batch += BATCH_SIZE
        print(f'Batch {cumulative_batch} / 50000')

        y_pred = model_alexnet(X)

        loss = criterion(y_pred, y)
        print(f'Loss: {loss.item()}')
        train_loss += loss.item()

        acc = accuracy(y_pred, y)
        train_acc += acc
        print(f'Loss: {train_loss}, Accuracy: {acc}')

        optimizer_adam.zero_grad()
        loss.backward()
        optimizer_adam.step()
    
    train_loss /= len(trainloader)
    train_acc /= len(trainloader)

    val_loss, val_acc = 0, 0
    model_alexnet.eval()

    with torch.inference_mode():
        for X, y in valloader:
            X, y = X.to(device), y.to(device)
            y_pred = model_alexnet(X)
            loss = criterion(y_pred, y)
            val_loss += loss.item()
            acc = accuracy(y_pred, y)
            val_acc += acc
            print(f'Val Loss: {val_loss}, Val Accuracy: {acc}')
        
        val_loss /= len(valloader)
        val_acc /= len(valloader)


    log_writer.add_scalars(main_tag="Loss", tag_scalar_dict={"train/loss": train_loss, "val/loss": val_loss}, global_step=epoch)
    log_writer.add_scalars(main_tag="Accuracy", tag_scalar_dict={"train/acc": train_acc, "val/acc": val_acc}, global_step=epoch)
    print(f'Epoch {epoch+1}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}')