## Install dependencies

In [1]:
%pip install -q lightning wandb torchvision torchmetrics matplotlib

Note: you may need to restart the kernel to use updated packages.


## Import packages

In [2]:
from pathlib import Path
import time

import PIL.Image
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms.v2 as v2
import torchmetrics
import lightning as L
import wandb
from torch.utils.data import Dataset, DataLoader
from lightning.pytorch.loggers import WandbLogger

import matplotlib.pyplot as plt

from jassair.utils import get_dataset_path, Datasets

## Lower matmul precision

## WandB login for experiment tracking

In [3]:
wandb.login()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mschurtenberger-david[0m ([33mdavid-schurtenberger[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Global variable definition

In [4]:
DATA_DIR = get_dataset_path(Datasets.S_1TO1_36C_NOVLP)
BATCH_SIZE = 32
NUM_CLASSES = 36

## Custom Synth-data Dataset

In [5]:
class YoloStyleDataset(Dataset):
    def __init__(self, root_dir: Path, transform=None):
        self.root_dir = root_dir
        self.image_dir = root_dir / 'images'
        self.label_dir = root_dir / 'labels'
        self.transform = transform
        self.image_files: list[Path] = [f for f in self.image_dir.iterdir()]
        self.labels: list[int] = []
        for file in self.image_files:
            label_file = self.label_dir / f"{file.stem}.txt"
            with label_file.open("r", encoding="utf-8") as f:
                self.labels.append(int(f.readline().split()[0]))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path: Path = self.image_files[idx]
        label = self.labels[idx]
        image = PIL.Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

## Baseline Model

In [18]:
class ImageClassifier(L.LightningModule):
    def __init__(self, lr: float, weight_decay: float, finetune_only: bool):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = MODEL
        
        # If set, only train the newly attached FC layer
        if self.hparams.finetune_only:
            for param in self.model.parameters():
                param.requires_grad = False

        match self.model._get_name().lower():
            case "efficientnet":
                in_features = self.model.classifier[1].in_features
                self.model.classifier[1] = nn.Linear(in_features, NUM_CLASSES)
            case "resnet":
                in_features = self.model.fc.in_features
                self.model.fc = nn.Linear(in_features, NUM_CLASSES)

        self.criterion = nn.CrossEntropyLoss()
        
        self._train_acc = torchmetrics.Accuracy("multiclass", num_classes=NUM_CLASSES)
        self._train_loss = []
        self._valid_acc = torchmetrics.Accuracy("multiclass", num_classes=NUM_CLASSES)
        self._valid_loss = []
        self._test_preds = []
        self._test_labels = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self._train_acc(outputs, labels)
        self._train_loss.append(loss)
        
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self._valid_acc(outputs, labels)
        self._valid_loss.append(loss)
        
    def on_train_epoch_end(self):
        loss = torch.stack(self._train_loss).mean()
        self.log_dict({'train_loss': loss, 'train_acc': self._train_acc.compute()}, prog_bar=True)
        self._train_loss.clear()
        self._train_acc.reset()

    def on_validation_epoch_end(self):
        loss = torch.stack(self._valid_loss).mean()
        self.log_dict({'val_loss': loss, 'val_acc': self._valid_acc.compute()}, prog_bar=True)
        self._valid_loss.clear()
        self._valid_acc.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)

        self._test_preds.append(preds.detach().cpu())
        self._test_labels.append(y.detach().cpu())

        return {}

    def test_epoch_end(self):
        preds = torch.cat(self._test_preds)
        targets = torch.cat(self._test_labels)

        cm = self._confusion_matrix(preds, targets)

        precision = []
        recall = []
        f1 = []

        for i in range(self.num_classes):
            true_positives = cm[i, i].item()
            false_positives = cm[:, i].sum().item() - true_positives
            false_negatives = cm[i, :].sum().item() - true_positives

            _precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
            _recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0
            _f1 = 2 * _precision * _recall / (_precision + _recall) if (_precision + _recall) > 0 else 0.0

            precision.append(_precision)
            recall.append(_recall)
            f1.append(_f1)

            self.log(f'test/precision_class_{i}', _precision)
            self.log(f'test/recall_class_{i}', _recall)
            self.log(f'test/f1_class_{i}', _f1)

        # Overall Accuracy
        acc = (preds == targets).sum().item() / len(targets)
        self.log('test/accuracy', acc)

        # Macro averages
        self.log('test/macro_precision', sum(precision) / self.num_classes)
        self.log('test/macro_recall', sum(recall) / self.num_classes)
        self.log('test/macro_f1', sum(f1) / self.num_classes)

    def _confusion_matrix(self, preds, targets):
        cm = torch.zeros(self.num_classes, self.num_classes, dtype=torch.int32)
        for t, p in zip(targets, preds):
            cm[t, p] += 1
        return cm

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, self.hparams.lr / 1000, self.hparams.lr)
        return [optimizer], [scheduler]

#### Training routine

In [19]:
def get_run_name(config, ft=False):
    params = []
    for k, v in config.items():
        if isinstance(v, float):
            params.append(f"{k}{v:.4f}")
        else:
            params.append(f"{k}{v}")
    return f"{MODEL_NAME}_{'_'.join(params)}_ft{ft}{time.strftime('%y%m%d%H%M%S')}"

In [20]:
def train_classifier(config, logger, *callbacks):
    L.seed_everything(42)
    model = ImageClassifier(
        wandb.config.get("lr"), 
        wandb.config.get("wd"),
        wandb.config.get("ft_only"),
    )
    trainer = L.Trainer(
        callbacks=list(callbacks),
        max_epochs=wandb.config.get("epochs"),
        accelerator="auto",
        precision='16-mixed',
        logger=logger,
    )
    trainer.fit(model, train_loader, val_loader)

In [21]:
def wandb_train_run(run):
    config = wandb.config
    run.name = get_run_name(config)
    checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
        dirpath="./lightning_checkpoints",
        filename=run.name + "_{epoch:02d}_{val_acc:.2f}",
        monitor="val_acc",
        save_last=True,
        mode="max"
    )
    early_stopping_callback = L.pytorch.callbacks.EarlyStopping(
        min_delta=0.005, # Improve at least 0.5%
        monitor="val_acc",
        patience=5,
        mode="max",
    )
    wandb_logger = WandbLogger()
    train_classifier(config, wandb_logger, checkpoint_callback, early_stopping_callback)
    artifact = wandb.Artifact(run.name, "model")
    artifact.add_file(checkpoint_callback.best_model_path)
    wandb.log_artifact(artifact)

In [None]:
def plot_cm(predictions, labels, name):
    cm = torchmetrics.ConfusionMatrix("multiclass", num_classes=5)
    cm.update(predictions, labels)
    fig, ax = cm.plot(labels=["A", "B", "C", "D", "E"], cmap="plasma")
    cbar = fig.colorbar(ax.images[0], ax=ax)
    plt.savefig(f"{name}.png", bbox_inches='tight')

In [None]:
def test_classifier(checkpoint, logger, name):
    L.seed_everything(42)
    model = ImageClassifier.load_from_checkpoint(checkpoint)
    trainer = L.Trainer(
        accelerator="auto",
        logger=logger,
    )
    trainer.test(model, test_loader)
    plot_cm(model._test_preds, model._test_labels, name)

In [None]:
def wandb_test_run(run, checkpoint):
    config = wandb.config
    run.name = get_run_name(config)
    wandb_logger = WandbLogger()
    test_classifier(checkpoint, wandb_logger, run.name)

In [58]:
def classifier_sweep(config=None):
    with wandb.init(config=config) as run:
        wandb_train_run(run)

## EfficientNet

#### Load model and weights

In [23]:
MODEL_WEIGHTS = models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
TRANSFORM = MODEL_WEIGHTS.transforms()
MODEL = models.efficientnet_v2_s(weights=MODEL_WEIGHTS)
MODEL_NAME = MODEL._get_name()

In [24]:
print(f"TRANSFORM: {TRANSFORM}")

TRANSFORM: ImageClassification(
    crop_size=[384]
    resize_size=[384]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


#### Data augmentation

In [61]:
gpu_transforms = v2.Compose([
    v2.ToImage(),
    v2.RandomAffine(degrees=180, translate=(0.1, 0.1)),
    v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    v2.Resize(384, interpolation=v2.InterpolationMode.BILINEAR),
    v2.ToDtype(torch.float32, scale=True),  # Converts to float [0,1]
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]),
])
print(transform)

