## Install `PADL-Extensions`

In [None]:
!pip install padl-extensions[pytorch_lightning]
!pip install torchvision

In [None]:
# These might be useful if there are errors regarding ipywidgets while downloading torchvision.datasets
# !pip install ipywidgets
# !jupyter nbextension enable --py widgetsnbextension

## Imports

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision import models

import padl
from padl import transform

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

## Using PADL with Pytorch Lightning

## Dataset:
MNIST dataset available through torchvision is used in this notebook. The dataset can be separately downloaded from MNIST website or can be loaded as given below. 

More details on torchvision's MNIST dataset can be found here: https://pytorch.org/vision/stable/datasets.html#mnist

In [None]:
mnist_train_dataset = torchvision.datasets.MNIST('data', train=True, download=True)
mnist_test_dataset = torchvision.datasets.MNIST('data', train=False, download=True)

## 1. Model Definition

We will build a simple `Unet` to classify `MNIST` handwritings. In the cell below, a simple `torch.nn.Module` is defined with the decorator `@transform`. This is enough to wrap the pytorch model into a `padl.Transform` object.

In [None]:
import torch.nn.functional as F
import torchvision.models.resnet 
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler


@transform
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Conv 1
        # size : input: 28x28x1 -> output : 26 x 26 x 32
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.batchnorm1 = torch.nn.BatchNorm2d(32)

        # Conv 2
        # size : input: 26x26x32 -> output : 24 x 24 x 32
        self.conv2 = torch.nn.Conv2d(32, 32, kernel_size=3)
        self.batchnorm2 = torch.nn.BatchNorm2d(32)

        # Conv 3
        # size : input: 24x24x32 -> output : 12 x 12 x 32
        self.conv3 = torch.nn.Conv2d(32, 32, kernel_size=2, stride = 2)
        self.batchnorm3 = torch.nn.BatchNorm2d(32)

        # Conv 4
        # size : input : 12 x 12 x 32 -> output : 8 x 8 x 64
        self.conv4 = torch.nn.Conv2d(32, 64, kernel_size=5)
        self.batchnorm4 = torch.nn.BatchNorm2d(64)

        # Conv 5
        # size : input: 8x8x64 -> output : 4 x 4 x 64 -> Linearize = 1024
        self.conv5 = torch.nn.Conv2d(64, 64, kernel_size=2, stride = 2)
        self.batchnorm5 = torch.nn.BatchNorm2d(64)

        # dropout layer 
        self.conv5_drop = torch.nn.Dropout2d()

        # FC 1 
        self.fc1 = torch.nn.Linear(1024, 128)

        # FC 2
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.batchnorm1(F.relu(self.conv1(x)))
        x = self.batchnorm2(F.relu(self.conv2(x)))
        x = self.batchnorm3(F.relu(self.conv3(x)))
        x = self.batchnorm4(F.relu(self.conv4(x)))
        x = self.batchnorm5(F.relu(self.conv5(x)))
        x = self.conv5_drop(x)
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x

In [None]:
@transform
def convert_to_tensor(img):
    arr = np.asarray(img)
    return torch.tensor(arr).type(torch.FloatTensor)

preprocess = (
    convert_to_tensor / convert_to_tensor
    >> padl.same.reshape(-1, 28, 28) / padl.Identity()
)

simplenet = SimpleNet()
loss_func = transform(F.nll_loss)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device to be used: ', device)

Let's define our training model which must include the loss as the last step in the pipeline.

In [None]:
train_model = (
    preprocess
    >> padl.Batchify()
    >> simplenet / padl.same.type(torch.long)
    >> transform(F.nll_loss)
)

train_model.pd_to(device)

For inference let's define a separate pipeline that will return the predicted number of the model.

In [None]:
infer_preprocess =(
    padl.same[0]
    >> convert_to_tensor
)
infer_model = (
    infer_preprocess
    >> padl.Batchify()
    >> padl.same.unsqueeze(1) 
    >> simplenet
    >> padl.transform(lambda x: x.max(1).indices)
)
infer_model

## 2. Converting a PADL model into a Lightning Module

The class `PadlLightning` is already a `LightningModule` so inherting from it allows for all the regular customizations available in Pytorch Lightning

In [None]:
from padl_ext.pytorch_lightning import PadlLightning, padl_data_loader

In [None]:
PadlLightning?

Any defaults in `PadlLightning` can be easily overwritten. For example, let's overwrite the default optimizer in `PadlLightning` by overwriting the function `configure_optimizers` as shown below:

