In [1]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from lit_modules.custom_wide_resnet_lit import WideResnetLit
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger

L.seed_everything(42)
torch.set_float32_matmul_precision("medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.cuda.empty_cache()

Seed set to 42


Using device: cuda


In [2]:
IMAGE_SIZE = 32
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
# source: https://pytorch.org/vision/stable/transforms.html
transforms_train = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)
transforms_test = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)
whole_dataset = datasets.EuroSAT(
    root="./data", download=True, transform=transforms_train
)
N = len(whole_dataset)
num_val = int(0.2 * N)
num_test = int(0.2 * N)
num_train = N - num_val - num_test
all_indices = list(range(N))
all_indices = torch.randperm(N).tolist()
indices_val = torch.randperm(N)[:num_val]
indices_test = torch.randperm(N)[:num_test]
train_indices = all_indices[:num_train]
val_indices = all_indices[num_train : num_train + num_val]
test_indices = all_indices[num_train + num_val : num_train + num_val + num_test]

train_dataset = torch.utils.data.Subset(whole_dataset, train_indices)
validation_dataset = torch.utils.data.Subset(whole_dataset, val_indices)
test_dataset = torch.utils.data.Subset(whole_dataset, test_indices)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=512,
    num_workers=30,
    persistent_workers=True,
)
validation_loader = torch.utils.data.DataLoader(
    dataset=validation_dataset, batch_size=512, num_workers=30, persistent_workers=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=512, num_workers=30, persistent_workers=True
)

In [3]:
class DataModule(L.LightningDataModule):
    def __init__(self, train_loader, validation_loader, test_loader):
        super().__init__()
        self.train_loader = train_loader
        self.validation_loader = validation_loader
        self.test_loader = test_loader

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.validation_loader

    def test_dataloader(self):
        return self.test_loader


data = DataModule(train_loader, validation_loader, test_loader)
model = WideResnetLit(depth=32, weight_decay=10e-4)
logger = MLFlowLogger(experiment_name="WideResnet", save_dir="mlruns")
trainer = L.Trainer(
    max_epochs=50,
    logger=logger,
    callbacks=[
        ModelCheckpoint(
            monitor="val_acc",
            mode="max",
            dirpath="checkpoints/wide_resnet",
            filename="{epoch:02d}-{val_acc:.3f}",
        )
    ],
    precision="16-mixed",
    num_sanity_val_steps=0,
)
trainer.fit(model, datamodule=data)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Experiment with name WideResnet not found. Creating it.
/home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/.venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/model_training_notebooks/checkpoints/wide_resnet exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (32) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

  | 

Epoch 0:   3%|▎         | 1/32 [00:00<00:22,  1.41it/s, v_num=7c1c, train_loss_step=13.60, train_acc_step=0.201]



Epoch 49: 100%|██████████| 32/32 [00:08<00:00,  3.76it/s, v_num=7c1c, train_loss_step=0.0142, train_acc_step=0.997, val_loss=0.171, val_acc=0.944, train_loss_epoch=0.0207, train_acc_epoch=0.997]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 32/32 [00:08<00:00,  3.73it/s, v_num=7c1c, train_loss_step=0.0142, train_acc_step=0.997, val_loss=0.171, val_acc=0.944, train_loss_epoch=0.0207, train_acc_epoch=0.997]


In [4]:
best_ckpt = trainer.checkpoint_callback.best_model_path
print("Best checkpoint path:", best_ckpt)
trainer.test(model, datamodule=data, ckpt_path=best_ckpt)

Restoring states from the checkpoint path at /home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/model_training_notebooks/checkpoints/wide_resnet/epoch=49-val_acc=0.944.ckpt


Best checkpoint path: /home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/model_training_notebooks/checkpoints/wide_resnet/epoch=49-val_acc=0.944.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/model_training_notebooks/checkpoints/wide_resnet/epoch=49-val_acc=0.944.ckpt


Testing DataLoader 0: 100%|██████████| 11/11 [00:00<00:00, 13.57it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9435185194015503
        test_loss           0.18384797871112823
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.18384797871112823, 'test_acc': 0.9435185194015503}]