<a href="https://colab.research.google.com/github/hoang1007/CodeSpace/blob/master/mnist_simple_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [63]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [64]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import statistics
import math
from tqdm import tqdm
import os
import pickle

# Implemet các model CNN

In [65]:
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channels, kernel_size):
        super().__init__()

        self.conv = nn.Conv2d(in_channel, out_channels, kernel_size)

        self.batch_norm = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):    
        x = self.conv(x)
        x = self.batch_norm(x)
        x = F.relu(x)

        return x

ConvBlock(3, 32, 3)(torch.rand((1, 3, 28, 28))).shape

torch.Size([1, 32, 26, 26])

## Model chỉ bao gồm các convolution block

In [66]:
class SingleCNN(nn.Module):
    def __init__(self, in_dim, out_dim, in_channel, block_channels, kernel_size):
        '''
        in_dim: int len of one dimensional of 2D image
        out_dim: int dimension of output
        '''
        super().__init__()

        self.conv_blocks = nn.Sequential(*[
            ConvBlock(
                in_channel if i == 0 else block_channels[i - 1],
                block_channels[i],
                kernel_size
            ) for i in range(len(block_channels))
        ])

        self.linear = self._linear_block(in_dim, out_dim, block_channels, kernel_size)
        self.batch_norm = nn.BatchNorm1d(out_dim)
    
    def _linear_block(self, in_dim, out_dim, block_channels, kernel_size):
        conv_out_dim = in_dim - (kernel_size - 1) * len(block_channels)
        linear_input_dim = conv_out_dim**2 * block_channels[-1]

        linear = nn.Linear(linear_input_dim, out_dim)
        return linear

    def forward(self, x):
        batch_size = x.size(0)
        
        # conv_outputs.shape == (batch_size, channels, img_size, img_size)
        conv_outputs = self.conv_blocks(x)
        conv_outputs = torch.flatten(conv_outputs, start_dim=1)

        linear_outputs = self.linear(conv_outputs)
        linear_outputs = self.batch_norm(linear_outputs)

        return linear_outputs


SingleCNN(28, 10, 3, [5, 7], 3)(torch.rand((2, 3, 28, 28))).shape

torch.Size([2, 10])

# Utilities

## Majority Voting

In [67]:
import random

def arg_max_frequency(inputs: torch.tensor):
    '''
    inputs: vector
    '''
    freq = {}

    for arg in inputs:
        if arg.item() in freq:
            freq[arg.item()] += 1
        else:
            freq[arg.item()] = 1

    max_freq = 0
    args_max = []

    for arg in freq:
        if max_freq == freq[arg]:
            args_max.append(arg)
        elif max_freq < freq[arg]:
            max_freq = freq[arg]
            args_max.clear()

            args_max.append(arg)

    return random.choice(args_max)

def majority_voting(*outputs):
    '''
    output.shape == (batch_size, num_classes)
    '''

    # outputs.shape == (num_predict, batch_size)
    outputs = torch.stack([torch.argmax(output, dim=-1) for output in outputs])
    outputs = outputs.transpose(0, 1) # (batch_size, num_predict)

    votings = torch.tensor([arg_max_frequency(output) for output in outputs])

    return votings.type_as(outputs)

majority_voting(
    torch.rand((5, 4)),
    torch.rand((5, 4)),
    torch.rand((5, 4))
)


tensor([3, 0, 3, 3, 3])

## Setup auto recovering model and training state

In [68]:
class IModel(nn.Module):
    def __init__(self, model_name=None):
        super().__init__()
        self.device = "cpu"
        self._model_name = model_name
        self._state = {}

    def configure_optimizers(self, *args, **kwargs):
        raise NotImplementedError()

    def training_step(self, batch, batch_idx):
        raise NotImplementedError()

    def validation_step(self, batch, batch_idx):
        raise NotImplementedError()

    def training_epoch_end(self, train_outputs, epoch):
        pass

    def validation_epoch_end(self, val_outputs, epoch):
        pass

    def configure_optimizers(self):
        raise NotImplementedError()

    def restore(self):
        raise NotImplementedError()

    @property
    def state(self):
        return self._state

    @property
    def name(self):
        return self._model_name

    def to(self, device):
        self.device = device
        return super().to(device)

