# PyTorch Lightning cifar10
https://github.com/bipinKrishnan/pytorch_lightning_examples/blob/main/pytorch_lightning_cifar10.ipynb

In [2]:
!pip install pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.6.1-py3-none-any.whl (582 kB)
[K     |████████████████████████████████| 582 kB 17.9 MB/s eta 0:00:01
Collecting pyDeprecate<0.4.0,>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.8.0-py3-none-any.whl (408 kB)
[K     |████████████████████████████████| 408 kB 36.0 MB/s eta 0:00:01
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 33.3 MB/s eta 0:00:01
Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.3.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (158 kB)
[K     |████████████████████████████████| 158 kB 30.9 MB/s eta 0:00:01
Collecting async-timeout<5.0,>

In [3]:
from torch import nn, optim
import torch.nn.functional as F
from torchvision.transforms import transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer  # import the pytorch lightning module

In [4]:
def load_dataloader(name, transform, batch):
    if name == 'cifar10':
        train_ds = datasets.CIFAR10('../Data/cifar10_builtin', train=True, download=True, transform=transform)
        val_ds = datasets.CIFAR10('../Data/cifar10_builtin', train=False, download=True, transform=transform)

        train_dl = DataLoader(train_ds, batch, shuffle=True)
        val_dl = DataLoader(val_ds, batch, shuffle=False)

        return train_dl, val_dl

In [5]:
class Model(LightningModule):
    def __init__(self, model='resnet18', freeze=False):
        super().__init__()
        self.model = self._build_model(model, freeze)

    def _build_model(self, name, freeze):
        if name=='resnet18':
            model = models.resnet18(pretrained=True)

            if freeze==True:
                for params in model.parameters():
                    params.requires_grad = False  # freeze the model
                model.fc = nn.Linear(512, 10)  # replace the last layer with a linear layer of 10 outputs
                return model
            else:
                model.fc = nn.Linear(512, 10)
                return model
        else:
            raise ValueError('Model not supported')

    def forward(x):  # define the forward pass
        out = self.model(x)
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self.model(x)
        loss = F.cross_entropy(out, y)
        # self.log("train_loss", loss, prog_bar=True)  # tensorboard logging

        return loss  # return training loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self.model(x)
        loss = F.cross_entropy(out, y)
        # self.log("val_loss", loss, prog_bar=True)

        return loss  # return validation loss

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters())  # return optimizer

In [7]:
train_dl, val_dl = load_dataloader('cifar10', transforms.ToTensor(), 128)
model = Model('resnet18', freeze=True)

Files already downloaded and verified
Files already downloaded and verified


ValueError: Model not supported

In [9]:
trainer = Trainer(max_epochs=20,
                  gpus=1,
                  progress_bar_refresh_rate=60)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(model, train_dl, val_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
5.1 K     Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]