In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

In [3]:
class CIFAR10Module(L.LightningDataModule):
    def __init__(self, data_dir, batch_size=4, img_size=(32, 32), num_workers=4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.num_workers = num_workers

        self.train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(self.img_size),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.2154, 0.2024))
        ])
        
        self.val_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(self.img_size),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.2154, 0.2024))
        ])

    def prepare_data(self):
        datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        datasets.CIFAR10(root=self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        train_dataset = datasets.CIFAR10(root=self.data_dir, train=True, transform=self.train_transform)
        val_dataset = datasets.CIFAR10(root=self.data_dir, train=False, transform=self.val_transform)
        self.train_set = train_dataset
        self.val_set = val_dataset

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True)

In [4]:
class ResnetModel(L.LightningModule):
    def __init__(self, num_classes=10, initial_lr=0.001, weight_decay=5e-4, gamma=0.1, step_size=5):
        super().__init__()
        self.model = torchvision.models.resnet18(weights='DEFAULT')
        self.loss_module = nn.CrossEntropyLoss()
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)

        self.num_classes = num_classes
        self.initial_lr = initial_lr
        self.weight_decay = weight_decay
        self.gamma = gamma
        self.step_size = step_size

    def forward(self, images):
        return self.model(images)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        outputs = self.forward(images)

        loss = F.cross_entropy(outputs, targets)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        outputs = self.forward(images)

        # Calculate loss
        loss = F.cross_entropy(outputs, targets)
        self.log('val_loss', loss, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.initial_lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.gamma)
        return [optimizer], [scheduler]

In [5]:
data_dir = 'data/'
data_module = CIFAR10Module(data_dir, batch_size=32)
model = ResnetModel(num_classes=10)

logger = TensorBoardLogger("lightning_logs", name="cifar10-resnet")
trainer = L.Trainer(
    accelerator="auto",
    max_epochs=50,
    callbacks=[
        EarlyStopping(monitor="val_loss", patience=1),
        ModelCheckpoint(monitor="val_loss", save_top_k=1)
    ],
    logger=logger
)

trainer.fit(model, data_module)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/davidroot/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 56.3MB/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params
-------------------------------------------------
0 | model       | ResNet           | 11.2 M
1 | loss_module | CrossEntropyLoss | 0     
-------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Epoch 7: 100%|██████████| 1563/1563 [00:16<00:00, 93.72it/s, v_num=2]      
