## Install dependencies

In [1]:
%pip install -q lightning wandb torchvision torchmetrics

Note: you may need to restart the kernel to use updated packages.


## Import packages

In [2]:
from pathlib import Path
import time

import PIL.Image
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms.v2 as v2
import torchmetrics
import lightning as L
import wandb
from torch.utils.data import Dataset, DataLoader
from lightning.pytorch.loggers import WandbLogger

from jassair.utils import get_dataset_path, Datasets

## Lower matmul precision

In [3]:
torch.set_float32_matmul_precision('high')

## WandB login for experiment tracking

In [4]:
wandb.login()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mschurtenberger-david[0m ([33mdavid-schurtenberger[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Global variable definition

In [5]:
DATA_DIR = get_dataset_path(Datasets.S_1TO1_36C_NOVLP)
BATCH_SIZE = 32
NUM_CLASSES = 36

## Custom Synth-data Dataset

In [6]:
class YoloStyleDataset(Dataset):
    def __init__(self, root_dir: Path, transform=None):
        self.root_dir = root_dir
        self.image_dir = root_dir / 'images'
        self.label_dir = root_dir / 'labels'
        self.transform = transform
        self.image_files: list[Path] = [f for f in self.image_dir.iterdir()]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path: Path = self.image_files[idx]
        image = PIL.Image.open(image_path).convert("RGB")
        label_path = self.label_dir / f"{image_path.stem}.txt"
        if not label_path.exists():
            raise FileNotFoundError(label_path)
        with label_path.open("r", encoding="utf-8") as f:
            label = int(f.readline().split()[0])
        if self.transform:
            image = self.transform(image)
        return image, label

## Pre-trained Model / Data transform

In [7]:
# Available models: https://pytorch.org/vision/stable/models.html#classification

MODEL_WEIGHTS = models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
TRANSFORM = MODEL_WEIGHTS.transforms()
MODEL = models.efficientnet_v2_s(weights=MODEL_WEIGHTS)

In [8]:
print(f"TRANSFORM: {TRANSFORM}")

TRANSFORM: ImageClassification(
    crop_size=[384]
    resize_size=[384]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


## DataLoader definition

In [9]:
transform = v2.Compose([
    v2.ToImage(),
    v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    v2.RandomAffine(degrees=180, translate=(0.3, 0.3)),
    v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    v2.Resize(384, interpolation=v2.InterpolationMode.BICUBIC),
    v2.ToDtype(torch.float32, scale=True),  # Converts to float [0,1]
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]),
])
print(transform)

Compose(
      ToImage()
      ColorJitter(brightness=(0.7, 1.3), contrast=(0.7, 1.3), saturation=(0.7, 1.3), hue=(-0.1, 0.1))
      RandomAffine(degrees=[-180.0, 180.0], translate=(0.3, 0.3), interpolation=InterpolationMode.NEAREST, fill=0)
      GaussianBlur(kernel_size=(3, 3), sigma=[0.1, 2.0])
      Resize(size=[384], interpolation=InterpolationMode.BICUBIC, antialias=True)
      ToDtype(scale=True)
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
)


In [10]:
train_dataset = YoloStyleDataset(DATA_DIR / "train", transform=transform)
val_dataset = YoloStyleDataset(DATA_DIR / "valid", transform=TRANSFORM)
test_dataset = YoloStyleDataset(DATA_DIR / "test", transform=TRANSFORM)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)

# Sanity check
image, label = train_dataset[1]
print(f"Train Image shape: {image.shape}, Label: {label}")
image, label = test_dataset[1]
print(f"Test Image shape: {image.shape}, Label: {label}")

Train Image shape: torch.Size([3, 384, 384]), Label: 34
Test Image shape: torch.Size([3, 384, 384]), Label: 31


## Baseline Model

In [11]:
class ImageClassifier(L.LightningModule):
    def __init__(self, lr: float, weight_decay: float, finetune_only: bool):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = MODEL
        
        # If set, only train the newly attached FC layer
        if self.hparams.finetune_only:
            for param in self.model.parameters():
                param.requires_grad = False

        in_features = self.model.classifier[1].in_features
        self.model.classifier[1] = nn.Linear(in_features, NUM_CLASSES)

        self.criterion = nn.CrossEntropyLoss()
        
        self._train_acc = torchmetrics.Accuracy("multiclass", num_classes=NUM_CLASSES)
        self._train_loss = []
        self._valid_acc = torchmetrics.Accuracy("multiclass", num_classes=NUM_CLASSES)
        self._valid_loss = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self._train_acc(outputs, labels)
        self._train_loss.append(loss)
        
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self._valid_acc(outputs, labels)
        self._valid_loss.append(loss)
        
    def on_train_epoch_end(self):
        loss = torch.stack(self._train_loss).mean()
        self.log_dict({'train_loss': loss, 'train_acc': self._train_acc.compute()}, prog_bar=True)
        self._train_loss.clear()
        self._train_acc.reset()

    def on_validation_epoch_end(self):
        loss = torch.stack(self._valid_loss).mean()
        self.log_dict({'val_loss': loss, 'val_acc': self._valid_acc.compute()}, prog_bar=True)
        self._valid_loss.clear()
        self._valid_acc.reset()

    def test_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        
        self.log("test_loss", loss, prog_bar=True, logger=True)
        self.log("test_acc", acc, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, self.hparams.lr / 1000, self.hparams.lr)
        return [optimizer], [scheduler]

## Train Model

In [13]:
HYPERPARAMETERS = {
    "lr": 1e-3,
    "wd": 1e-5,
    "ft_only": True,
    "epochs": 100,
}

In [14]:
RUN_NAME = f"{time.strftime('%y%m%d-%H%M%S')}_{MODEL._get_name()}_lr{HYPERPARAMETERS['lr']}"
wandb.init(
    entity="jassair",
    project="BaselineModel",
    name=RUN_NAME,
    config=HYPERPARAMETERS,
)
wandb_logger = WandbLogger(project="BaselineModel")

In [15]:
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
    dirpath="./lightning_checkpoints",
    filename=RUN_NAME + "-{epoch:02d}-{val_acc:.2f}",
    monitor="val_acc",
    save_last=True,
    mode="max"
)

In [16]:
early_stopping_callback = L.pytorch.callbacks.EarlyStopping(
    monitor="val_acc",
    min_delta=0.005, # Increase accuracy by at least 0.5%
    patience=10,
    mode="max"
)

In [None]:
L.seed_everything(42)
trainer = L.Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=wandb.config.get("epochs"),
    accelerator="auto",
    logger=wandb_logger,
)
model = ImageClassifier(
    wandb.config.get("lr"), 
    wandb.config.get("wd"),
    wandb.config.get("ft_only"),
)
trainer.fit(model, train_loader, val_loader)
artifact = wandb.Artifact(RUN_NAME, type='model')
artifact.add_file(checkpoint_callback.best_model_path)
wandb.log_artifact(artifact)
wandb.finish()

Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/opt/conda/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/jovyan/DSPRO2-jassAIr/lightning_checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params | Mode 
----------------------------------------------------------
0 | model      | EfficientNet       | 20.2 M | train
1 | criterion  | CrossEntropyLoss   | 0      | train
2 | _train_acc | MulticlassAccuracy | 0      | train
3 | _valid_acc | MulticlassAccuracy | 0      | train
--------------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation: |          | 0/? [00:00<?, ?it/s]

