# <font style="color:blue">Combine them all: LeNet5 pipeline with Trainer</font>

Let's take a look at how we can build the training pipeline using the Trainer helper class and the other helper classes we've discussed before in this notebook.
Import all the necessary classes and functions:

In [None]:
# %matplotlib notebook
# %load_ext autoreload
# %autoreload 2

from operator import itemgetter

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision import datasets, transforms

from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR

from trainer import Trainer, hooks, configuration
from trainer.utils import setup_system, patch_configs
from trainer.metrics import AccuracyEstimator
from trainer.tensorboard_visualizer import TensorBoardVisualizer

import matplotlib.pyplot as plt

import os
import numpy as np
import pandas as pd
import time

from torch.utils.data import Dataset, DataLoader, random_split
# from typing import Iterable
# from dataclasses import dataclass

from PIL import Image

## <font style="color:Green">1. Get Training and Validation Data Loader</font>


Define the data wrappers and transformations (the same way as before):

In [None]:
class KenyanFood13Dataset(Dataset):
    """
    This custom dataset class takes root directory and train flag, 
    and returns dataset training dataset if train flag is true 
    else it returns validation dataset.
    """
    
    def __init__(self, data_root, image_shape=None, transform=None):
        
        """
        init method of the class.
        
         Parameters:
         
         data_root (string): path of root directory.
         
         train (boolean): True for training dataset and False for test dataset.
         
         image_shape (int or tuple or list): [optional] int or tuple or list. Defaut is None. 
                                             If it is not None image will resize to the given shape.
                                 
         transform (method): method that will take PIL image and transform it.
         
        """
        
        # get label to species mapping
        label_csv_path = os.path.join(data_root, 'train.csv')
        self.data_df = pd.read_csv(label_csv_path, delimiter=' *, *', engine='python')
        self.classes = self.data_df.iloc[:, 1].unique()
        self.num_classes = len(self.classes)
        self.image_ids = self.data_df.iloc[:, 0]

        self.class_given_label = {image_id : image_class for image_id, image_class in enumerate(self.classes)}
        self.label_given_class = {image_class : image_id for image_id, image_class in enumerate(self.classes)}
        
        # set image_resize attribute
        if image_shape is not None:
            if isinstance(image_shape, int):
                self.image_shape = (image_shape, image_shape)
            
            elif isinstance(image_shape, tuple) or isinstance(image_shape, list):
                assert len(image_shape) == 1 or len(image_shape) == 2, 'Invalid image_shape tuple size'
                if len(image_shape) == 1:
                    self.image_shape = (image_shape[0], image_shape[0])
                else:
                    self.image_shape = image_shape
            else:
                raise NotImplementedError 
                
        else:
            self.image_shape = image_shape
            
        # set transform attribute
        self.transform = transform

        # initialize the data dictionary
        self.data_dict = {
            'image_path': [],
            'label': []
        }
        img_dir = os.path.join(data_root, 'images', 'images')

        # print("self.data_df", type(self.data_df))
        for data in self.data_df.iterrows():
            image_id = str(data[1]['id']) + '.jpg'
            image_path = os.path.join(img_dir, image_id)
            image_class = data[1]['class']
            label = self.label_given_class[image_class]
            self.data_dict['image_path'].append(image_path)
            self.data_dict['label'].append(label)

    
    def __len__(self):
        """
        return length of the dataset
        """
        return len(self.data_dict['label'])
    
    
    def __getitem__(self, idx):
        """
        For given index, return images with resize and preprocessing.
        """
        
        image = Image.open(self.data_dict['image_path'][idx]).convert("RGB")
        
        if self.image_shape is not None:
            image = F.resize(image, self.image_shape)
            
        if self.transform is not None:
            image = self.transform(image)
            
        target = self.data_dict['label'][idx]
        
        return image, target            
                
        
    def class_name(self, label):
        """
        class label to common name mapping
        """
        return self.class_given_label[label]

