In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from lit_modules.torchvision_wide_resnet_lit import WideResnetLit
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

L.seed_everything(42)
torch.set_float32_matmul_precision("medium")

Seed set to 42


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
IMAGE_SIZE = 32
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
# source: https://pytorch.org/vision/stable/transforms.html
transforms_train = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)


transforms_test = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)

train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transforms_train
)
test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transforms_test
)

N = len(train_dataset)
num_val = int(0.2 * N)
indices = torch.randperm(N)[:num_val]
mask = torch.ones(N, dtype=torch.bool)
mask[indices] = False
train_indices = torch.nonzero(mask, as_tuple=False).squeeze(1)
validation_dataset = torch.utils.data.Subset(train_dataset, indices=indices)
train_dataset = torch.utils.data.Subset(train_dataset, indices=train_indices)
validation_loader = torch.utils.data.DataLoader(
    dataset=validation_dataset, batch_size=512, num_workers=1, persistent_workers=True
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=512, num_workers=1, persistent_workers=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=512, num_workers=1, persistent_workers=True
)

In [4]:
class DataModule(L.LightningDataModule):
    def __init__(self, train_loader, validation_loader, test_loader):
        super().__init__()
        self.train_loader = train_loader
        self.validation_loader = validation_loader
        self.test_loader = test_loader

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.validation_loader

    def test_dataloader(self):
        return self.test_loader


data = DataModule(train_loader, validation_loader, test_loader)
model = WideResnetLit(depth=32, weight_decay=10e-4)
logger = MLFlowLogger(experiment_name="WideResnet", save_dir="mlruns")
trainer = L.Trainer(
    max_epochs=50,
    logger=logger,
    callbacks=[
        ModelCheckpoint(
            monitor="val_acc",
            mode="max",
            dirpath="checkpoints/wide_resnet",
            filename="{epoch:02d}-{val_acc:.3f}",
        )
    ],
    precision="16-mixed",
    num_sanity_val_steps=0,
)
trainer.fit(model, datamodule=data)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/.venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/models/checkpoints/wide_resnet exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.

  | Name      | Type               | Params | Mod

Epoch 0:   1%|▏         | 1/79 [00:00<00:54,  1.43it/s, v_num=50dd, train_loss_step=7.710, train_acc_step=0.117]



Epoch 49: 100%|██████████| 79/79 [00:20<00:00,  3.90it/s, v_num=50dd, train_loss_step=0.000691, train_acc_step=1.000, val_loss=0.407, val_acc=0.880, train_loss_epoch=0.00321, train_acc_epoch=1.000]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 79/79 [00:20<00:00,  3.89it/s, v_num=50dd, train_loss_step=0.000691, train_acc_step=1.000, val_loss=0.407, val_acc=0.880, train_loss_epoch=0.00321, train_acc_epoch=1.000]


In [5]:
best_ckpt = trainer.checkpoint_callback.best_model_path
print("Best checkpoint path:", best_ckpt)
trainer.test(model, datamodule=data, ckpt_path=best_ckpt)

Restoring states from the checkpoint path at /home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/models/checkpoints/wide_resnet/epoch=49-val_acc=0.880.ckpt


Best checkpoint path: /home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/models/checkpoints/wide_resnet/epoch=49-val_acc=0.880.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/models/checkpoints/wide_resnet/epoch=49-val_acc=0.880.ckpt
/home/dxzielinski/Desktop/github-repositories/optimization-data-analysis/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 20/20 [00:01<00:00, 13.17it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.8752999901771545
        test_loss           0.4262078106403351
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.4262078106403351, 'test_acc': 0.8752999901771545}]