In [69]:
class Trainer:
    def __init__(self, checkpoint_dir, restore_if_available=True, update_bar_fraction=.02):
        self.checkpoint_dir = checkpoint_dir
        self.restore_if_available = restore_if_available
        self.update_bar_fraction = update_bar_fraction

    def fit(self, model: IModel, train_dataloader, val_dataloader, epochs, device="cpu"):
        model_checkpoint = os.path.join(self.checkpoint_dir, model.name)
        
        pre_state = self._get_previous_state(model_checkpoint) \
            if self.restore_if_available else None

        if pre_state is not None:
            print("Restoring from last session")

            model.restore(pre_state["model_state"])
            start_epoch = pre_state["epoch"]
            optimizer = pre_state["optimizer"]
            scheduler = pre_state["scheduler"]
        else:
            start_epoch = 1
            optimizer, scheduler = model.configure_optimizers()

        model = model.to(device)

        # Run val test on the first run
        if pre_state is None:
          self._val_per_epoch(model, val_dataloader, 0)
        
        for epoch in range(start_epoch, epochs + 1):
            self._train_per_epoch(model, train_dataloader, optimizer, scheduler, epoch)

            self._val_per_epoch(model, val_dataloader, epoch)

            training_state = {
                "model_state": model.state,
                "epoch": epoch,
                "optimizer": optimizer,
                "scheduler": scheduler
            }

            if pre_state is None or epoch >= pre_state["epoch"]:
              self._save_state(model_checkpoint, training_state)
            else:
              print(f"""Skip backup model state on epoch {epoch}
                          since it's lower than backed up model""")
                                    

    def _train_per_epoch(self, model: IModel, train_dataloader, optimizer, scheduler, epoch):
        model.train()
        train_outputs = []

        update_bar_step = math.ceil(len(train_dataloader) * self.update_bar_fraction)

        with tqdm(train_dataloader, unit='batch') as training_bar:
            training_bar.set_description(f"Epoch {epoch}")

            loss_log = []

            for idx, batch in enumerate(training_bar):
                optimizer.zero_grad()

                train_output = model.training_step(batch, idx)

                train_outputs.append(train_output)

                loss = train_output[0] if isinstance(train_output, tuple) else train_output

                loss.backward()
                optimizer.step()

                loss_log.append(loss.item())
                
                if idx % update_bar_step == 0:
                    training_bar.set_postfix(loss=statistics.mean(loss_log))
                    loss_log.clear()
            scheduler.step()

            model.training_epoch_end(train_outputs, epoch)

    def _val_per_epoch(self, model: IModel, val_dataloader, epoch):
        model.eval()

        val_outputs = []

        with torch.no_grad():
            for idx, batch in enumerate(val_dataloader):
                val_output = model.validation_step(batch, idx)

                val_outputs.append(val_output)

            model.validation_epoch_end(val_outputs, epoch)

    def _get_previous_state(self, filepath):
        if os.path.exists(filepath):
            with open(filepath, "rb") as f:
                pre_state = pickle.load(f)
        else:
            pre_state = None

        return pre_state

    def _save_state(self, filepath, state):
        with open(filepath, "wb") as f:
            pickle.dump(state, f, pickle.HIGHEST_PROTOCOL)
        

# Training

