### Train CIFAR-10 classifier using ResNet from TorchHub

In [None]:
import torch
import torchvision
import lightning
from lightning.pytorch.loggers import CSVLogger
from torchvision import datasets, transforms
from torch.utils.data import dataset, DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from common_def import LightningModel, plot_metrics_from_csv_log
import torch.nn.functional as F

#### Define DataModule

In [None]:
class CIFAR10DataModule(lightning.LightningDataModule):
    def __init__(self, data_dir='./dataset/cifar-10', batch_size=64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def prepare_data(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.train_dataset = datasets.CIFAR10(root=self.data_dir, train=True, transform=transform, download=True)
        self.test_dataset = datasets.CIFAR10(root=self.data_dir, train=False, transform=transform, download=True)

    def setup(self, stage=None):
        self.train_dataset, self.val_dataset = dataset.random_split(dataset=self.train_dataset, lengths=[45000, 5000])

    def train_dataloader(self):
        return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False)
    

#### Load the ResNet from TorchHub

In [10]:
# examine torchhub models
model_entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
for e in model_entrypoints:
    if 'resnet' in e:
        print(e)

Downloading: "https://github.com/pytorch/vision/zipball/main" to /home/tu/.cache/torch/hub/main.zip


deeplabv3_resnet101
deeplabv3_resnet50
fcn_resnet101
fcn_resnet50
resnet101
resnet152
resnet18
resnet34
resnet50
wide_resnet101_2
wide_resnet50_2


In [None]:
resnet18_model = torch.hub.load('pytorch/vision', 'resnet18', weights=None)
# weights = None for just loading the model architecture, not pre-trained weights

#### Training

In [None]:
lightning.seed_everything(123)

cifar_dm = CIFAR10DataModule()

lightning_model = LightningModel(torch_model=resnet18_model, num_classes=10, learning_rate=0.1)

trainer = lightning.Trainer(
    max_epochs=10, 
    logger=CSVLogger('lightning_logs', name='ResNet18-CIFAR10'), 
    deterministic=True,
    accelerator='gpu', # cannot train on CPU
    devices=1
    )

trainer.fit(lightning_model, cifar_dm)