### Standard Imports

In [None]:
import os
import sys

sys.path.append(os.path.join(os.getcwd(), "..", "."))  # add parent dir to path

from typing import Tuple

import lightning as pl
import lightning.pytorch.callbacks as pl_callbacks
import matplotlib.pyplot as plt
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchinfo
import torchmetrics
import torchvision
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, sampler
from torchvision import datasets, models
from torchvision import transforms as T  # for simplifying the transforms
from tqdm.notebook import tqdm

# Custom imports
from config import *
from data import *

# Use STIX font for math plotting
plt.rcParams["font.family"] = "STIXGeneral"

import warnings

import torchvision
from termcolor import colored
from torchvision import transforms

warnings.filterwarnings("ignore")

cfg = get_config()
cfg.root_dir = os.path.join(os.getcwd(), "..")
cfg.data_dir = os.path.join(cfg.root_dir, "data")
cfg.model_dir = os.path.join(cfg.root_dir, "weights")
print(colored(f"Config:", "green"))
print(cfg.to_yaml())

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(colored(f"Using device:", "green"), device)

# Seed for reproducability
# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(cfg.seed)
np.random.seed(np.array(cfg.seed))

### Data Loading and Visualization

In [None]:
# Resize the images to 224x224 as is the ImageNet standard
train_transform = transforms.Compose([transforms.ToTensor()])

# CIFAR100 dataset
train_dataset, _, _ = get_cifar100_dataset(cfg.data_dir, train_transform, val_size=0)

# Visualize some images from the dataset
images = torch.stack([train_dataset[i][0] for i in range(18)])
grid = torchvision.utils.make_grid(images[:18], nrow=6, padding=2, pad_value=1)
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0))
plt.title("CIFAR100 Images")
plt.axis("off")
plt.show()

### Model

In [None]:
# ResNet50 model
model = timm.create_model("resnet50", pretrained=True, num_classes=100)

# Print the model summary
torchinfo.summary(model, input_size=(1, 3, 224, 224), depth=2, device="meta")

### Training

In [None]:
train_transform = transforms.Compose(
    [
        transforms.Resize(
            (224, 224), interpolation=transforms.InterpolationMode.BICUBIC
        ),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]
        ),
    ]
)
test_transform = transforms.Compose(
    [
        transforms.Resize(
            (224, 224), interpolation=transforms.InterpolationMode.BICUBIC
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]
        ),
    ]
)

train_dataloader, val_dataloader, test_dataloader = get_cifar100_loaders(
    cfg.data_dir,
    cfg.batch_size,
    cfg.num_workers,
    train_transform,
    test_transform,
    val_size=0.1,
)

In [None]:
class ImageClassifier(pl.LightningModule):
    def __init__(self, model: nn.Module, cfg: dict):
        super().__init__()
        self.model = model
        self.cfg = cfg
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x: torch.Tensor):
        return self.model(x)

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log("val_loss", loss, prog_bar=True)

        # calculate accuracy
        _, preds = torch.max(y_hat, dim=1)
        acc = torchmetrics.functional.accuracy(
            preds, y, num_classes=100, task="multiclass"
        )
        self.log("val_acc", acc)

        return loss

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)

        # calculate accuracy
        _, preds = torch.max(y_hat, dim=1)
        acc = torchmetrics.functional.accuracy(
            preds, y, num_classes=100, task="multiclass"
        )
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc)

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.cfg.lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
        return [optimizer], [scheduler]

In [None]:
theme = pl_callbacks.progress.rich_progress.RichProgressBarTheme(
    description="black",
    progress_bar="cyan",
    progress_bar_finished="green",
    progress_bar_pulse="#6206E0",
    batch_progress="cyan",
    time="grey82",
    processing_speed="grey82",
    metrics="black",
)

# Create the model
model = timm.create_model("resnet50", pretrained=True, num_classes=100)
model = ImageClassifier(model, cfg)

# Create a PyTorch Lightning trainer with the required callbacks
trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    strategy="auto",
    max_epochs=cfg.num_epochs,
    enable_model_summary=False,
    callbacks=[
        pl_callbacks.RichModelSummary(max_depth=3),
        pl_callbacks.RichProgressBar(theme=theme),
    ],
)

torch.set_float32_matmul_precision("medium")

# Train the model
trainer.fit(model, train_dataloader, val_dataloader)

In [None]:
# Save the model
torch.save(model.state_dict(), os.path.join(cfg.model_dir, "resnet50_ft.pth"))

### Evaluation

In [None]:
# Evaluate the model on the test set
trainer.test(model, test_dataloader)