## Fine-tune model

In [13]:
DATA_DIR = get_dataset_path(Datasets.R_1TO1_36C_NOVLP)

In [14]:
train_dataset = YoloStyleDataset(DATA_DIR / "train", transform=transform)
val_dataset = YoloStyleDataset(DATA_DIR / "valid", transform=TRANSFORM)
test_dataset = YoloStyleDataset(DATA_DIR / "test", transform=TRANSFORM)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Sanity check
image, label = train_dataset[1]
print(f"Image shape: {image.shape}, Label: {label}")

Image shape: torch.Size([3, 384, 384]), Label: 5


In [52]:
HYPERPARAMETERS = {
    "lr": 1e-1,
    "wd": 1e-7,
    "ft_only": False,
    "epochs": 50,
}

In [53]:
RUN_NAME = f"ft_{time.strftime('%y%m%d-%H%M%S')}_{MODEL._get_name()}_lr{HYPERPARAMETERS['lr']}"
wandb.init(
    entity="jassair",
    project="BaselineModel",
    name=RUN_NAME,
    config=HYPERPARAMETERS,
)
wandb_logger = WandbLogger(project="BaselineModel")

In [54]:
artifact = wandb.use_artifact("jassair/BaselineModel/250418-133900_EfficientNet_lr0.001:v0", type="model")
artifact_dir = artifact.download()
checkpoint_path = f"{artifact_dir}/250418-133900_EfficientNet_lr0.001-epoch=71-val_acc=0.93.ckpt"

[34m[1mwandb[0m: Downloading large artifact 250418-133900_EfficientNet_lr0.001:v0, 78.37MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5


In [55]:
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
    dirpath="./lightning_checkpoints",
    filename=RUN_NAME + "_{epoch:02d}_{val_acc:.2f}",
    monitor="val_acc",
    save_last=True,
    mode="max"
)

In [56]:
early_stopping_callback = L.pytorch.callbacks.EarlyStopping(
    monitor="val_acc",
    min_delta=0.005, # Increase accuracy by at least 0.5%
    patience=15,
    mode="max"
)

In [57]:
L.seed_everything(42)
trainer = L.Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=wandb.config.get("epochs"),
    accelerator="auto",
    logger=wandb_logger,
)
model = ImageClassifier.load_from_checkpoint(
    checkpoint_path,
    lr=wandb.config.get("lr"), 
    weight_decay=wandb.config.get("wd"),
    finetune_only=wandb.config.get("ft_only"),
)
model.configure_optimizers()
trainer.fit(model, train_loader, val_loader)
artifact = wandb.Artifact(RUN_NAME, type='model')
artifact.add_file(checkpoint_callback.best_model_path)
wandb.log_artifact(artifact)
wandb.finish()

Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/opt/conda/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/jovyan/DSPRO2-jassAIr/lightning_checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params | Mode 
----------------------------------------------------------
0 | model      | EfficientNet       | 20.2 M | train
1 | criterion  | CrossEntropyLoss   | 0      | train
2 | _train_acc | MulticlassAccuracy | 0      | train
3 | _valid_acc | MulticlassAccuracy | 0      | train
--------------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇████
train_acc,▅▃▃▁▂▄▄▄▅▃▅█▄█▄▇▃
train_loss,██▆▇▇▆▄▆▆▄▆▁▆▃▅▅▄
trainer/global_step,▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇████
val_acc,▇█▃▇▇█▇▆▅▆▅█▇▇▁▅▅
val_loss,▂▅▆▄▆▃▄▅▇▆▆▁▂▅▇█▅

0,1
epoch,16.0
train_acc,0.65333
train_loss,1.24545
trainer/global_step,135.0
val_acc,0.66667
val_loss,0.90733
