<a href="https://colab.research.google.com/github/edgarriba/kornia-examples/blob/master/kornia_lightning_mnist_gpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Kornia and PyTorch Lightning GPU data augmentation demo

In this tutorial we show how one can combine both Kornia and PyTorch Lightning to perform data augmentation to train a model using CPUs and GPUs in batch mode without additional effort.


Enjoy the example !

## Install Kornia and PyTorch Lightning
Next, we install Kornia and PyTorch Lightning

In [2]:
! pip install git+git://github.com/PyTorchLightning/pytorch-lightning.git@master --upgrade
! pip install kornia -q

Collecting git+git://github.com/PyTorchLightning/pytorch-lightning.git@master
  Cloning git://github.com/PyTorchLightning/pytorch-lightning.git (to revision master) to /tmp/pip-req-build-259a97b9
  Running command git clone -q git://github.com/PyTorchLightning/pytorch-lightning.git /tmp/pip-req-build-259a97b9
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: pytorch-lightning
  Building wheel for pytorch-lightning (PEP 517) ... [?25l[?25hdone
  Created wheel for pytorch-lightning: filename=pytorch_lightning-0.6.1.dev0-cp36-none-any.whl size=148144 sha256=0e5792706a76a81d4ddf66132826ca67d7250e59cc710b4dd3bda4c76c06f30e
  Stored in directory: /tmp/pip-ephem-wheel-cache-6vl82xvd/wheels/10/69/8d/aa7539d71fd3f79eec266a45eae3206c704f37fcf4ff976e0d
Successfully built pytorch-lightning
Installing collected packages: pytorch-lightning
  Found 

## Define lightning model

In [0]:
import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

import pytorch_lightning as pl
import kornia as K

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

        self.transform = torch.nn.Sequential(
            K.augmentation.RandomRectangleErasing((.05, .1), (.3, 1/.3)),
            K.augmentation.RandomRotation((-15., 15.))
        )
    
        self.pil_to_tensor = lambda x: K.image_to_tensor(x)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        x_aug = self.transform(x)  # => we perform GPU/Batched data augmentation
        y_hat = self.forward(x_aug)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
        
    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.0004)

    def prepare_data(self):
        MNIST(os.getcwd(), train=True, download=True, transform=self.pil_to_tensor)
        MNIST(os.getcwd(), train=False, download=True, transform=self.pil_to_tensor)

    def train_dataloader(self):
        # REQUIRED
        dataset = MNIST(os.getcwd(), train=True, download=False, transform=self.pil_to_tensor)
        loader = DataLoader(dataset, batch_size=32)
        return loader

    def val_dataloader(self):
        dataset = MNIST(os.getcwd(), train=True, download=False, transform=self.pil_to_tensor)
        loader = DataLoader(dataset, batch_size=32)
        return loader

    def test_dataloader(self):
        dataset = MNIST(os.getcwd(), train=False, download=False, transform=self.pil_to_tensor)
        loader = DataLoader(dataset, batch_size=32)
        return loader

## Run training

In [0]:
from pytorch_lightning import Trainer

model = CoolSystem()

# most basic trainer, uses good defaults
trainer = Trainer(num_tpu_cores=8)
trainer.fit(model)

INFO:root:training on 8 TPU cores
INFO:root:INIT TPU local core: 0, global rank: 0
INFO:root:INIT TPU local core: 6, global rank: 6
INFO:root:INIT TPU local core: 3, global rank: 3
INFO:root:INIT TPU local core: 7, global rank: 7
INFO:root:INIT TPU local core: 2, global rank: 2
INFO:root:INIT TPU local core: 4, global rank: 4
INFO:root:INIT TPU local core: 5, global rank: 5
INFO:root:INIT TPU local core: 1, global rank: 1


INFO:root:
  | Name        | Type                 | Params
-------------------------------------------------
0 | l1          | Linear               | 7 K   
1 | transform   | Sequential           | 0     
2 | transform.0 | RandomHorizontalFlip | 0     


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=235.0, style=Pr…




In [0]:
trainer.test(model)