In [None]:
def get_data(batch_size, data_root='data', num_workers=1):
    train_test_transforms = transforms.Compose([
        # Resize to 32X32
        # transforms.Resize((32, 32)),
        # this re-scale image tensor values between 0-1. image_tensor /= 255
        # transforms.ToTensor(),
        # subtract mean (0.1307) and divide by variance (0.3081).
        # This mean and variance is calculated on training data (verify yourself)
        # transforms.Normalize((0.1307, ), (0.3081, ))
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])

    dataset =  KenyanFood13Dataset(data_root, image_shape=None, transform=train_test_transforms)

    train_size = int(0.8 * len(dataset)) # 80% for training
    test_size = len(dataset) - train_size # 20% for validation

    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    # dataloader with dataset
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

    # test dataloader
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )

    # # train dataloader
    # train_loader = torch.utils.data.DataLoader(
    #     datasets.MNIST(root=data_root, train=True, download=True, transform=train_test_transforms),
    #     batch_size=batch_size,
    #     shuffle=True,
    #     num_workers=num_workers
    # )

    # # test dataloader
    # test_loader = torch.utils.data.DataLoader(
    #     datasets.MNIST(root=data_root, train=False, download=True, transform=train_test_transforms),
    #     batch_size=batch_size,
    #     shuffle=False,
    #     num_workers=num_workers
    # )
    return train_loader, test_loader

In [None]:
train_loader, test_loader = get_data(batch_size=15, data_root='../../../../data/Week7_project2_classification/KenyanFood13Dataset', num_workers=1)

In [None]:
# Plot few images
plt.rcParams["figure.figsize"] = (15, 9)
plt.figure
for images, labels in test_loader:
    for i in range(len(labels)):
        plt.subplot(3, 5, i+1)
        img = F.to_pil_image(images[i])
        plt.imshow(img)
        plt.gca().set_title('Target: {0}'.format(labels[i]))
    plt.show()
    break

## <font style="color:Green">2. Define the Model</font>

Define the model (the same way as before):

In [None]:
# class LeNet5(nn.Module):
#     def __init__(self):
#         super().__init__()

#         self._body = nn.Sequential(
#             nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
#             nn.ReLU(inplace=True),
#             nn.MaxPool2d(kernel_size=2),
#             nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
#             nn.ReLU(inplace=True),
#             nn.MaxPool2d(kernel_size=2),
#         )
#         self._head = nn.Sequential(
#             nn.Linear(in_features=16 * 5 * 5, out_features=120), nn.ReLU(inplace=True),
#             nn.Linear(in_features=120, out_features=84), nn.ReLU(inplace=True),
#             nn.Linear(in_features=84, out_features=10)
#         )

#     def forward(self, x):
#         x = self._body(x)
#         x = x.view(x.size()[0], -1)
#         x = self._head(x)
#         return x

In [None]:
class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()

        # convolution layers
        self._body = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=7),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )

        
        # Fully connected layers
        self._head = nn.Sequential(
            nn.Linear(in_features=64*52*52, out_features=1024), 
            nn.ReLU(inplace=True),
            
            nn.Linear(in_features=1024, out_features=13)
            
        )
    
    def forward(self, x):
        
        # apply feature extractor
        x = self._body(x)
        # flatten the output of conv layers
        # dimension should be batch_size * number_of weight_in_last conv_layer
        x = x.view(x.size()[0], -1)
        # apply classification head
        x = self._head(x)
        
        
        return x

## <font style="color:Green">3. Start Experiment / Training</font>


Define the experiment with the given model and given data. It's the same idea again: we keep the less-likely-to-change things inside the object and configure it with the things that are more likely to change.

You may wonder, why do we put the specific metric and optimizer into the experiment code and not specify them as parameters. But the experiment class is just a handy way to store all the parts of your experiment in one place. If you change the loss function, or the optimizer, or the model - it seems like another experiment, right? So it deserves to be a separate class.

