In [2]:
import lightning as pl
from lightning.pytorch import loggers as pl_loggers
from models import *
from datasets import *


# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")


tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name="cifar100")
dataset = CIFAR100DataModule()
model_name = "resnet152-cifar100-min-val-loss"
trainer = pl.Trainer(
    max_epochs=30,
    logger=tb_logger,
    callbacks=[
        # Save the model with the lowest validation loss
        pl.pytorch.callbacks.ModelCheckpoint(
            dirpath="checkpoints",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
            filename=model_name,
            enable_version_counter=False,
        )
    ],
)

TRAIN = True

if TRAIN:
    model = ResNetModule(
        num_classes=100,
        architecture="resnet152",
        optim="adamw",
        lr=0.001,
        weight_decay=0.05,
    )
    trainer.fit(model, datamodule=dataset)

    # Test with the best model from the checkpoint
    results = trainer.test(datamodule=dataset, ckpt_path="best")
else:
    model = ResNetModule.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
    results = trainer.test(model, datamodule=dataset)

print(results)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | ResNet           | 58.3 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
58.3 M    Trainable params
0         Non-trainable params
58.3 M    Total params
233.395   Total estimated model params size (MB)
424       Modules in train mode
0         Modules in eval mode


Epoch 29: 100%|██████████| 157/157 [00:29<00:00,  5.27it/s, v_num=14]      

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 157/157 [00:29<00:00,  5.27it/s, v_num=14]


Restoring states from the checkpoint path at /home/sentinel/Development/master-thesis/project/checkpoints/resnet152-cifar100-min-val-loss.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/sentinel/Development/master-thesis/project/checkpoints/resnet152-cifar100-min-val-loss.ckpt


Testing DataLoader 0: 100%|██████████| 40/40 [00:04<00:00,  8.43it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.5734000205993652
        eval_loss           1.6999752521514893
        hp_metric           0.5734000205993652
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'eval_loss': 1.6999752521514893, 'eval_accuracy': 0.5734000205993652, 'hp_metric': 0.5734000205993652}]
