In [11]:
import os
from pathlib import Path
from PIL import Image
from typing import Any

from prettytable import PrettyTable
import skimage as ski
import lightning as L
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import get_model, get_weight
from torchvision.datasets import DatasetFolder

# Setup

In [9]:
DATASET_PATHS = {
    'fog-detection': Path('./datasets/fog-detection-dataset-prepared'),
    'fog-or-smog': Path('./datasets/fog-or-smog-detection-dataset-prepared'),
    'foggy-cityscapes': Path('./datasets/foggy-cityscapes-image-dataset-prepared')
}

In [None]:
def load_image(path: str):
    img = ski.io.imread(path)
    if img.ndim == 2:  # Handle grayscale
        img = ski.color.gray2rgb(img)
    if img.shape[-1] == 4:  # Handle RGBA
        img = ski.color.rgba2rgb(img)
    img = ski.util.img_as_ubyte(img)
    return Image.fromarray(img)


def get_dataloader(
    path: Path | str,
    transform: callable | None = None,
    batch_size: int = 32,
    shuffle: bool = False,
    num_workers: int = 0,
    pin_memory: bool = True,
):
    return DataLoader(
        dataset=DatasetFolder(
            root=path,
            loader=load_image,
            extensions=[".jpg", ".png", ".jpeg"],
            transform=transform,
        ),
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=True if num_workers > 0 else False,
    )

In [12]:
def recursive_getattr(obj: object, name: str):
    names = name.split(".")
    for name in names:
        obj = getattr(obj, name)
    return obj

def recursive_setattr(obj: object, value: Any, name: str):
    names = name.split(".")
    for name in names[:-1]:
        obj = getattr(obj, name)
    setattr(obj, names[-1], value)

# Training

In [None]:
SAVE_DIR = Path('runs/classify/CNN')
MODEL_NAME = "ResNet18"
DATASET = 'fog-or-smog'
VERSION = 1

trainer = L.Trainer(
    max_epochs=10,
    logger=L.pytorch.loggers.TensorBoardLogger(
        save_dir=SAVE_DIR,
        name=f"MODEL_NAME-{DATASET}",
        version=VERSION,
    ),
    callbacks=[
        L.pytorch.callbacks.early_stopping.EarlyStopping(
            monitor="val_loss", mode="min",
            patience=5,
            verbose=False    
        ),
        L.pytorch.callbacks.ModelCheckpoint(
            monitor="val_f1", mode="max",
            dirpath=SAVE_DIR / MODEL_NAME / f"version_{VERSION}",
            filename="{epoch}-{val_loss:.2f}-{val_f1:.2f}"
        )
    ],
    log_every_n_steps=1
)

model = CNNClassifier(
    model_name=MODEL_NAME,
    num_classes=2,
    weights=f"{MODEL_NAME}_Weights.DEFAULT",
    dataset_path=DATASET_PATHS[DATASET],
    num_workers=11
)

trainer.fit(
    model
)

model = CNNClassifier.load_from_checkpoint(
    trainer.checkpoint_callbacks[0].best_model_path
)

_transform = get_weight(model.weights).transforms()
res = {
    dataset_path: trainer.test(model, get_dataloader(dataset_path=path / 'test', transform=_transform))[0]
    for dataset_path, path in DATASET_PATHS.items()
}

table = PrettyTable()
table.field_names = [
    "Dataset", *list(next(iter(res.values())).keys())
]
table.add_rows([
    [dataset, *[round(m, 4) for m in metrics.values()]] for dataset, metrics in res.items()
])
table