<a href="https://colab.research.google.com/github/dvarkless/InnopolisDS/blob/main/homeworks/pytorch_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import sys
from collections import OrderedDict
from typing import OrderedDict as OrderedDictType

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, Precision, Recall
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor


class SeqModeler(nn.Sequential):
    def __init__(self, ord_dict: OrderedDictType, device='cpu') -> None:
        super().__init__(ord_dict)
        self.device = torch.device(device)
        self.to(self.device, non_blocking=True)

    def predict(self, X):
        X = self.forward(X)
        return torch.argmax(nn.Softmax(dim=1)(X), dim=1)


class Trainer:
    __defaults = {
        'batch_size': 100,
        'device': 'cpu',
        'epochs': 20,
        'criterion': nn.CrossEntropyLoss(),
        'enable_print': False,
        'metrics': None,
    }
    __must_have_params = ['model_class', 'model_params', 'set_optimizer']

    def __init__(self, **hp) -> None:
        self.config = self.__defaults.copy()
        for name, val in hp.items():
            self.config[name] = val

        for name in self.__must_have_params:
            if name not in self.config:
                print(f'Error: config parameter "{name}" is missing')
                sys.exit(1)

        self.model = self.config['model_class'](**self.config['model_params'])
        opt_config = self.config['set_optimizer'].copy()
        opt_config['params'] = self.model.parameters()
        optimizer_name = opt_config.pop('name')
        self.optimizer = getattr(torch.optim, optimizer_name)(**opt_config)
        self.criterion = self.config['criterion']
        self.device = torch.device(self.config['device'])

    @property
    def data_batch(self):
        return self._data_batch

    @data_batch.setter
    def data_batch(self, data, /):
        if isinstance(data, torch.Tensor):
            self._data_batch = data.to(self.device, non_blocking=True).float()
        elif isinstance(data, np.ndarray):
            self.data_batch = torch.Tensor(data)
        else:
            raise ValueError(f'data of type {type(data)} is unacceptable')

    @property
    def targets_batch(self):
        return self._targets_batch

    @targets_batch.setter
    def targets_batch(self, targets):
        if isinstance(targets, torch.Tensor):
            self._targets_batch = targets.to(
                self.device, non_blocking=True)
        elif isinstance(targets, (np.ndarray, list, tuple)):
            self.targets_batch = torch.Tensor(targets)
        else:
            raise ValueError(f'data of type {type(targets)} is unacceptable')

    def fit(self, train_dataset, eval_dataset=None):
        train_dl = DataLoader(train_dataset, self.config['batch_size'])
        for epoch in range(self.config['epochs']):
            avg_loss = []
            for (inputs, targets) in train_dl:
                self.data_batch, self.targets_batch = inputs, targets
                self.optimizer.zero_grad()
                yhat = self.model(self.data_batch)
                loss = self.criterion(yhat, self.targets_batch)
                avg_loss.append(loss)
                loss.backward()
                self.optimizer.step()
            avg_loss = torch.Tensor(avg_loss).mean()
            avg_loss.to(self.device)
            if self.config['enable_print']:
                print(
                    f'==========Epoch {epoch+1}/{self.config["epochs"]}==========')
                print(f'Loss: {avg_loss}')
                if self.config['metrics'] and eval_dataset:
                    metric_data = self.evaluate(eval_dataset)
                    for metric, data in zip(self.config['metrics'], metric_data):
                        print(f'{metric.__class__.__name__} = {data:.3f}')
        return self

    def evaluate(self, eval_dataset):
        eval_dl = DataLoader(eval_dataset, batch_size=10000)
        for data, targets in eval_dl:
            self.data_batch, self.targets_batch = data, targets
            predictions = self.model.predict(self.data_batch)
            metric_data = []
            for metric in self.config['metrics']:
                metric_data.append(metric(predictions, self.targets_batch))
            return tuple(metric_data)


if __name__ == "__main__":
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    trans = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
    train = MNIST('data/', train=True, download=True, transform=trans)
    test = MNIST('data/', train=False, download=True, transform=trans)
    model_params = OrderedDict([
        ('batch1', nn.BatchNorm2d(1)),
        ('conv1', nn.Conv2d(1, 16, (2, 2), stride=1, padding=1)),
        ('relu1', nn.ReLU()),
        ('maxpool1', nn.MaxPool2d((2, 2))),

        ('batch2', nn.BatchNorm2d(16)),
        ('conv2', nn.Conv2d(16, 64, (2, 2), stride=1, padding=1)),
        ('relu2', nn.ReLU()),
        ('maxpool2', nn.MaxPool2d((2, 2))),

        ('flatten3', nn.Flatten()),
        ('batch3', nn.BatchNorm1d(64*7*7)),
        ('linear3', nn.Linear(64*7*7, 100)),
        ('relu3', nn.ReLU()),
        ('linear4', nn.Linear(100, 10)),
        ('relu4', nn.ReLU()),
    ])
    optim_params = {
        'name': 'SGD',
        'params': None,
        'lr': 1e-4,
        'momentum': 0.9,
    }
    trainer_hp = {
        'batch_size': 50,
        'model_class': SeqModeler,
        'model_params': {'ord_dict': model_params, 'device': device},
        'set_optimizer': optim_params,
        'device': device,
        'criterion': nn.CrossEntropyLoss(),
        'enable_print': True,
        'metrics': [Accuracy(num_classes=10, average='macro').to(device), Recall(num_classes=10, average='macro').to(device), Precision(num_classes=10, average='macro').to(device)]
    }

    trainer = Trainer(**trainer_hp).fit(train, test)

Loss: 1.2967389822006226
Accuracy = 0.884
Recall = 0.884
Precision = 0.888
Loss: 0.4429587721824646
Accuracy = 0.923
Recall = 0.923
Precision = 0.923
Loss: 0.2905564308166504
Accuracy = 0.940
Recall = 0.940
Precision = 0.940
Loss: 0.2298288494348526
Accuracy = 0.950
Recall = 0.950
Precision = 0.950
Loss: 0.19461113214492798
Accuracy = 0.954
Recall = 0.954
Precision = 0.954
Loss: 0.1706569343805313
Accuracy = 0.959
Recall = 0.959
Precision = 0.959
Loss: 0.15299151837825775
Accuracy = 0.964
Recall = 0.964
Precision = 0.964
Loss: 0.1392328441143036
Accuracy = 0.966
Recall = 0.966
Precision = 0.966
Loss: 0.12814292311668396
Accuracy = 0.969
Recall = 0.969
Precision = 0.969
Loss: 0.11899322271347046
Accuracy = 0.971
Recall = 0.971
Precision = 0.971
Loss: 0.11129086464643478
Accuracy = 0.973
Recall = 0.973
Precision = 0.973
Loss: 0.1046813577413559
Accuracy = 0.974
Recall = 0.974
Precision = 0.974
Loss: 0.09894908964633942
Accuracy = 0.975
Recall = 0.975
Precision = 0.976
Loss: 0.09392621368