## Install dependencies

In [2]:
%pip install lightning wandb torchvision

Collecting lightning
  Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)
Collecting wandb
  Downloading wandb-0.19.8-py3-none-win_amd64.whl.metadata (10 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.14.2-py3-none-any.whl.metadata (5.6 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)
Collecting click!=8.0.0,>=7.1 (from wandb)
  Using cached click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Using cached GitPython-3.1.44-py3-none-any.whl.metadata (13 kB)
Collecting protobuf!=4.21.0,!=5.28.0,<6,>=3.19.0 (from wandb)
  Downloading protobuf-5.29.4-cp310-abi3-win_amd64.w


[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


## Import packages

In [23]:
from pathlib import Path

import PIL.Image
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import lightning as L
import wandb
from torch.utils.data import Dataset, DataLoader
from lightning.pytorch.loggers import WandbLogger

from jassair.utils import get_dataset_path, Datasets

## WandB login for experiment tracking

In [None]:
wandb.login()

## Global variable definition

In [24]:
DATA_DIR = get_dataset_path(Datasets.SYNTHETIC_SINGLE)
BATCH_SIZE = 32
NUM_CLASSES = 16

## Custom Synth-data Dataset

In [36]:
class YoloDataset(Dataset):
    def __init__(self, root_dir: Path, transform=None):
        self.root_dir = root_dir
        self.image_dir = root_dir / 'images'
        self.label_dir = root_dir / 'labels'
        self.transform = transform
        self.image_files: list[Path] = [f for f in self.image_dir.iterdir()]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path: Path = self.image_files[idx]
        image = PIL.Image.open(image_path).convert("RGB")
        label_path = self.label_dir / f"{image_path.stem}.txt"
        if not label_path.exists():
            raise FileNotFoundError(label_path)
        with label_path.open("r", encoding="utf-8") as f:
            label = torch.tensor([int(f.readline().split()[0])])
        if self.transform:
            image = self.transform(image)
        return image, label

## Pre-trained Model / Data transform

In [43]:
MODEL_WEIGHTS = models.ResNet50_Weights.IMAGENET1K_V2
TRANSFORM = MODEL_WEIGHTS.transforms()
MODEL = models.resnet50(weights=MODEL_WEIGHTS)

'ResNet'

## DataLoader definition

In [41]:
train_dataset = YoloDataset(DATA_DIR / "train", transform=TRANSFORM)
val_dataset = YoloDataset(DATA_DIR / "valid", transform=TRANSFORM)
test_dataset = YoloDataset(DATA_DIR / "test", transform=TRANSFORM)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Sanity check
image, label = train_dataset[1]
print(f"Image shape: {image.shape}, Label: {label}")

C:\git\DSPRO2-jassAIr\data\synth_single\train\images\Eichel-6_1.png
Image shape: torch.Size([3, 224, 224]), Label: tensor([1])


## Baseline Model

In [42]:
class ImageClassifier(L.LightningModule):
    def __init__(self, lr: float, weight_decay: float, finetune_only: bool):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = MODEL
        
        # If set, only train the newly attached FC layer
        if self.hparams.finetune_only:
            for param in self.model.parameters():
                param.requires_grad = False
                
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, NUM_CLASSES)

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", acc, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc", acc, on_epoch=True, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        
        self.log("test_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("test_acc", acc, on_epoch=True, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, self.hparams.lr / 100, self.hparams.lr)
        return [optimizer], [scheduler]

## Train Model

In [None]:
HYPERPARAMETERS = {
    "lr": 1e-4,
    "wd": 1e-8,
    "ft_only": True,
    "epochs": 10,
}

In [None]:
wandb.init(
    entity="jassair",
    project="BaselineModel",
    name=f"{MODEL._get_name()}_lr{HYPERPARAMETERS['lr']}",
    config=HYPERPARAMETERS,
)
wandb_logger = WandbLogger(project="NLP_Project_1")

In [None]:
trainer = L.Trainer(
    max_epochs=wandb.config.get("epochs"),
    accelerator="auto",
    logger=wandb_logger,
    log_every_n_steps=10
)

model = ImageClassifier(
    wandb.config.get("lr"), 
    wandb.config.get("ft_only"),
    wandb.config.get("wd")
)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
wandb.finish()