Compose(
      Resize(size=[384], interpolation=InterpolationMode.BILINEAR, antialias=True)
      ToImage()
      ColorJitter(brightness=(0.7, 1.3), contrast=(0.7, 1.3), saturation=(0.7, 1.3), hue=(-0.1, 0.1))
      ToDtype(scale=True)
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
)


#### DataLoader

In [62]:
train_dataset = YoloStyleDataset(DATA_DIR / "train", transform)
val_dataset = YoloStyleDataset(DATA_DIR / "valid", TRANSFORM)
test_dataset = YoloStyleDataset(DATA_DIR / "test", TRANSFORM)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

# Sanity check
image, label = next(iter(test_loader))
print(f"Test Image shape: {image.shape}, Label: {label.shape}")

Test Image shape: torch.Size([32, 3, 384, 384]), Label: torch.Size([32])


## Test run

In [None]:
with wandb.init(
    entity="jassair",
    project="Baseline",
    config={
        "lr": 1e-2,
        "wd": 1e-6,
        "ft_only": True,
        "epochs": 50
    },
) as run:
    wandb_train_run(run)

## Hyperparameter Sweep

In [63]:
sweep_config = {
    "name": "EffiecientNet-Sweep",
    "method": "grid",
    "metric": {"name": "val_acc", "goal": "maximize"},
    "parameters": {
        "lr": {"values": [1e-3, 1e-2]},
        "wd": {"values": [1e-6, 1e-5, 1e-4]},
        "ft_only": {"value": True},
        "epochs": {"value": 50},
    }
}

