# CIFAR 10 Classifier: Pytorch Lightning Version

This notebook uses Pytorch Lightning (PL) to demonstrate how easy it is to use TPU 😀, I also like the approach PL has taken in terms of organising the code. Anyone who writes PyTorch code ends up writing the same boiler plate over-and-over again, apart from the Model. 

---

@date: 03-Sep-2020 | @author: katnoria

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, models, utils
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

In [2]:
pl.seed_everything(42)

42

In [3]:
def version_info(cls):
    print(f"{cls.__name__}: {cls.__version__}")

In [4]:
version_info(torch)
version_info(pl)

torch: 1.6.0
pytorch_lightning: 0.9.0


Stuff from Keras-Tuner on Colab
```
{'conv_blocks': 5,
 'dropout': 0.0,
 'filters_0': 128,
 'filters_1': 160,
 'filters_2': 160,
 'filters_3': 192,
 'filters_4': 224,
 'hidden_size': 80,
 'learning_rate': 0.0025359172395390105,
 'pooling_0': 'max',
 'pooling_1': 'max',
 'pooling_2': 'max',
 'pooling_3': 'max',
 'pooling_4': 'max',
 'tuner/bracket': 2,
 'tuner/epochs': 30,
 'tuner/initial_epoch': 10,
 'tuner/round': 2,
 'tuner/trial_id': 'b780a5e9191d55c360ad5be6040decc4'}
```

In [5]:
class CIFARTenLitModel(pl.LightningModule):
    """CIFAR10 Model"""
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.backbone.fc = nn.Linear(2048, 10)
#         self.fc1 = nn.Linear(2048, 128)
#         self.fc2 = nn.Linear(128, 10)
#         nn.AvgPool2d()

    def forward(self, x):
        x = self.backbone(x)
#         x = F.relu(self.fc1(x))
#         out = self.fc2(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        result = pl.TrainResult(loss)
        result.log("train_loss", loss, prog_bar=True)
        return result
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log("val_loss", loss)
        return result
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.02)

# Dataset

In [5]:
tfms = transforms.Compose([
    transforms.Resize(224),    
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
])

In [6]:
train_ds = datasets.CIFAR10(
    root="./data", 
    train=True,
    download=True,
    transform=tfms
)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=12)

Files already downloaded and verified


In [7]:
test_ds = datasets.CIFAR10(
    root="./data", 
    train=False,
    download=True,
    transform=tfms
)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=12)

Files already downloaded and verified


# Train

In [8]:
backbone = models.resnet50(pretrained=True)
for param in backbone.parameters():
    param.requires_grad = False

In [17]:
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=3,
    strict=False,
    verbose=False,
    mode='min'
)

NameError: name 'EarlyStopping' is not defined

In [10]:
trainer = pl.Trainer(
    fast_dev_run=False, 
    gpus=1, 
    early_stop_callback=early_stop, 
    max_epochs=500,
    auto_lr_find=True
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [11]:
model = CIFARTenLitModel(backbone)

In [13]:
trainer.fit(model, train_loader, val_dataloaders=test_loader)


  | Name     | Type   | Params
------------------------------------
0 | backbone | ResNet | 23 M  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





1

In [16]:
trainer.test(test_dataloaders=test_loader, ckpt_path="lightning_logs/version_12/checkpoints")

IsADirectoryError: [Errno 21] Is a directory: 'lightning_logs/version_12/checkpoints'

In [33]:
train_ds

Dataset CIFAR10
    Number of datapoints: 50000
    Split: train
    Root Location: ./data
    Transforms (if any): Compose(
                             ToTensor()
                             Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                         )
    Target Transforms (if any): None

In [37]:
for batch in train_loader:
    print(batch[0].size())
    break

torch.Size([128, 3, 32, 32])
