# EMTL Usage Example
This notebook shows an example of how we can use EMTL to study the training of a multi-task model on MNIST and Fashion-MNIST.

In [1]:
# import the project directory here to find the emtl package
import os, sys
project_dir = os.path.abspath('..')

# if the kernel wasn't restarted, the folder might still be there
if project_dir not in sys.path: 
    sys.path.append(project_dir)

In [2]:
# PyTorch Imports
import torch
import torch.nn as nn
from torchvision import transforms as T
from torchvision import datasets as D

# EMTL Library Imports
from emtl import SimpleTask, Trainer
from emtl.algorithms import SequentialTraining

### Models
Here are defined the models we will use for this example. We are looking at a LeNet5 model from Yann LeCun, which we split into two submodules: one containing the *convolutions* (our encoder / backbone), and one with the full connections (specialized head).

In [None]:
class LeNetConvolutions(nn.Module):
    def __init__(self):
        super(LeNetConvolutions, self).__init__()

        self.cnn1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))

        self.cnn2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        
        self.flatten = nn.Flatten()

    def forward(self, x):
        # convolutions
        out = self.cnn1(x)
        out = self.cnn2(out)

        # flatten
        out = self.flatten(out)
        return out


class LeNetFullConnections(nn.Module):
    def __init__(self):
        super(LeNetFullConnections, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(400, 120), nn.ReLU(),
            nn.Linear(120, 84), nn.ReLU(),
            nn.Linear(84, 10)
        )

    def forward(self, x, **_):
        return self.model(x)

### Datasets
We create two datasets: MNIST and Fashion-MNIST, each split in train and test.

In [None]:
dataset_params = {
    'batch_size': 128,
    'num_workers': 8,
    'pin_memory': True
}

transform = T.Compose([
    T.Resize((32,32)),
    T.ToTensor(),
    T.Normalize(mean=(0.1307,), std=(0.3081,))
])

# Datsets
MNIST_trainset = D.MNIST(root='./data', train=True, download=True, transform=transform)
MNIST_testset = D.MNIST(root='./data', train=False, download=True, transform=transform)

FMNIST_trainset = D.FashionMNIST(root='./data', train=True, download=True, transform=transform)
FMNIST_testset = D.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [None]:
def top_k_acc(k, pred, true):
    # get indices of k highest values along last axis
    kbest = pred.argsort(-1)[:,-k:]

    # find any matches along last axis (expanding the labels to match the shape)
    bool_matches = torch.eq(true[:, None], kbest).any(dim=-1)

    # return the mean
    return bool_matches.float().mean().item()

In [None]:
MNIST_task = SimpleTask(
    name = 'MNIST',
    head = LeNetFullConnections(),
    trainset = MNIST_trainset, 
    testset = MNIST_testset,
    dataloader_params=dataset_params,
    criterion = nn.CrossEntropyLoss(),
    optimizer_fn = torch.optim.Adam,
    metric_fns = {
        'accuracy': lambda pred, true : top_k_acc(1, pred, true),
        'top-2 accuracy': lambda pred, true : top_k_acc(2, pred, true)
    })

FMNIST_task = SimpleTask(
    name = 'Fashion-MNIST',
    head = LeNetFullConnections(),
    trainset = FMNIST_trainset, 
    testset = FMNIST_testset,
    dataloader_params=dataset_params,
    criterion = nn.CrossEntropyLoss(),
    optimizer_fn = torch.optim.Adam,
    metric_fns = {
        'accuracy': lambda pred, true : top_k_acc(1, pred, true),
        'top-2 accuracy': lambda pred, true : top_k_acc(2, pred, true)
    })

In [None]:
trainer = Trainer(
    backbone = LeNetConvolutions(),
    tasks = [MNIST_task, FMNIST_task],
    algorithm = SequentialTraining(epochs=3),
    config='config.ini'
)

# train the model
trainer.launch()