In [None]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from lit_modules.custom_wide_resnet_lit import WideResnetLit
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger

L.seed_everything(42)
torch.set_float32_matmul_precision("medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
IMAGE_SIZE = 32
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
# source: https://pytorch.org/vision/stable/transforms.html
transforms_train = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)
transforms_test = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)
train_dataset = datasets.CIFAR100(
    root="./data", train=True, download=True, transform=transforms_train
)
test_dataset = datasets.CIFAR100(
    root="./data", train=False, download=True, transform=transforms_test
)
N = len(train_dataset)
num_val = int(0.2 * N)
indices = torch.randperm(N)[:num_val]
mask = torch.ones(N, dtype=torch.bool)
mask[indices] = False
train_indices = torch.nonzero(mask, as_tuple=False).squeeze(1)
validation_dataset = torch.utils.data.Subset(train_dataset, indices=indices)
train_dataset = torch.utils.data.Subset(train_dataset, indices=train_indices)
validation_loader = torch.utils.data.DataLoader(
    dataset=validation_dataset, batch_size=512, num_workers=30, persistent_workers=True
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=512, num_workers=30, persistent_workers=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=512, num_workers=30, persistent_workers=True
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def show_img(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


images, labels = next(iter(train_loader))
show_img(torchvision.utils.make_grid(images[:8], nrow=4, padding=2))

In [None]:
class DataModule(L.LightningDataModule):
    def __init__(self, train_loader, validation_loader, test_loader):
        super().__init__()
        self.train_loader = train_loader
        self.validation_loader = validation_loader
        self.test_loader = test_loader

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.validation_loader

    def test_dataloader(self):
        return self.test_loader


data = DataModule(train_loader, validation_loader, test_loader)
model = WideResnetLit(depth=32, weight_decay=10e-4, num_classes=100)
logger = MLFlowLogger(experiment_name="WideResnet-CIFAR100", save_dir="mlruns")
trainer = L.Trainer(
    max_epochs=50,
    logger=logger,
    callbacks=[
        ModelCheckpoint(
            monitor="val_acc",
            mode="max",
            dirpath="checkpoints/wide_resnet",
            filename="{epoch:02d}-{val_acc:.3f}",
        )
    ],
    precision="16-mixed",
    num_sanity_val_steps=0,
)
trainer.fit(model, datamodule=data)

In [None]:
best_ckpt = trainer.checkpoint_callback.best_model_path
trainer.test(model, datamodule=data, ckpt_path=best_ckpt)
None