## Install `PADL-Extensions`

In [None]:
!pip install padl-extensions[pytorch_lightning]
!pip install torchvision

In [None]:
# These might be useful if there are errors regarding ipywidgets while downloading torchvision.datasets
# !pip install ipywidgets
# !jupyter nbextension enable --py widgetsnbextension

## Imports

In [None]:
import torch
import torchvision
from torchvision import models
import numpy as np

import padl
from padl import transform

## Using PADL with Pytorch Lightning

## Dataset:
MNIST dataset available through torchvision is used in this notebook. The dataset can be separately downloaded from MNIST website or can be loaded as given below. 

More details on torchvision's MNIST dataset can be found here: https://pytorch.org/vision/stable/datasets.html#mnist

In [None]:
mnist_train_dataset = torchvision.datasets.MNIST('data', train=True, download=True)
mnist_test_dataset = torchvision.datasets.MNIST('data', train=False, download=True)

## 1. Model Definition

We will build a simple `Unet` to classify `MNIST` handwritings. In the cell below, a simple `torch.nn.Module` is defined with the decorator `@transform`. This is enough to wrap the pytorch model into a `padl.Transform` object.

In [None]:
import torch.nn.functional as F
import torchvision.models.resnet 
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler


@transform
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Conv 1
        # size : input: 28x28x1 -> output : 26 x 26 x 32
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.batchnorm1 = torch.nn.BatchNorm2d(32)

        # Conv 2
        # size : input: 26x26x32 -> output : 24 x 24 x 32
        self.conv2 = torch.nn.Conv2d(32, 32, kernel_size=3)
        self.batchnorm2 = torch.nn.BatchNorm2d(32)

        # Conv 3
        # size : input: 24x24x32 -> output : 12 x 12 x 32
        self.conv3 = torch.nn.Conv2d(32, 32, kernel_size=2, stride = 2)
        self.batchnorm3 = torch.nn.BatchNorm2d(32)

        # Conv 4
        # size : input : 12 x 12 x 32 -> output : 8 x 8 x 64
        self.conv4 = torch.nn.Conv2d(32, 64, kernel_size=5)
        self.batchnorm4 = torch.nn.BatchNorm2d(64)

        # Conv 5
        # size : input: 8x8x64 -> output : 4 x 4 x 64 -> Linearize = 1024
        self.conv5 = torch.nn.Conv2d(64, 64, kernel_size=2, stride = 2)
        self.batchnorm5 = torch.nn.BatchNorm2d(64)

        # dropout layer 
        self.conv5_drop = torch.nn.Dropout2d()

        # FC 1 
        self.fc1 = torch.nn.Linear(1024, 128)

        # FC 2
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.batchnorm1(F.relu(self.conv1(x)))
        x = self.batchnorm2(F.relu(self.conv2(x)))
        x = self.batchnorm3(F.relu(self.conv3(x)))
        x = self.batchnorm4(F.relu(self.conv4(x)))
        x = self.batchnorm5(F.relu(self.conv5(x)))
        x = self.conv5_drop(x)
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x

In [None]:
@transform
def convert_to_tensor(img):
    arr = np.asarray(img)
    return torch.tensor(arr).type(torch.FloatTensor)

preprocess = (
    convert_to_tensor / convert_to_tensor
    >> padl.same.reshape(-1, 28, 28) / padl.Identity()
)

simplenet = SimpleNet()
loss_func = transform(F.nll_loss)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device to be used: ', device)

In [None]:
train_model = (
    preprocess
    >> padl.Batchify()
    >> simplenet / padl.same.type(torch.long)
    >> transform(F.nll_loss)
)

train_model.pd_to(device)

## 2 Converting a PADL model into a Lightning Module

### 2.1 Directly Initialize the class PadlLightning
If your `train_model` has the loss function as the final step you can directly build the `PadlLightning` object by

In [None]:
from padl_ext.pytorch_lightning import PadlLightning

In [None]:
PadlLightning?

In [None]:
batch_size = 256
num_workers = 0

padl_lightning_module = PadlLightning(
    train_model,  # train_model with the loss function
    train_data=mnist_train_dataset,  # list of training data points
    val_data=mnist_test_dataset,  # list of validation data points
    batch_size=batch_size,
    num_workers=num_workers
)
# pad_lightning is a LightningModule !

### 2.2 Inherit from PADLLightning
The class `PADLLightning` is already a `LightningModule` so inherting from it allows for all the regular customizations available in Pytorch Lightning

In [None]:
batch_size = 256
num_workers = 4
learning_rate = 0.01

class MyModule(PadlLightning):
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [None]:
batch_size = 256
num_workers = 4

padl_lightning_module = MyModule(
    train_model,  # train_model with the loss function
    train_data=mnist_train_dataset,  # list of training data points
    val_data=mnist_test_dataset,  # list of validation data points
    batch_size=batch_size,
    num_workers=num_workers
)
# pad_lightning is a LightningModule !

## 3. Training and validating the `train_model` with the PADL-Pytorch Lightning Connector

In [None]:
import pytorch_lightning as pl

log_interval = 10
nepoch = 2

trainer = pl.Trainer(
    gpus=1 if device == 'cuda' else 0,
    val_check_interval=10,
    max_epochs=nepoch,
    default_root_dir='test',
    log_every_n_steps=log_interval
)
trainer.fit(padl_lightning_module)