In [3]:
import lightning as pl
from lightning.pytorch import loggers as pl_loggers
import torch
from models import *
from datasets import *


# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")


tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name="caltech101")
dataset = Caltech101DataModule()
model_name = "resnet50-caltech101-min-val-loss"
trainer = pl.Trainer(
    max_epochs=10,
    logger=tb_logger,
    callbacks=[
        # Save the model with the lowest validation loss
        pl.pytorch.callbacks.ModelCheckpoint(
            dirpath="checkpoints",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
            filename=model_name,
            enable_version_counter=False,
        )
    ],
)

TRAIN = False

if TRAIN:
    model = ResNetModule(
        num_classes=101,
        architecture="resnet50",
        optim="sgd",
        lr=0.01,
        momentum=0.9,
        weight_decay=5e-4,
    )
    trainer.fit(model, datamodule=dataset)

    # Test with the best model from the checkpoint
    results = trainer.test(datamodule=dataset, ckpt_path="best")
else:
    model = ResNetModule.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
    results = trainer.test(model, datamodule=dataset)

print(results)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/sentinel/.conda/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 28/28 [00:02<00:00,  9.74it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.9711649417877197
        eval_loss           0.09238769114017487
        hp_metric           0.9711649417877197
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'eval_loss': 0.09238769114017487, 'eval_accuracy': 0.9711649417877197, 'hp_metric': 0.9711649417877197}]