The Trainer class inner structure is a bit more complicated compared to what we've discussed above - it is just to be able to cope with the different kinds of the tasks we will discuss in this course. We will elaborate a bit more on the Trainer inner structure in the following lectures and now take a look at how compact and self-descriptive the code is:

In [None]:
class Experiment:
    def __init__(
        self,
        system_config: configuration.SystemConfig = configuration.SystemConfig(),
        dataset_config: configuration.DatasetConfig = configuration.DatasetConfig(),
        dataloader_config: configuration.DataloaderConfig = configuration.DataloaderConfig(),
        optimizer_config: configuration.OptimizerConfig = configuration.OptimizerConfig()
    ):
        self.loader_train, self.loader_test = get_data(
            batch_size=dataloader_config.batch_size,
            num_workers=dataloader_config.num_workers,
            data_root=dataset_config.root_dir
        )
        
        setup_system(system_config)

        self.model = LeNet5()
        self.loss_fn = nn.CrossEntropyLoss()
        self.metric_fn = AccuracyEstimator(topk=(1, ))
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=optimizer_config.learning_rate,
            weight_decay=optimizer_config.weight_decay,
            momentum=optimizer_config.momentum
        )
        self.lr_scheduler = MultiStepLR(
            self.optimizer, milestones=optimizer_config.lr_step_milestones, gamma=optimizer_config.lr_gamma
        )
        self.visualizer = TensorBoardVisualizer()

    def run(self, trainer_config: configuration.TrainerConfig) -> dict:

        device = torch.device(trainer_config.device)
        self.model = self.model.to(device)
        self.loss_fn = self.loss_fn.to(device)

        model_trainer = Trainer(
            model=self.model,
            loader_train=self.loader_train,
            loader_test=self.loader_test,
            loss_fn=self.loss_fn,
            metric_fn=self.metric_fn,
            optimizer=self.optimizer,
            lr_scheduler=self.lr_scheduler,
            device=device,
            data_getter=itemgetter(0),
            target_getter=itemgetter(1),
            stage_progress=trainer_config.progress_bar,
            get_key_metric=itemgetter("top1"),
            visualizer=self.visualizer,
            model_saving_frequency=trainer_config.model_saving_frequency,
            save_dir=trainer_config.model_dir
        )

        model_trainer.register_hook("end_epoch", hooks.end_epoch_hook_classification)
        self.metrics = model_trainer.fit(trainer_config.epoch_num)
        return self.metrics


In [None]:
def main():
    '''Run the experiment
    '''
    # patch configs depending on cuda availability
    dataloader_config, trainer_config = patch_configs(epoch_num_to_set=1)#5)
    # dataset_config = configuration.DatasetConfig(root_dir="data")
    dataset_config = configuration.DatasetConfig(root_dir="../../../../data/Week7_project2_classification/KenyanFood13Dataset")
    experiment = Experiment(dataset_config=dataset_config, dataloader_config=dataloader_config)
    results = experiment.run(trainer_config)

    return results

In [None]:
if __name__ == '__main__':
    main()

So in a few lines of code, we got a more robust system that we had before - we have richer visualizations, a more configurable training process, and we separated the pipeline for the training from the model - so we can concentrate on the things that matter the most.

# <font style="color:blue">References</font>

You may wonder whether it is a common way of doing deep learning or we're doing overengineering here. We may assure you that this is a common way to do deep learning research in an industry - most of the companies and research groups invest in building these DL training frameworks for their projects, and some of them are even published to the open-source. To name a couple of them:
- https://github.com/NVlabs/SPADE
- https://github.com/pytorch/ignite
- https://github.com/PyTorchLightning/pytorch-lightning
- https://github.com/catalyst-team/catalyst
- https://github.com/open-mmlab/mmdetection
- https://github.com/fastai/fastai