In [70]:
class SingleModel(IModel):
    def __init__(self, block_channels, kernel_size, model_name):
        super().__init__(model_name)

        self.random_affine = transforms.RandomAffine(MAX_ROTATION_DEGREE, MAX_TRANSLATION_FRACTION)

        self.cnn = SingleCNN(IMG_SIZE, N_DIGITS, 1, block_channels, kernel_size)
        
    def forward(self, x):
        return self.cnn(x)

    def training_step(self, batch, batch_idx):
        imgs, labels = batch

        imgs = imgs.to(self.device)
        labels = labels.to(self.device)

        for i in range(imgs.size(0)):
            imgs[i] = self.random_affine(imgs[i])
        
        logits = self(imgs) # (batch_size, n_digits)

        loss = F.cross_entropy(logits, labels)

        return loss, logits, labels

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch

        imgs = imgs.to(self.device)
        labels = labels.to(self.device)

        logits = self(imgs) # (batch_size, n_digits)

        return logits, labels

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), LR, weight_decay=WEIGHT_DECAY)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=EXPO_GAMMA)

        return optimizer, scheduler

    def validation_epoch_end(self, val_outputs, epoch):
        num_correct = 0
        num_labels = 0

        for logits, labels in val_outputs:
            predicted = torch.argmax(logits, dim=-1)

            num_correct += torch.sum(predicted == labels).item()
            num_labels += predicted.size(0)

        accur = round(num_correct / num_labels, 2)

        if "val_accur" in self._state:
            self._state["val_accur"].append(accur)
        else:
            self._state["val_accur"] = [accur]

        print(f"Accuracy on epoch {epoch}:", num_correct / num_labels)

    def training_epoch_end(self, training_outputs, epoch):
        num_correct = 0
        num_labels = 0

        for _, logits, labels in training_outputs:
            predicted = torch.argmax(logits, dim=-1)

            num_correct += torch.sum(predicted == labels).item()
            num_labels += predicted.size(0)

        accur = round(num_correct / num_labels, 2)

        # backup state dict
        self._state["state_dict"] = self.state_dict()
        
        if "train_accur" in self._state:
            self._state["train_accur"].append(accur)
        else:
            self._state["train_accur"] = [accur]

    def restore(self, state):
        self.load_state_dict(state["state_dict"])
        self._state = state

In [71]:
IMG_SIZE = 28
N_DIGITS = 10
MAX_ROTATION_DEGREE = 20
MAX_TRANSLATION_FRACTION = (0.2, 0.2)
LR = 1e-3
WEIGHT_DECAY = 1e-4
EXPO_GAMMA = 0.999
BATCH_SIZE = 120
EPOCHS = 100
N_MODEL = 20
CHECKPOINT_DIR = "/content/drive/Shareddrives/colab/mnist_checkpoint/"

In [72]:
train_data = MNIST(".data", train=True, transform=transforms.ToTensor(), download=True)
val_data = MNIST(".data", train=False, transform=transforms.ToTensor(), download=True)

In [73]:
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE)

In [74]:
m3_models = [
  SingleModel([16 * (i + 1) for i in range(1, 10 + 1)],
              3, f"M3_{k}")
    for k in range(N_MODEL)
]

m5_models = [
  SingleModel([32 * i for i in range(1, 5 + 1)],
              5, f"M5_{k}")
    for k in range(N_MODEL)
]

m7_models = [
  SingleModel([48 * i for i in range(1, 4 + 1)],
              7, f"M7_{k}")
    for k in range(N_MODEL)
]

In [None]:
# Train m3 models

for model in m3_models:
  trainer = Trainer(CHECKPOINT_DIR)

  trainer.fit(model, train_dataloader, val_dataloader,
              epochs=EPOCHS, device="cuda" if torch.cuda.is_available() else "cpu")

Accuracy on epoch 0: 0.1135


Epoch 1:  36%|███▌      | 178/500 [00:40<01:12,  4.41batch/s, loss=0.476]

In [None]:
# Train m5 models

for model in m5_models:
  trainer = Trainer(CHECKPOINT_DIR)

  trainer.fit(model, train_dataloader, val_dataloader,
              epochs=EPOCHS, device="cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Train m7 models

for model in m7_models:
  trainer = Trainer(CHECKPOINT_DIR)

  trainer.fit(model, train_dataloader, val_dataloader,
              epochs=EPOCHS, device="cuda" if torch.cuda.is_available() else "cpu")