## Install dependencies

In [1]:
%pip install -q lightning wandb torchvision torchmetrics

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


## Import packages

In [13]:
from pathlib import Path
import time

import PIL.Image
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchmetrics
import lightning as L
import wandb
from torch.utils.data import Dataset, DataLoader
from lightning.pytorch.loggers import WandbLogger

from jassair.utils import get_dataset_path, Datasets

## Lower matmul precision

In [14]:
torch.set_float32_matmul_precision('high')
L.seed_everything(42)

## WandB login for experiment tracking

In [15]:
wandb.login()

True

## Global variable definition

In [16]:
DATA_DIR = get_dataset_path(Datasets.SYNTHETIC_SINGLE)
BATCH_SIZE = 32
NUM_CLASSES = 36

## Custom Synth-data Dataset

In [17]:
class YoloStyleDataset(Dataset):
    def __init__(self, root_dir: Path, transform=None):
        self.root_dir = root_dir
        self.image_dir = root_dir / 'images'
        self.label_dir = root_dir / 'labels'
        self.transform = transform
        self.image_files: list[Path] = [f for f in self.image_dir.iterdir()]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path: Path = self.image_files[idx]
        image = PIL.Image.open(image_path).convert("RGB")
        label_path = self.label_dir / f"{image_path.stem}.txt"
        if not label_path.exists():
            raise FileNotFoundError(label_path)
        with label_path.open("r", encoding="utf-8") as f:
            label = int(f.readline().split()[0])
        if self.transform:
            image = self.transform(image)
        return image, label

## Pre-trained Model / Data transform

In [18]:
# Available models: https://pytorch.org/vision/stable/models.html#classification

MODEL_WEIGHTS = models.ResNet50_Weights.IMAGENET1K_V2
TRANSFORM = MODEL_WEIGHTS.transforms()
MODEL = models.resnet50(weights=MODEL_WEIGHTS)

## DataLoader definition

In [19]:
train_dataset = YoloStyleDataset(DATA_DIR / "train", transform=TRANSFORM)
val_dataset = YoloStyleDataset(DATA_DIR / "valid", transform=TRANSFORM)
test_dataset = YoloStyleDataset(DATA_DIR / "test", transform=TRANSFORM)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Sanity check
image, label = train_dataset[1]
print(f"Image shape: {image.shape}, Label: {label}")

Image shape: torch.Size([3, 224, 224]), Label: 29


## Baseline Model

In [20]:
class ImageClassifier(L.LightningModule):
    def __init__(self, lr: float, weight_decay: float, finetune_only: bool):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = MODEL
        
        # If set, only train the newly attached FC layer
        if self.hparams.finetune_only:
            for param in self.model.parameters():
                param.requires_grad = False
                
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, NUM_CLASSES)

        self.criterion = nn.CrossEntropyLoss()
        
        self._train_acc = torchmetrics.Accuracy("multiclass", num_classes=self.hparams.output_dim)
        self._train_loss = []
        self._valid_acc = torchmetrics.Accuracy("multiclass", num_classes=self.hparams.output_dim)
        self._valid_loss = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self._train_acc(outputs, labels)
        self._train_loss.append(loss)
        
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self._valid_acc(outputs, labels)
        self._valid_loss.append(loss)
        
    def on_train_epoch_end(self):
        loss = torch.stack(self._train_loss).mean()
        self.log_dict({'train_loss': loss, 'train_acc': self._train_acc.compute()}, prog_bar=True)
        self._train_loss.clear()
        self._train_acc.reset()

    def on_validation_epoch_end(self):
        loss = torch.stack(self._valid_loss).mean()
        self.log_dict({'val_loss': loss, 'val_acc': self._valid_acc.compute()}, prog_bar=True)
        self._valid_loss.clear()
        self._valid_acc.reset()

    def test_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        
        self.log("test_loss", loss, prog_bar=True, logger=True)
        self.log("test_acc", acc, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, self.hparams.lr / 100, self.hparams.lr)
        return [optimizer], [scheduler]

## Train Model

In [21]:
HYPERPARAMETERS = {
    "lr": 1e-3,
    "wd": 1e-5,
    "ft_only": True,
    "epochs": 100,
}

In [22]:
RUN_NAME = f"{time.strftime('%y%m%d-%H%M%S')}_{MODEL._get_name()}_lr{HYPERPARAMETERS['lr']}"
wandb.init(
    entity="jassair",
    project="BaselineModel",
    name=RUN_NAME,
    config=HYPERPARAMETERS,
)
wandb_logger = WandbLogger(project="BaselineModel")

In [None]:
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
    dirpath="./lightning_checkpoints",
    filename=RUN_NAME + "-{epoch:02d}-{val_acc:.2f}",
    monitor="val_acc",
    save_last=True,
    mode="max"
)

In [None]:
early_stopping_callback = L.pytorch.callbacks.EarlyStopping(
    monitor="val_acc",
    min_delta=0.005, # Increase accuracy by at least 0.5%
    patience=10,
    mode="max"
)

In [None]:
trainer = L.Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=wandb.config.get("epochs"),
    accelerator="auto",
    logger=wandb_logger,
)
model = ImageClassifier(
    wandb.config.get("lr"), 
    wandb.config.get("ft_only"),
    wandb.config.get("wd")
)
trainer.fit(model, train_loader, val_loader)
wandb.finish()

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | ResNet           | 23.6 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
73.8 K    Trainable params
23.5 M    Non-trainable params
23.6 M    Total params
94.327    Total estimated model params size (MB)
152       Modules in train mode
0         Modul

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]