<a href="https://colab.research.google.com/github/edgarriba/kornia-benchmark/blob/master/kornia_lightning_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Kornia and PyTorch Lightning GPU data augmentation demo

In this tutorial we show how one can combine both Kornia and PyTorch Lightning to perform data augmentation to train a model using CPUs and GPUs in batch mode without additional effort.


Enjoy the example !

## Install Kornia and PyTorch Lightning
Next, we install Kornia and PyTorch Lightning

In [1]:
! pip install pytorch-lightning==0.7.1
! pip install git+https://github.com/kornia/kornia

Collecting pytorch-lightning==0.7.1
[?25l  Downloading https://files.pythonhosted.org/packages/61/9e/db4e1e3036e045a25d5c37617ded31a673a61f4befc62c5231818810b3a7/pytorch-lightning-0.7.1.tar.gz (6.0MB)
[K     |████████████████████████████████| 6.0MB 3.2MB/s 
Collecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 35.4MB/s 
Building wheels for collected packages: pytorch-lightning, future
  Building wheel for pytorch-lightning (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-lightning: filename=pytorch_lightning-0.7.1-cp36-none-any.whl size=145306 sha256=766ed99af5d80d2e54b2592983d615d6fe0c0af73fc6b378e70d2cf9effbe8d4
  Stored in directory: /root/.cache/pip/wheels/dc/93/61/14094d2116ff739513dda993007501ae5701b78386b39d5912
  Building wheel for future (setup.py) ... [?25l[?25hdone
  Created wheel for f

## Define lightning model

In [0]:
import os

import torch
import torch.nn as nn

from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

import pytorch_lightning as pl
import kornia as K

# based on: https://github.com/pytorch/examples/blob/master/mnist/main.py

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        self.model = Net()

        self.transform = torch.nn.Sequential(
            K.color.Normalize(0.1307, 0.3081),
        )
    
        self.pil_to_tensor = lambda x: K.image_to_tensor(x)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        x_aug = self.transform(x)  # => we perform GPU/Batched data augmentation
        y_hat = self.forward(x_aug)
        loss = F.nll_loss(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.nll_loss(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
        
    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': F.nll_loss(y_hat, y)}

    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adadelta(self.parameters(), lr=0.0004)

    def prepare_data(self):
        MNIST(os.getcwd(), train=True, download=True, transform=self.pil_to_tensor)
        MNIST(os.getcwd(), train=False, download=True, transform=self.pil_to_tensor)

    def train_dataloader(self):
        # REQUIRED
        dataset = MNIST(os.getcwd(), train=True, download=False, transform=self.pil_to_tensor)
        loader = DataLoader(dataset, batch_size=32)
        return loader

    def val_dataloader(self):
        dataset = MNIST(os.getcwd(), train=True, download=False, transform=self.pil_to_tensor)
        loader = DataLoader(dataset, batch_size=32)
        return loader

    def test_dataloader(self):
        dataset = MNIST(os.getcwd(), train=False, download=False, transform=self.pil_to_tensor)
        loader = DataLoader(dataset, batch_size=32)
        return loader

## Run training

In [10]:
from pytorch_lightning import Trainer

model = CoolSystem()

# most basic trainer, uses good defaults
trainer = Trainer(profiler=True, max_epochs=1, gpus=1)
trainer.fit(model)

HBox(children=(IntProgress(value=0, description='Validation sanity check', layout=Layout(flex='2'), max=5, sty…



HBox(children=(IntProgress(value=1, bar_style='info', layout=Layout(flex='2'), max=1), HTML(value='')), layout…

HBox(children=(IntProgress(value=0, description='Validating', layout=Layout(flex='2'), max=1875, style=Progres…




1