In [None]:
!pip install padl-extensions[trainer]
!pip install padl

In [None]:
import sys

sys.path.append('..')

import torch
import torchvision.datasets
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

import numpy as np
import padl
from padl_ext.trainer.trainer import Trainer

This tutorial and accompanying notebook show you how to implement a highly portable training object for PyTorch modules using PADL. We'll be using the classic MNIST dataset and a standard CNN for illustrative purposes. The same approach applies to arbitrary PyTorch models. For more background on PADL see here and here, 
and a fully working example here.

In [None]:
train_data = torchvision.datasets.MNIST('data', train=True, download=True)
valid_data = torchvision.datasets.MNIST('data', train=False, download=True)

Here's our layer transform implementing the CNN. Notice the decoration `@padl.transform` - all that's necessary to access
the full range of cool PADL functionality.

In [None]:
@padl.transform
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.batchnorm1 = torch.nn.BatchNorm2d(32)
        self.conv2 = torch.nn.Conv2d(32, 32, kernel_size=3)
        self.batchnorm2 = torch.nn.BatchNorm2d(32)
        self.conv3 = torch.nn.Conv2d(32, 32, kernel_size=2, stride = 2)
        self.batchnorm3 = torch.nn.BatchNorm2d(32)
        self.conv4 = torch.nn.Conv2d(32, 64, kernel_size=5)
        self.batchnorm4 = torch.nn.BatchNorm2d(64)
        self.conv5 = torch.nn.Conv2d(64, 64, kernel_size=2, stride = 2)
        self.batchnorm5 = torch.nn.BatchNorm2d(64)
        self.conv5_drop = torch.nn.Dropout2d()
        self.fc1 = torch.nn.Linear(1024, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.batchnorm1(F.relu(self.conv1(x)))
        x = self.batchnorm2(F.relu(self.conv2(x)))
        x = self.batchnorm3(F.relu(self.conv3(x)))
        x = self.batchnorm4(F.relu(self.conv4(x)))
        x = self.batchnorm5(F.relu(self.conv5(x)))
        x = self.conv5_drop(x)
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
simplenet = SimpleNet()

All tensors in PADL are accessed by pushing data through "pipelines" or "transforms". "Transforms" are the basic
building blocks and pipelines are compositions, and branches built up from "transforms". 

In the example below, `train_model` is built up of a preprocessor, which prepares tensors, and additionally
a relatively trivial branch for the target labels. The prepared tensors are pushed through the layer, followed by the loss
together with the labels.

The pipeline makes use of the overloaded operators `>>` (compose) and `/` (apply-in-parallel). For more 
introduction to these operators see here.

In [None]:
preprocess = (
    padl.transform(lambda x: np.array(x).astype(np.float32))
    >> padl.transform(lambda x: torch.from_numpy(x).type(torch.float))
    >> padl.same.reshape(-1, 28, 28)
)

train_model = (
    preprocess / padl.identity
    >> padl.batch
    >> simplenet / padl.same.type(torch.long)
    >> padl.transform(F.cross_entropy)
)

train_model

When we eventually use the trained layer, we won't need the loss or the labels. For that reason we create an
additional pipeline, whose weights are tied to `train_model`, which we'll use in testing, demo-ing, serving etc..

This model may contain non-PyTorch postprocessing (everything after the `unbatch`) which can come in handy
when communicating with other bits of your infrastructure, such as returning results in the body of a response etc.. In this case, we add the results to a dictionary, along with the confidence estimate.

In [None]:
infer_model = (
    preprocess
    >> padl.batch
    >> simplenet
    >> padl.unbatch
    >> padl.transform(lambda x: x.exp() / x.exp().sum())
    >> padl.transform(lambda x: x.topk(1))
    >> padl.transform(lambda x: {'probability': x[0].item(), 'prediction': x[1].item()})
)
infer_model

In order to monitor performance, let's create a metric.

In [None]:
def accuracy(x, y):
    return sum([xx['prediction'] == yy for xx, yy in zip(x, y)]) / len(x)

The torch-extensions package contains a simple trainer, which may be configured to cover many use cases.
In order to extend the trainer, the methods may be simply overwritten. Alternatively, simply create a new 
transform object, with methods to manage training, saving etc.. The `@padl.transform` decorator along with
the methods `Transform.pre_load` and `Transform.post_load` will handle any important side effects which you
need in order to save the object.

In [None]:
t = Trainer(
    train_model=train_model,
    infer_model=infer_model,
    optimizer=torch.optim.Adam(train_model.pd_parameters()),
    metrics={'accuracy': accuracy}
)

In [None]:
metric_data = [x[0] for x in valid_data]
ground_truth = [x[1] for x in valid_data]

try:
    t.train(train_data, 'train.padl', valid_data=valid_data,
            save_interval=100, batch_size=100, metric_data=metric_data, ground_truth=ground_truth)
except KeyboardInterrupt:
    print('quitting training...')

In [None]:
from padl import load

s = load('train.padl')

In [None]:
try:
    s.train(train_data, 'other.padl', save_interval=100)
except KeyboardInterrupt:
    print('quitting training')
    

In [None]:
r = padl.load('other.padl')

In [None]:
r.infer_model

In [None]:
from IPython.display import display
import random

for _ in range(10):
    datapoint = random.choice(metric_data)
    display(datapoint)
    print(r.infer_model.infer_apply(datapoint))