In [None]:
# Import necessary torch and torchvision libraries
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torchvision.datasets import CIFAR10

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

device = "mps" if torch.backends.mps.is_available() else "cpu"

In [None]:
# Download and load the CIFAR-10 dataset
train_data = CIFAR10(root='./data', 
                     train=True, 
                     download=True, 
                     transform=transforms.ToTensor())

test_data = CIFAR10(root='./data',
                    train=False,
                    download=True,
                    transform=transforms.ToTensor())

In [None]:
# See classes
class_names = train_data.classes
print(class_names) # It is also idx to class -> class_names[1] = 'Trouser
# Class to index
cls_to_idx = train_data.class_to_idx
print(cls_to_idx)

In [None]:
# Create a DataLoader object to load data in batches
train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=32,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                            batch_size=32,
                                            shuffle=False)

In [None]:
class MyCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(MyCNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 64 x 16 x 16

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4

            nn.Flatten(), 
            nn.Linear(in_features=256*4*4, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=num_classes))
    def forward(self, x):
        x = self.model(x)
        return x

firstModel = MyCNN().to(device)

In [None]:
# Import accuracy metric
from torchmetrics import Accuracy
accuracy = Accuracy(task="multiclass", num_classes=len(class_names)).to(device)
# Setup loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=firstModel.parameters(), lr=0.001)

In [None]:
def train_step(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               accuracy,
               device: torch.device = device):
    train_loss, train_acc = 0, 0
    model.to(device)
    
    for batch, (X, y) in enumerate(data_loader):
        
        X = X.to(device)
        y = y.to(device)
        
        # Training
        model.train()
        # Forward pass
        y_pred = model(X)
        # Calculate loss per batch
        loss = loss_fn(y_pred, y)
        train_loss += loss # accumulate loss per batch
        # Update accuracy
        accuracy.update(y_pred, y)
        # Zero the gradients
        optimizer.zero_grad()
        # Backward pass
        loss.backward()
        # Update weights
        optimizer.step()
        # Print loss every 400 batches
        if batch % 400 == 0:
            print(f"Looked at {batch * len(X)}/{len(data_loader.dataset)} samples")
    # Loss per epoch    
    train_loss = train_loss / len(data_loader)
    train_acc = accuracy.compute()
    print(f"Train loss: {train_loss:.5f} | Train accuracy: {train_acc*100:.2f}%")

In [None]:
# Using tqdm for progress bar
from tqdm.auto import tqdm
torch.manual_seed(42)

epochs = 8

for epoch in tqdm(range(epochs)):
    
    print(f"Epoch: {epoch}\n---------")
    
    train_step(model=firstModel, 
        data_loader=train_loader, 
        loss_fn=loss_fn,
        optimizer=optimizer,
        accuracy=accuracy)
    
    test_step(model=firstModel,
        data_loader=test_loader,
        loss_fn=loss_fn,
        accuracy=accuracy)