In [None]:
class MyModule(PadlLightning):
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
batch_size = 256
num_workers = 0

padl_lightning_module = MyModule(
    train_model,  # train_model with the loss function
    train_data=mnist_train_dataset,  # list of training data points
    val_data=mnist_test_dataset,  # list of validation data points
    batch_size=batch_size,
    num_workers=num_workers
)
# padl_lightning_module is a LightningModule !

### Model Training
First we set up the pytorch lightning trainer.

In [None]:
log_interval = 10
nepoch = 1
ngpus = 1 if device == 'cuda' else 0

# Define callbacks to be given to the trainer
early_stop = EarlyStopping(monitor="val_loss", mode="min")
model_checkpoint = ModelCheckpoint(monitor="val_loss", every_n_epochs=1, save_top_k=1)
callbacks = [early_stop, model_checkpoint]

trainer = pl.Trainer(
    callbacks=callbacks,
    gpus=ngpus,
    val_check_interval=10,
    max_epochs=nepoch,
    default_root_dir='test',
    log_every_n_steps=log_interval
)

Now let's train the model!

In [None]:
from padl_ext.pytorch_lightning import padl_data_loader

In [None]:
trainer.fit(padl_lightning_module)

## 3. Infer a few images from the test data

First we define a few transforms to plot the images from the MNIST test set

In [None]:
@transform
def plot_image(img_tensor):
    fig= plt.figure(figsize=(2,2))
    ax = fig.add_subplot(111)
    ax.imshow(img_tensor, cmap='gray')
    plt.axis('off')
    return fig

@transform
def img_to_array(img):
    return np.asarray(img)

convert_plot = (
    img_to_array
    >> plot_image
)

plot_datapoint = (convert_plot - 'image')/ (padl.Identity() - 'label')

Now we plot the MNIST image and the model prediction!

In [None]:
for _ in range(3):
    data_point = mnist_test_dataset[np.random.randint(len(mnist_test_dataset))]
    output = plot_datapoint(data_point)
    plt.show()
    print(f'Prediction: {infer_model.infer_apply(data_point).item()}')
    print('-'*30)

## 4. Inspect the pytorch lightning checkpoint
We can inspect the checkpoint file saved by the pytorch lightning trainer using the following command

In [None]:
ckpt_file = trainer.checkpoint_callback.best_model_path
print(ckpt_file)

In [None]:
ckpt = torch.load(ckpt_file, map_location=torch.device("cpu"))
ckpt.keys()

As we can see pytorch lightning saves quite a lot of information about the state of our model, optimizer, and trainer. Additionally, we have added the keyword `padl_models` that show the locations of all saved PADL models that resulted from this training.

In [None]:
ckpt['padl_models']

## 5. Restart training from the Pytorch Lightning Checkpoint
To intialize `MyModule` from a pytorch lightning checkpoint file we will need to use the `MyModule.load_from_checkpoint` function. We will also need to provide some additional arguments such as `padl_model`, `train_data`, and `val_data`. 

In [None]:
loaded_module = MyModule.load_from_checkpoint(
    ckpt_file,
    padl_model=ckpt['padl_models'][-1],
    train_data=mnist_train_dataset,
    val_data=mnist_test_dataset
)

In [None]:
# trainer.fit(loaded_module)

## 6. Export the trained `torch.nn.Module` layer into a separate model

In [None]:
from padl import load

In [None]:
loaded_model = load(ckpt['padl_models'][-1])

From inspecting the loaded model we can see that we can access the trained layer at position 3

In [None]:
loaded_model

Acessing the trained layer can be either be done by using

In [None]:
loaded_model[3][0]

or by providing the name of the `padl.transform` which is `"simplenet"` for our case

In [None]:
loaded_model['simplenet']

Create an inference model using the `simplenet` layer from `loaded_model`

In [None]:
infer_preprocess =(
    padl.same[0]
    >> convert_to_tensor
)
trained_net = loaded_model['simplenet']
infer_model = (
    infer_preprocess
    >> padl.Batchify()
    >> padl.same.unsqueeze(1) 
    >> trained_net
    >> padl.transform(lambda x: x.max(1).indices)
)

In [None]:
for _ in range(3):
    data_point = mnist_test_dataset[np.random.randint(len(mnist_test_dataset))]
    output = plot_datapoint(data_point)
    plt.show()
    print(f'Prediction: {infer_model.infer_apply(data_point).item()}')
    print('-'*30)