In [None]:
import pandas as pd
import numpy as np
import torch
from torchvision.datasets import CIFAR100
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.models.resnet import ResNetModel
from library.datasets.cifar100 import CIFAR100DataModule

# Load dataset information
cifar100_labels = CIFAR100(root="datasets/cifar100", train=True, download=False).classes

print(f"CIFAR-100 classes: {len(cifar100_labels)}")

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

CIFAR-100 classes: 100


In [2]:
# Configuration
TRAIN = True  # Set to True to train models from scratch


def train_dataset_model(
    dataset_module, dataset_name, logger_name, model_name, num_classes, training_config
):
    """Train a ResNet model for a specific dataset"""
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)

    trainer = Trainer(
        max_epochs=training_config["max_epochs"],
        logger=tb_logger if TRAIN else False,
        callbacks=[
            ModelCheckpoint(
                dirpath="checkpoints",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
                filename=model_name,
                enable_version_counter=False,
            )
        ],
    )

    if TRAIN:
        model = ResNetModel(
            num_classes=num_classes,
            architecture="resnet152",  # Using ResNet152 for CIFAR-100
            optim=training_config["optim"],
            optim_kwargs=training_config["optim_kwargs"],
            **training_config["addition"],
        )
        trainer.fit(model, datamodule=dataset_module)
        results = trainer.test(datamodule=dataset_module, ckpt_path="best")
    else:
        model = ResNetModel.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
        results = trainer.test(model, datamodule=dataset_module)

    return results

In [3]:
# Train CIFAR-100 model
cifar100_dataset = CIFAR100DataModule()
cifar100_config = {
    "max_epochs": 50,
    "optim": "adamw",
    "optim_kwargs": {
        "lr": 0.001,
        "weight_decay": 0.001,
    },
    "addition": {
        # Add learning rate scheduling
        "lr_scheduler": "multistep",
        "lr_scheduler_kwargs": {
            "milestones": [30, 60, 80],
            "gamma": 0.1,
        },
    },
}

print("Training CIFAR-100 baseline model...")
cifar100_results = train_dataset_model(
    cifar100_dataset,
    "CIFAR-100",
    "cifar100_baseline",
    "resnet152-cifar100-min-val-loss",
    len(cifar100_labels),
    cifar100_config,
)

print(f"CIFAR-100 Test Results: {cifar100_results}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Training CIFAR-100 baseline model...


/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/bjoern/dev/master-thesis/project/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | ResNet           | 60.9 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
60.9 M    Trainable params
0         Non-trainable params
60.9 M    Total params
243.776   Total estimated model params size (MB)
434       Modules in train mode
0         Modules in eval mode


Epoch 49: 100%|██████████| 157/157 [00:33<00:00,  4.68it/s, v_num=2]       

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 157/157 [00:33<00:00,  4.68it/s, v_num=2]


Restoring states from the checkpoint path at /home/bjoern/dev/master-thesis/project/checkpoints/resnet152-cifar100-min-val-loss.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/bjoern/dev/master-thesis/project/checkpoints/resnet152-cifar100-min-val-loss.ckpt


Testing DataLoader 0: 100%|██████████| 40/40 [00:03<00:00, 12.73it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.6047000288963318
        eval_loss           1.9386216402053833
        hp_metric           0.6047000288963318
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
CIFAR-100 Test Results: [{'eval_loss': 1.9386216402053833, 'eval_accuracy': 0.6047000288963318, 'hp_metric': 0.6047000288963318}]
