# CIFAR 10 Classifier: Pytorch Lightning Edition

We use Pytorch Lightning in this notebook, it abstracts away all the boiler plate code that we need to add to every pytorch training workflow.

This notebook performs CIFAR10 classification on TPU using:
- Resnet 50 as backbone
- Minimal augmentation
- Experiment tracking using Comet ML

Most of the content is similar to the [Tensorflow](tf-resnet50.ipynb) and [Pytorch](pytorch-resnet50.ipynb) version of the notebook. 

---

@Date: 26-Sep-2020 | @Author: Katnoria

# 1. Imports & Setup

Install PyTorch/XLA and Lightning.

Ref: https://pytorch-lightning.readthedocs.io/en/latest/tpu.html

In [None]:
# Crash on purpose to get more ram -- Else trainer will not work
import torch
torch.tensor([10.]*10000000000)

In [1]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5115  100  5115    0     0  33874      0 --:--:-- --:--:-- --:--:-- 33651


In [3]:
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Collecting cloud-tpu-client
  Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl
Collecting google-api-python-client==1.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 3.1MB/s 
Uninstalling torch-1.6.0+cu101:
Installing collected packages: google-api-python-client, cloud-tpu-client
  Found existing installation: google-api-python-client 1.7.12
    Uninstalling google-api-python-client-1.7.12:
      Successfully uninstalled google-api-python-client-1.7.12
Successfully installed cloud-tpu-client-0.10 google-api-python-client-1.8.0
Done updating TPU runtime
  Successfully uninstalled torch-1.6.0+cu101
Uninstalling 

In [4]:
# Install lightning
%%capture
! pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

In [25]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, models, utils
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.metrics.functional import accuracy

import torch_xla.core.xla_model as xm

In [6]:
pl.seed_everything(42)

42

In [7]:
def version_info(cls):
    print(f"{cls.__name__}: {cls.__version__}")

In [8]:
# Print version info
version_info(torch)
version_info(pl)
import torch_xla
version_info(torch_xla)

torch: 1.7.0a0+241afc9
pytorch_lightning: 0.9.1rc4
torch_xla: 1.6+1155541


# 2. Dataset

We will load CIFAR10 dataset from pytorch datasets

In [9]:
#BATCH_SIZE=128 # Trainer cannot connect to TPU (timeout)
#BATCH_SIZE=64  # Trainer cannot connect to TPU (timeout)
BATCH_SIZE=8
NUM_WORKERS=4

In [10]:
tfms = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
])

In [11]:
# Load the dataset
train_ds = datasets.CIFAR10(
    root="./data", train=True, 
    download=True, transform=tfms
)
# Create train and validation splits
train, val = random_split(train_ds, [45000, 5000])

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data


In [12]:
# Use DistributedSampler to use TPU 

# sampler = torch.utils.data.distributed.DistributedSampler(
#     train_ds,
#     num_replicas=xm.xrt_world_size(),
#     rank=xm.get_ordinal(),
#     shuffle=True
# )
# train_loader = DataLoader(train, batch_size=BATCH_SIZE, sampler=sampler)
train_loader = DataLoader(train, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [13]:
# test_sampler = torch.utils.data.distributed.DistributedSampler(
#     train_ds,
#     num_replicas=xm.xrt_world_size(),
#     rank=xm.get_ordinal(),
#     shuffle=False
# )
# val_loader = DataLoader(val, batch_size=BATCH_SIZE, sampler=sampler)
val_loader = DataLoader(val, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [14]:
len(train_loader.dataset), len(val_loader.dataset)

(45000, 5000)

In [15]:
test_ds = datasets.CIFAR10(
    root="./data", 
    train=False,
    download=True,
    transform=tfms
)
# test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, sampler=sampler)
test_loader = DataLoader(test_ds, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
len(test_loader.dataset)

Files already downloaded and verified


10000

# 3. Build Model

We will use imagenet pre-trained ResNet50 model. You can swap out the base model with others such as ResNet 18 or ResNet 110. Just make sure the input features of the final linear layer matches with the out features of your base model.

In [57]:
class CIFARTenLitModelV2(pl.LightningModule):
    """CIFAR10 Model"""
    def __init__(self, backbone, learning_rate):
        super().__init__()
        self.learning_rate = learning_rate
        self.backbone = backbone
        self.backbone.fc = nn.Linear(2048, 256)
        self.fc1 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.backbone(x)
        x = F.relu(x)
        out = self.fc1(x)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y, num_classes=10)        
        # result = pl.TrainResult(loss)
        # result.log("train_loss", loss,)
        # result.log("train_acc", acc, prog_bar=True)                
        # result.log("loss", {"train_loss": loss})
        result = {
            "loss": loss,
            "train_acc": acc,
        }
        return result
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y, num_classes=10)        
        # result = pl.EvalResult(checkpoint_on=loss)
#         result.log("val_loss", loss)
        # result.log("val_acc", acc)
        result = {
            "val_loss": loss,
            "val_acc": acc,
        }
        return result
    
    def test_step(self, batch, batch_index):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y, num_classes=10)        
        # result = pl.EvalResult(checkpoint_on=loss)
        # result.log("test_acc", acc)
        result = {
            "train_acc": acc,
        }
        return result
        
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

# 3. Train

Load the pre-trained Resnet 50 model from pytorch 

In [58]:
backbone = models.resnet50(pretrained=True)
for param in backbone.parameters():
    param.requires_grad = False

In [59]:
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=3,
    strict=False,
    verbose=False,
    mode='min'
)

With Pytorch Lightning, the trainer takes care of setting up the training, validation and test loop. It also provides other goodies such as:
* fast dev test (a training loop to check model for errors)
* metric logging
* easy switching between accelerators (GPU/TPU) with minimal code change (just the dataloader sampler) 💯

In [60]:
EPOCHS=10
learning_rate=1e-3

In [61]:
trainer = pl.Trainer( 
    tpu_cores=8,
    progress_bar_refresh_rate=20,
    early_stop_callback=early_stop, 
    max_epochs=EPOCHS,
)

GPU available: False, used: False
TPU available: True, using: 8 TPU cores


In [62]:
model = CIFARTenLitModelV2(backbone, learning_rate)

In [None]:
# track training time
training_start = time.time()

trainer.fit(model, train_loader, val_dataloaders=val_loader)

# print training time
total_time = time.time() - training_start
print(f"Total training time {total_time//60 :.0f}m {total_time%60:.0f}s")

training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None

  | Name     | Type   | Params
------------------------------------
0 | backbone | ResNet | 24 M  
1 | fc1      | Linear | 2 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [None]:
trainer.test(model, test_dataloaders=test_loader)

In [23]:
# Start tensorboard.
!pip install tensorboard
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

<IPython.core.display.Javascript object>

# Next Steps

This is a very simple example of training CIFAR10 classifier using a pre-trained network. Its your turn to turn the knobs and see if you can get model to generalise better. Some ideas:

- make the model overfit your training data
- regularize the model to generalize better
- increase/decrease model capacity based on what you find in above steps
- add image augmentation
- use hyperparameter tuning library to find the best set of combination
- rollout your own model from scratch, you can use the tuning library to help design the network too
