# <font style="color:blue">Combine them all: LeNet5 pipeline with Trainer</font>

Let's take a look at how we can build the training pipeline using the Trainer helper class and the other helper classes we've discussed before in this notebook.
Import all the necessary classes and functions:

In [1]:
# %matplotlib notebook
# %load_ext autoreload
# %autoreload 2

from operator import itemgetter

import torch
import torch.nn as nn
import torch.optim as optim

# from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.transforms import functional as Fn
from torchvision import datasets, transforms

# from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR

from trainer import Trainer, hooks, configuration

from trainer.trainer_dataset import KenyanFood13Dataset, TransformedSubset
from trainer.test_dataset import KenyanFood13DatasetTest

from trainer.experinment_utils import get_mean_std, get_data
from trainer.configuration import Model
from trainer.utils import setup_system, patch_configs
from trainer.metrics import AccuracyEstimator
from trainer.tensorboard_visualizer import TensorBoardVisualizer

import matplotlib.pyplot as plt

import os
import numpy as np
import pandas as pd
import time

from torch.utils.data import Dataset, DataLoader, random_split

from PIL import Image
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'trainer.configuration'

## <font style="color:Green">1. Get Training and Validation Data Loader</font>


Define the data wrappers and transformations (the same way as before):

In [None]:
# def PlotLoader(loader):
#     # Plot few images
#     plt.rcParams["figure.figsize"] = (15, 9)
#     plt.figure
#     for images, labels in loader:
#         for i in range(len(labels)):
#             plt.subplot(3, 5, i+1)
#             img = Fn.to_pil_image(images[i])
#             plt.imshow(img)
#             plt.gca().set_title('Target: {0}'.format(labels[i]))
#         plt.show()
#         break

In [None]:
# train_loader, test_loader, train_mean, train_std, classes = get_data(batch_size=15, data_root='../../../../data/Week7_project2_classification/KenyanFood13Dataset', num_workers=1)
# print(classes)
# PlotLoader(train_loader)
# print("---------")
# PlotLoader(test_loader)

## <font style="color:Green">2. Define the Model</font>

Define the model (the same way as before):

## <font style="color:Green">3. Start Experiment / Training</font>


Define the experiment with the given model and given data. It's the same idea again: we keep the less-likely-to-change things inside the object and configure it with the things that are more likely to change.

You may wonder, why do we put the specific metric and optimizer into the experiment code and not specify them as parameters. But the experiment class is just a handy way to store all the parts of your experiment in one place. If you change the loss function, or the optimizer, or the model - it seems like another experiment, right? So it deserves to be a separate class.

The Trainer class inner structure is a bit more complicated compared to what we've discussed above - it is just to be able to cope with the different kinds of the tasks we will discuss in this course. We will elaborate a bit more on the Trainer inner structure in the following lectures and now take a look at how compact and self-descriptive the code is:

In [None]:
class Experiment:
    def __init__(
        self,
        system_config: configuration.SystemConfig = configuration.SystemConfig(),
        dataset_config: configuration.DatasetConfig = configuration.DatasetConfig(),
        dataloader_config: configuration.DataloaderConfig = configuration.DataloaderConfig(),
        optimizer_config: configuration.OptimizerConfig = configuration.OptimizerConfig(),
        trainer_config: configuration.TrainerConfig = configuration.TrainerConfig()
    ):
        self.loader_train, self.loader_test, self.train_mean, self.train_std, self.labels = get_data(
            batch_size=dataloader_config.batch_size,
            num_workers=dataloader_config.num_workers,
            data_root=dataset_config.root_dir
        )
        
        setup_system(system_config)

        self.model = Model1()
        self.loss_fn = nn.CrossEntropyLoss()
        self.metric_fn = AccuracyEstimator(topk=(1, ))
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=optimizer_config.learning_rate,
            weight_decay=optimizer_config.weight_decay,
            momentum=optimizer_config.momentum
        )
        self.lr_scheduler = MultiStepLR(
            self.optimizer, milestones=optimizer_config.lr_step_milestones, gamma=optimizer_config.lr_gamma
        )
        self.visualizer = TensorBoardVisualizer(trainer_config.tensor_board_dir)

    def run(self, trainer_config: configuration.TrainerConfig) -> dict:

        device = torch.device(trainer_config.device)
        self.model = self.model.to(device)
        self.loss_fn = self.loss_fn.to(device)

        model_trainer = Trainer(
            model=self.model,
            loader_train=self.loader_train,
            loader_test=self.loader_test,
            loss_fn=self.loss_fn,
            metric_fn=self.metric_fn,
            optimizer=self.optimizer,
            lr_scheduler=self.lr_scheduler,
            device=device,
            data_getter=itemgetter(0),
            target_getter=itemgetter(1),
            stage_progress=trainer_config.progress_bar,
            get_key_metric=itemgetter("top1"),
            visualizer=self.visualizer,
            model_saving_frequency=trainer_config.model_saving_frequency,
            save_dir=trainer_config.model_dir,
            model_name_prefix=trainer_config.trainer_name
        )
        
        model_trainer.register_hook("end_epoch", hooks.end_epoch_hook_classification)
        self.metrics = model_trainer.fit(trainer_config.epoch_num)
        return self.metrics, self.train_mean, self.train_std, self.labels


In [None]:
def main():
    '''Run the experiment
    '''
    # patch configs depending on cuda availability
    dataloader_config, trainer_config = patch_configs(epoch_num_to_set=5)
    dataset_config = configuration.DatasetConfig()
    experiment = Experiment(dataset_config=dataset_config, dataloader_config=dataloader_config)
    results, train_mean, train_std, labels = experiment.run(trainer_config)

    return results, train_mean, train_std, labels

