# Image Classification Task (MNIST)

In this example we will demonstrate how to train and evaluate an Image Classification model on the MNIST dataset using PyTorchWrapper.

#### Additional libraries

First of all we need to install the `torchvision` library in order to download the data.

In [None]:
! pip install torchvision


#### Import Statements

In [None]:
import torch
import torchvision
import math
import random
import numpy as np

from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from pytorch_wrapper import modules, System
from pytorch_wrapper import evaluators as evaluators
from pytorch_wrapper.loss_wrappers import GenericPointWiseLossWrapper
from pytorch_wrapper.training_callbacks import EarlyStoppingCriterionCallback


#### Dataset Definition
Since `torchvision` provides ready to use `Dataset` object for the MNIST Dataset we just need to wrap it with a custom class in order to adhere to the requirements of PyTorchWrapper, i.e. the data loaders must represent a batch as a dictionary.

In [None]:
class MNISTDatasetWrapper(Dataset):
    def __init__(self, is_train):
        self.dataset = MNIST(
            'data/mnist/',
            train=is_train,
            download=True,
            transform=torchvision.transforms.ToTensor()
        )

    def __getitem__(self, index):
        return {'input': self.dataset[index][0], 'target': self.dataset[index][1]}

    def __len__(self):
        return len(self.dataset)


#### Model Definition
We will use a simple 2-layer CNN with an MLP on top.

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, padding=2),
            nn.Dropout(p=0.2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, padding=2),
            nn.Dropout(p=0.2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.out_mlp = modules.MLP(
            input_size=980,
            num_hidden_layers=1,
            hidden_layer_size=128,
            hidden_activation=nn.ReLU,
            output_size=10,
            output_activation=None
        )

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.shape[0], -1)
        return self.out_mlp(x)


#### Training

Next we create the dataset object along with three data loaders (for training, validation, and testing).

In [None]:
train_val_dataset = MNISTDatasetWrapper(is_train=True)
test_dataset = MNISTDatasetWrapper(is_train=False)

# Use 10% of the training dataset as validation.
val_size = math.floor(0.1 * len(train_val_dataset))
train_val_indexes = list(range(len(train_val_dataset)))
random.seed(12345)
random.shuffle(train_val_indexes)
train_indexes = train_val_indexes[val_size:]
val_indexes = train_val_indexes[:val_size]

train_dataloader = DataLoader(
    train_val_dataset,
    sampler=SubsetRandomSampler(train_indexes),
    batch_size=128
)

val_dataloader = DataLoader(
    train_val_dataset,
    sampler=SubsetRandomSampler(val_indexes),
    batch_size=128
)

test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)


Then we create the model and we wrap it with a `pytorch_wrapper.System` object.

In [None]:
model = Model()

last_activation = nn.Softmax(dim=-1)
if torch.cuda.is_available():
    system = System(model, last_activation=last_activation, device=torch.device('cuda'))
else:
    system = System(model, last_activation=last_activation, device=torch.device('cpu'))

Next we train the model on the training set, using the validation set for early stopping.

In [None]:
loss_wrapper = GenericPointWiseLossWrapper(nn.CrossEntropyLoss())
evals = {

    'prec': evaluators.MultiClassPrecisionEvaluator(average='macro'),
    'rec': evaluators.MultiClassRecallEvaluator(average='macro'),
    'f1': evaluators.MultiClassF1Evaluator(average='macro')

}

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, system.model.parameters()))

_ = system.train(
    loss_wrapper,
    optimizer,
    train_data_loader=train_dataloader,
    evaluators=evals,
    evaluation_data_loaders={
        'val': val_dataloader
    },
    callbacks=[
        EarlyStoppingCriterionCallback(
            patience=3,
            evaluation_data_loader_key='val',
            evaluator_key='f1',
            tmp_best_state_filepath='data/mnist_tmp_best.weights'
        )
    ]
)


Next we evaluate the model.

In [None]:
results = system.evaluate(test_dataloader, evals)
for r in results:
    print(results[r])


We can also use the `predict` method in order to predict for all the examples returned by a dataloder.

In [None]:
predictions = system.predict(test_dataloader, perform_last_activation=True)


In [None]:
ex = 599
print(f'Prediction for ex {ex}: {np.argmax(predictions["outputs"][ex])}')
print(f'Label of ex {ex}: {test_dataset[ex]["target"]}')


Finally we save the model's weights.

In [None]:
system.save_model_state('data/mnist_final.weights')