In [64]:
sweep_id = wandb.sweep(
    sweep=sweep_config, 
    entity="jassair",
    project="BaselineModel",
)

Create sweep with ID: q21zodb0
Sweep URL: https://wandb.ai/jassair/BaselineModel/sweeps/q21zodb0


In [None]:
wandb.agent(sweep_id=sweep_id, function=classifier_run)

[34m[1mwandb[0m: Agent Starting Run: dv4cgc1j with config:
[34m[1mwandb[0m: 	epochs: 50
[34m[1mwandb[0m: 	ft_only: True
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	wd: 1e-06
[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mschurtenberger-david[0m ([33mjassair[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Seed set to 42
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/opt/conda/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/jovyan/DSPRO2-jassAIr/lightning_checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params | Mode 
----------------------------------------------------------
0 | model      | EfficientNet       | 20.2 M | train
1 | criterion  | CrossEntropyLoss   | 0      | train
2 | _train_acc | MulticlassAccuracy | 0      | train
3 | _valid_acc | MulticlassAccuracy | 0      | train


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
wandb.api.stop_sweep(sweep_id)

## ResNet

#### Load model and weights

In [None]:
MODEL_WEIGHTS = models.ResNet50_Weights.IMAGENET1K_V2
TRANSFORM = MODEL_WEIGHTS.transforms()
MODEL = models.resnet50(weights=MODEL_WEIGHTS)
MODEL_NAME = MODEL._get_name()

In [None]:
print(f"TRANSFORM: {TRANSFORM}")

## Data augmentation

In [None]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    v2.Resize(232, interpolation=v2.InterpolationMode.BILINEAR),
    v2.ToDtype(torch.float32, scale=True),  # Converts to float [0,1]
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]),
])
print(transform)

#### Data loaders

In [None]:
train_dataset = YoloStyleDataset(DATA_DIR / "train", transform=transform)
val_dataset = YoloStyleDataset(DATA_DIR / "valid", transform=TRANSFORM)
test_dataset = YoloStyleDataset(DATA_DIR / "test", transform=TRANSFORM)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=16, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=16, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=16, pin_memory=True)

# Sanity check
image, label = train_dataset[1]
print(f"Train Image shape: {image.shape}, Label: {label}")
image, label = test_dataset[1]
print(f"Test Image shape: {image.shape}, Label: {label}")

In [None]:
sweep_config = {
    "name": "ResNet-Sweep",
    "method": "random",
    "metric": {"name": "val_acc", "goal": "maximize"},
    "parameters": {
        "lr": {"values": [1e-4, 1e-3, 1e-2]},
        "wd": {"values": [1e-6, 1e-5, 1e-4]},
        "ft_only": {"values": [True, False]},
        "epochs": {"value": 50},
    }
}

In [None]:
sweep_id = wandb.sweep(
    sweep=sweep_config, 
    entity="jassair",
    project="BaselineModel",
)

In [None]:
wandb.agent(sweep_id=sweep_id, function=classifier_sweep, count=6)

In [None]:
wandb.api.stop_sweep(sweep_id)

## Fine-tune model

In [None]:
FINETUNE = False

In [None]:
if FINETUNE:
    DATA_DIR = get_dataset_path(Datasets.R_1TO1_36C_NOVLP)
    
    train_dataset = YoloStyleDataset(DATA_DIR / "train", transform=transform)
    val_dataset = YoloStyleDataset(DATA_DIR / "valid", transform=TRANSFORM)
    test_dataset = YoloStyleDataset(DATA_DIR / "test", transform=TRANSFORM)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    
    # Sanity check
    image, label = train_dataset[1]
    print(f"Image shape: {image.shape}, Label: {label}")

In [None]:
if FINETUNE:
    HYPERPARAMETERS = {
        "lr": 1e-1,
        "wd": 1e-7,
        "ft_only": False,
        "epochs": 50,
    }

In [None]:
if FINETUNE:
    RUN_NAME = f"ft_{time.strftime('%y%m%d-%H%M%S')}_{MODEL._get_name()}_lr{HYPERPARAMETERS['lr']}"
    wandb.init(
        entity="jassair",
        project="BaselineModel",
        name=RUN_NAME,
        config=HYPERPARAMETERS,
    )
    wandb_logger = WandbLogger(project="BaselineModel")

In [None]:
if FINETUNE:
    artifact = wandb.use_artifact("jassair/BaselineModel/250418-133900_EfficientNet_lr0.001:v0", type="model")
    artifact_dir = artifact.download()
    checkpoint_path = f"{artifact_dir}/250418-133900_EfficientNet_lr0.001-epoch=71-val_acc=0.93.ckpt"

In [None]:
if FINETUNE:
    early_stopping_callback = L.pytorch.callbacks.EarlyStopping(
        monitor="val_acc",
        min_delta=0.005, # Increase accuracy by at least 0.5%
        patience=15,
        mode="max"
    )
    
    checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
        dirpath="./lightning_checkpoints",
        filename=RUN_NAME + "_{epoch:02d}_{val_acc:.2f}",
        monitor="val_acc",
        save_last=True,
        mode="max"
    )
    
    L.seed_everything(42)
    trainer = L.Trainer(
        callbacks=[checkpoint_callback, early_stopping_callback],
        max_epochs=wandb.config.get("epochs"),
        accelerator="auto",
        logger=wandb_logger,
    )
    model = ImageClassifier.load_from_checkpoint(
        checkpoint_path,
        lr=wandb.config.get("lr"), 
        weight_decay=wandb.config.get("wd"),
        finetune_only=wandb.config.get("ft_only"),
    )
    model.configure_optimizers()
    trainer.fit(model, train_loader, val_loader)
    artifact = wandb.Artifact(RUN_NAME, type='model')
    artifact.add_file(checkpoint_callback.best_model_path)
    wandb.log_artifact(artifact)
    wandb.finish()

### Evaluation

In [None]:
def run_tests(cps: list[str], test_set: str):
    for checkpoint in cps:
        with wandb.init(entity="jassair", project="BaselineModel", config=dict(test=test_set)) as run:
            wandb_test_run(run, checkpoint)

#### ResNet

In [None]:
MODEL_WEIGHTS = models.ResNet50_Weights.IMAGENET1K_V2
TRANSFORM = MODEL_WEIGHTS.transforms()
MODEL = models.resnet50(weights=MODEL_WEIGHTS)
MODEL_NAME = MODEL._get_name()

In [None]:
checkpoints = [
    
]

##### Synthetic data

In [None]:
test_dataset = YoloStyleDataset(get_dataset_path(Datasets.S_1TO1_36C_NOVLP) / "test")
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
run_tests(checkpoints, "synth")

##### Real world data

In [None]:
test_dataset = YoloStyleDataset(get_dataset_path(Datasets.R_1TO1_36C_NOVLP) / "test")
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
run_tests(checkpoints, "real")

#### EfficientNet

In [None]:
MODEL_WEIGHTS = models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
TRANSFORM = MODEL_WEIGHTS.transforms()
MODEL = models.efficientnet_v2_s(weights=MODEL_WEIGHTS)
MODEL_NAME = MODEL._get_name()

In [None]:
checkpoints = [
    
]

##### Synthetic Data

In [None]:
test_dataset = YoloStyleDataset(get_dataset_path(Datasets.S_1TO1_36C_NOVLP) / "test")
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
run_tests(checkpoints, "synth")

##### Real data

In [None]:
test_dataset = YoloStyleDataset(get_dataset_path(Datasets.R_1TO1_36C_NOVLP) / "test")
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
run_tests(checkpoints, "real")