In [19]:
import torchvision
from torchvision.transforms import v2
import lightning as pl
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.models.resnet import resnet50, ResNet50_Weights
from lightning.pytorch import loggers as pl_loggers


class ResNetModule(pl.LightningModule):
    def __init__(
        self, num_classes, optim, learning_rate=None, momentum=0, weight_decay=0
    ):
        super().__init__()

        # Save hyperparameters
        self.save_hyperparameters()
        self.num_classes = num_classes
        self.optim = optim
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay

        # Load pretrained ResNet50 model and replace last layer to fit the number of classes
        self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, inputs):
        return self.model(inputs)

    def training_step(self, batch):
        inputs, target = batch
        output = self(inputs)
        loss = F.cross_entropy(output, target)

        pred = output.argmax(dim=1, keepdim=True)
        correct = pred.eq(target.view_as(pred)).sum().item()
        accuracy = correct / len(target)

        self.log_dict({"train_loss": loss, "train_accuracy": accuracy})

        return loss

    def configure_optimizers(self):
        if self.optim == "adam":
            return torch.optim.Adam(
                self.model.parameters(),
                lr=self.learning_rate,
                weight_decay=self.weight_decay,
            )
        elif self.optim == "sgd":
            return torch.optim.SGD(
                self.model.parameters(),
                lr=self.learning_rate,
                momentum=self.momentum,
                weight_decay=self.weight_decay,
            )
        else:
            raise ValueError(f"Unsupported optimizer: {self.optim}")

    def test_step(self, batch):
        inputs, target = batch
        output = self(inputs)
        loss = F.cross_entropy(output, target)

        pred = output.argmax(dim=1, keepdim=True)
        correct = pred.eq(target.view_as(pred)).sum().item()
        accuracy = correct / len(target)

        self.log_dict({"eval_loss": loss, "eval_accuracy": accuracy})

    def validation_step(self, batch):
        inputs, target = batch
        output = self(inputs)
        loss = F.cross_entropy(output, target)

        pred = output.argmax(dim=1, keepdim=True)
        correct = pred.eq(target.view_as(pred)).sum().item()
        accuracy = correct / len(target)

        self.log_dict({"val_loss": loss, "val_accuracy": accuracy})


class Caltech101DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, split=(0.8, 0.1, 0.1)):
        super().__init__()
        self.batch_size = batch_size
        self.split = split

    def prepare_data(self):
        torchvision.datasets.Caltech101(root="datasets/caltech101", download=True)

    def setup(self, stage):
        transform = v2.Compose(
            [
                v2.RGB(),
                v2.Resize(size=(224, 224)),
                v2.RandomHorizontalFlip(),
                v2.ToImage(),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

        dataset = torchvision.datasets.Caltech101(
            root="datasets/caltech101", transform=transform
        )
        self.train, self.val, self.test = random_split(
            dataset, self.split, generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)


class CIFAR100DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        torchvision.datasets.CIFAR100(
            root="datasets/cifar100", train=True, download=True
        )
        torchvision.datasets.CIFAR100(
            root="datasets/cifar100", train=False, download=True
        )

    def setup(self, stage):
        transform = v2.Compose(
            [
                v2.RGB(),
                v2.Resize(size=(32, 32)),
                v2.RandomHorizontalFlip(),
                v2.ToImage(),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

        dataset = torchvision.datasets.CIFAR100(
            root="datasets/cifar100", transform=transform, train=True
        )
        self.train, self.val = random_split(
            dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42)
        )
        dataset = torchvision.datasets.CIFAR100(
            root="datasets/cifar100", transform=transform, train=False
        )
        self.test = dataset

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)


tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name="caltech101")
dataset = Caltech101DataModule()
model_name = "resnet-caltech101-min-val-loss"
trainer = pl.Trainer(
    max_epochs=20,
    logger=tb_logger,
    callbacks=[
        # Save the model with the lowest validation loss
        pl.pytorch.callbacks.ModelCheckpoint(
            dirpath="checkpoints",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
            filename=model_name,
            enable_version_counter=False,
        )
    ],
)

TRAIN = True

if TRAIN:
    model = ResNetModule(
        num_classes=101,
        optim="sgd",
        learning_rate=0.01,
        momentum=0.9,
        weight_decay=0.0005,
    )
    trainer.fit(model, datamodule=dataset)

    # Test with the best model from the checkpoint
    results = trainer.test(datamodule=dataset, ckpt_path="best")
else:
    model = ResNetModule.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
    results = trainer.test(model, datamodule=dataset)

print(results)

tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name="cifar100")
dataset = CIFAR100DataModule()
model_name = "resnet-cifar100-min-val-loss"
trainer = pl.Trainer(
    max_epochs=20,
    logger=tb_logger,
    callbacks=[
        # Save the model with the lowest validation loss
        pl.pytorch.callbacks.ModelCheckpoint(
            dirpath="checkpoints",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
            filename=model_name,
            enable_version_counter=False,
        )
    ],
)

TRAIN = True

if TRAIN:
    model = ResNetModule(
        num_classes=100,
        optim="adam",
        learning_rate=0.001,
        weight_decay=0.0001,
    )
    trainer.fit(model, datamodule=dataset)

    # Test with the best model from the checkpoint
    results = trainer.test(datamodule=dataset, ckpt_path="best")
else:
    model = ResNetModule.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
    results = trainer.test(model, datamodule=dataset)

print(results)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | ResNet | 23.7 M | train
-----------------------------------------
23.7 M    Trainable params
0         Non-trainable params
23.7 M    Total params
94.860    Total estimated model params size (MB)
151       Modules in train mode
0         Modules in eval mode


Epoch 19: 100%|██████████| 217/217 [00:31<00:00,  6.93it/s, v_num=0]       

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 217/217 [00:31<00:00,  6.93it/s, v_num=0]


Restoring states from the checkpoint path at /home/sentinel/Development/master-thesis/project/checkpoints/resnet-caltech101-min-val-loss.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/sentinel/Development/master-thesis/project/checkpoints/resnet-caltech101-min-val-loss.ckpt
/home/sentinel/.conda/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 28/28 [00:02<00:00, 11.09it/s]

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.9769319295883179
        eval_loss           0.0922703742980957
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'eval_loss': 0.0922703742980957, 'eval_accuracy': 0.9769319295883179}]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | ResNet | 23.7 M | train
-----------------------------------------
23.7 M    Trainable params
0         Non-trainable params
23.7 M    Total params
94.852    Total estimated model params size (MB)
151       Modules in train mode
0         Modules in eval mode


Epoch 19: 100%|██████████| 1250/1250 [00:36<00:00, 34.71it/s, v_num=0]      

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 1250/1250 [00:36<00:00, 34.71it/s, v_num=0]


Restoring states from the checkpoint path at /home/sentinel/Development/master-thesis/project/checkpoints/resnet-cifar100-min-val-loss.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/sentinel/Development/master-thesis/project/checkpoints/resnet-cifar100-min-val-loss.ckpt


Testing DataLoader 0: 100%|██████████| 313/313 [00:04<00:00, 75.79it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.4975000023841858
        eval_loss            2.005612373352051
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'eval_loss': 2.005612373352051, 'eval_accuracy': 0.4975000023841858}]