In [None]:
if __name__ == '__main__':
    results, train_mean, train_std, labels = main()
    # print(train_mean, train_std, labels)

# <font style="color:blue">Predictions</font><a name="predictions"></a>

## <font style="color:blue">Make Predictions</font>

## <font style="color:blue">Get Predictions on a Batch</font>

In [None]:
dataset_config = configuration.DatasetConfig()
data_root = dataset_config.root_dir #'../../../../data/Week7_project2_classification/KenyanFood13Dataset'

# dataset =  KenyanFood13DatasetTest(data_root, image_shape=256)

# # print('Length of the dataset: {}'.format(len(dataset)))

# img, img_id = dataset[5]
# print(img.size)
# print('Image_id: {}'.format(img_id))
# plt.imshow(img)
# plt.show()

In [None]:
def load_model(model, model_dir, model_file_name):
    model_path = os.path.join(model_dir, model_file_name)

    # loading the model and getting model parameters by using load_state_dict
    checkpoint = torch.load(model_path)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    return model, epoch, loss

In [None]:
def prediction(model, device, batch_input):
    
#     data = batch_input.to(device)
    data = batch_input.to("cpu")
    
    with torch.no_grad():
        output = model(data)

    # Score to probability using softmax
    prob = F.softmax(output, dim=1)

    # get the max probability
    pred_prob = prob.data.max(dim=1)[0]
    
    # get the index of the max probability
    pred_index = prob.data.max(dim=1)[1]
    
    return pred_index.cpu().numpy(), pred_prob.cpu().numpy()

### <font style="color:green">Compulsary Preprocessing Transforms</font>

In [None]:
def image_compulsary_transforms():
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
        ])
    
    return preprocess

### <font style="color:green">Common Image Transforms</font>

In [None]:
def image_common_transforms(mean=(0.4611, 0.4359, 0.3905), std=(0.2193, 0.2150, 0.2109)):
    preprocess = image_compulsary_transforms()
    
    common_transforms = transforms.Compose([
        preprocess,
        transforms.Normalize(mean, std)
    ])
    
    return common_transforms

In [None]:
def get_sample_prediction(model, data_root, train_mean, train_std, labels, output_root):
    transforms.Normalize(train_mean, train_std)
    
    
    test_dataset_trans =  KenyanFood13DatasetTest(data_root, image_shape=None, transform=image_common_transforms(train_mean, train_std))
    
    batch_size = 15
    num_workers = 4

    if torch.cuda.is_available():
        device = "cuda"
        num_workers = 8
    else:
        device = "cpu"
        num_workers = 2
    
    # It is important to do model.eval() before prediction
    model.eval()
    
    # Send model to cpu/cuda according to your system configuration
#     model.to(device)
    model.to("cpu")


    
    
    data_len = test_dataset_trans.__len__()
    print("data_len: ", data_len)
    
    interval = 1 #int(data_len/batch_size)
    classes = []
    image_ids = []
    for start in range(0, data_len, batch_size):
        end = start + batch_size
        end = min(end, data_len)
        # print('start: {}, end: {}'.format(start, end))

        trans_images = []
        for index in range(start, end):
            trans_image, image_id = test_dataset_trans[index]
            # print('index: {}, img_id: {}'.format(index, img_id))
    
            trans_images.append(trans_image)
            image_ids.append(image_id)
        
        trans_images = torch.stack(trans_images)
        classes_index, prob = prediction(model, device, batch_input=trans_images)
        # print("classes_index:", classes_index)
        
        classes.extend([labels[class_index] for class_index in classes_index])
    
    data = {
        'id': image_ids,
        'class': classes
    }
    df = pd.DataFrame(data)
    
    label_csv_path = os.path.join(output_root, 'output.csv')
    df.to_csv(label_csv_path, sep=",", index=False)
    
    return

## <font style="color:blue">Load Model and Run Inference</font>

In [None]:
# m = LeNet5()
# model_dir = "./checkpoints"
# model_file_name = "checkpoint1.pt"
# model, epoch, loss = load_model(m, model_dir, model_file_name)
# print(epoch, loss)

In [None]:
# train_mean=torch.tensor([0.5772715211, 0.4631873667, 0.3466044068])
# train_std =torch.tensor([0.2699360847, 0.2737641633, 0.2830057442])
# labels = ['githeri', 'ugali', 'kachumbari', 'matoke', 'sukumawiki', 'bhaji', 'mandazi',
#  'kukuchoma', 'nyamachoma', 'pilau', 'chapati', 'masalachips', 'mukimo']
# get_sample_prediction(model, data_root, train_mean, train_std, labels, "./submissions/")

# # PlotLoader(test_loader)

So in a few lines of code, we got a more robust system that we had before - we have richer visualizations, a more configurable training process, and we separated the pipeline for the training from the model - so we can concentrate on the things that matter the most.

# <font style="color:blue">References</font>

You may wonder whether it is a common way of doing deep learning or we're doing overengineering here. We may assure you that this is a common way to do deep learning research in an industry - most of the companies and research groups invest in building these DL training frameworks for their projects, and some of them are even published to the open-source. To name a couple of them:
- https://github.com/NVlabs/SPADE
- https://github.com/pytorch/ignite
- https://github.com/PyTorchLightning/pytorch-lightning
- https://github.com/catalyst-team/catalyst
- https://github.com/open-mmlab/mmdetection
- https://github.com/fastai/fastai