<a href="https://colab.research.google.com/github/edgarriba/kornia-examples/blob/master/kornia_lightning_mnist_gpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Kornia and PyTorch Lightning GPU data augmentation demo

In this tutorial we show how one can combine both Kornia and PyTorch Lightning to perform data augmentation to train a model using CPUs and GPUs in batch mode without additional effort.


Enjoy the example !

## Install Kornia and PyTorch Lightning
Next, we install Kornia and PyTorch Lightning

In [1]:
! pip install git+git://github.com/PyTorchLightning/pytorch-lightning.git@master --upgrade
! pip install kornia -q

Collecting git+git://github.com/PyTorchLightning/pytorch-lightning.git@master
  Cloning git://github.com/PyTorchLightning/pytorch-lightning.git (to revision master) to /tmp/pip-req-build-t2r5dn0h
  Running command git clone -q git://github.com/PyTorchLightning/pytorch-lightning.git /tmp/pip-req-build-t2r5dn0h
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting tqdm>=4.35.0
[?25l  Downloading https://files.pythonhosted.org/packages/47/55/fd9170ba08a1a64a18a7f8a18f088037316f2a41be04d2fe6ece5a653e8f/tqdm-4.43.0-py2.py3-none-any.whl (59kB)
[K     |████████████████████████████████| 61kB 1.9MB/s 
[?25hCollecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 4.1MB/s 
Building wheels for collected pac

[?25l[K     |██▎                             | 10kB 23.8MB/s eta 0:00:01[K     |████▋                           | 20kB 1.7MB/s eta 0:00:01[K     |███████                         | 30kB 2.5MB/s eta 0:00:01[K     |█████████▏                      | 40kB 1.7MB/s eta 0:00:01[K     |███████████▌                    | 51kB 2.0MB/s eta 0:00:01[K     |█████████████▉                  | 61kB 2.4MB/s eta 0:00:01[K     |████████████████▏               | 71kB 2.8MB/s eta 0:00:01[K     |██████████████████▍             | 81kB 3.2MB/s eta 0:00:01[K     |████████████████████▊           | 92kB 3.6MB/s eta 0:00:01[K     |███████████████████████         | 102kB 2.7MB/s eta 0:00:01[K     |█████████████████████████▍      | 112kB 2.7MB/s eta 0:00:01[K     |███████████████████████████▋    | 122kB 2.7MB/s eta 0:00:01[K     |██████████████████████████████  | 133kB 2.7MB/s eta 0:00:01[K     |████████████████████████████████| 143kB 2.7MB/s 
[?25h

## Define lightning model

In [2]:
import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

import pytorch_lightning as pl
import kornia as K

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

        self.transform = torch.nn.Sequential(
            K.augmentation.RandomRectangleErasing((.05, .1), (.3, 1/.3)),
            K.augmentation.RandomRotation((-15., 15.))
        )
    
        self.pil_to_tensor = lambda x: K.image_to_tensor(x)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        x_aug = self.transform(x)  # => we perform GPU/Batched data augmentation
        y_hat = self.forward(x_aug)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
        
    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.0004)

    def prepare_data(self):
        MNIST(os.getcwd(), train=True, download=True, transform=self.pil_to_tensor)
        MNIST(os.getcwd(), train=False, download=True, transform=self.pil_to_tensor)

    def train_dataloader(self):
        # REQUIRED
        dataset = MNIST(os.getcwd(), train=True, download=False, transform=self.pil_to_tensor)
        loader = DataLoader(dataset, batch_size=32)
        return loader

    def val_dataloader(self):
        dataset = MNIST(os.getcwd(), train=True, download=False, transform=self.pil_to_tensor)
        loader = DataLoader(dataset, batch_size=32)
        return loader

    def test_dataloader(self):
        dataset = MNIST(os.getcwd(), train=False, download=False, transform=self.pil_to_tensor)
        loader = DataLoader(dataset, batch_size=32)
        return loader



## Run training

In [3]:
from pytorch_lightning import Trainer

model = CoolSystem()

# most basic trainer, uses good defaults
trainer = Trainer()
trainer.fit(model)

INFO:root:
  | Name        | Type                   | Params
---------------------------------------------------
0 | l1          | Linear                 | 7 K   
1 | transform   | Sequential             | 0     
2 | transform.0 | RandomRectangleErasing | 0     
3 | transform.1 | RandomRotation         | 0     


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/train-images-idx3-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/train-labels-idx1-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



Extracting /content/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw
Processing...
Done!


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1875.0, style=P…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1875.0, style=P…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1875.0, style=P…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1875.0, style=P…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1875.0, style=P…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1875.0, style=P…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=1875.0, style=P…




1

In [4]:
trainer.test(model)

INFO:root:
  | Name        | Type                   | Params
---------------------------------------------------
0 | l1          | Linear                 | 7 K   
1 | transform   | Sequential             | 0     
2 | transform.0 | RandomRectangleErasing | 0     
3 | transform.1 | RandomRotation         | 0     
INFO:root:Model and Trainer restored from checkpoint: /content/lightning_logs/version_0/checkpoints/_ckpt_epoch_3.ckpt


HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=313.0, style=Progr…




HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=313.0, style=Progr…


