### packages

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
from matplotlib import pyplot as plt

import torch
from torch import nn, optim
from torchvision import datasets, transforms, utils
from torch.utils import data
from torchkeras import summary, Model
from sklearn.metrics import precision_score, accuracy_score
import pandas as pd
import os
import datetime

### common codes

In [None]:
# metric plot
def plot_metric(dfhistory, metric):
    train_metrics = dfhistory[metric]
    val_metrics = dfhistory['val_'+metric]
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics, 'bo--')
    plt.plot(epochs, val_metrics, 'ro-')
    plt.title('Training and validation '+ metric)
    plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.legend(["train_"+metric, 'valid_'+metric])
    plt.show()

# metrics
def precision_metrics(targets, labels):
    y_pred = targets.data.max(1)[1].numpy()
    y_true = labels.numpy()
    score = precision_score(y_true, y_pred, average='macro')
    return torch.tensor(score)

def accuracy_metrics(targets, labels):
    y_pred = targets.data.max(1)[1].numpy()
    y_true = labels.numpy()
    score = accuracy_score(y_true, y_pred)
    return torch.tensor(score)


# training functions
def run_step(model, features, labels, train_mode=True):
    targets = model(features)
    
    metrics = dict()
    loss = model.loss_fn(targets, labels)
    metrics.update({'%sloss' % ('' if train_mode else 'val_'): loss.item()})
    
    for metric_name, metric_fn in model.metrics_dict.items():
        metric_value = metric_fn(targets, labels)
        metrics.update({'%s%s' % ('' if train_mode else 'val_', metric_name): metric_value.item()})

    loss.backward()
    model.optim.step()
    model.optim.zero_grad()

    return metrics


def run_epoch(model, dataloader, train_mode=True, log_per_steps=200):
    metrics_epoch = dict()

    model.train(train_mode)
    for step, (features, labels) in enumerate(dataloader, 1):
        metrics = run_step(model, features, labels, train_mode)

        # # update loss_epoch (mean)
        # loss_epoch = (step - 1) / step * loss_epoch + metric_val / step
        # update metric_epoch (mean)
        for metric_name, metric_val in metrics.items():
            if metrics_epoch.get(metric_name) == None:
                metrics_epoch[metric_name] = metric_val
            else:
                metrics_epoch[metric_name] = \
                    (step - 1) / step * metrics_epoch[metric_name] + metric_val / step

        if step % log_per_steps == 0:
            print(" - Step %d, %s" % (step, metrics_epoch))

    return metrics_epoch


def train_model(model, dataloader_train, dataloader_valid, epochs, log_per_epochs=10, log_per_steps=200, skip_epoch=0):
    print("==========" * 6)
    print("= Training model")
    
    metrics_list = []
    start_epoch = 1 + skip_epoch
    end_epoch = epochs + 1 + skip_epoch
    for epoch in range(start_epoch, end_epoch):
        metrics = dict()
        print("==========" * 6)
        print("= Epoch %d/%d @ %s" % (epoch, epochs + skip_epoch, datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
        metrics_train = run_epoch(model, dataloader_train, train_mode=True, log_per_steps=log_per_steps)
        metrics_valid = run_epoch(model, dataloader_valid, train_mode=False, log_per_steps=log_per_steps)
        metrics.update({'epoch': epoch})
        metrics.update(metrics_train)
        metrics.update(metrics_valid)
        metrics_list.append(metrics)

        if epoch % log_per_epochs == 0:
            print('= %s' % metrics)
        
    print("==========" * 6)
    
    history = pd.DataFrame(metrics_list)
    history.set_index('epoch', inplace=True)
    return history


def predict_model(model, features):
    targets = model(features).data.max(1)[1]
    return targets


def plot_images(images, nrows=8, figsize=(2, 2)):
    # images (B, C, H, W)
    grid_image = utils.make_grid(images, nrow=nrows)

    plt.figure(figsize=figsize)
    plt.imshow(grid_image.numpy().transpose(1, 2, 0))
    plt.xticks([])
    plt.yticks([])
    plt.show()

### datasets and dataloader

In [None]:
# datasets and dataloader
MNIST_PATH = os.path.join('..', 'data')
HISTORY_FILE = os.path.join(MNIST_PATH, 'MNIST', 'mnist_history.csv')
HISTORY2_FILE = os.path.join(MNIST_PATH, 'MNIST', 'mnist_history2.csv')
WEIGHT_FILE = os.path.join(MNIST_PATH, 'MNIST', 'mnist_weight.pth')
WEIGHT2_FILE = os.path.join(MNIST_PATH, 'MNIST', 'mnist_weight2.pth')

NB_CLASSES = 10

data_tf = transforms.Compose([
    transforms.ToTensor(),  # 0~255 -> 0~1
    transforms.Normalize((0.5, ), (0.5, ))  # 0~1 -> -1~1
])

ds_train = datasets.MNIST(MNIST_PATH, train=True, transform=data_tf, download=True)
ds_valid = datasets.MNIST(MNIST_PATH, train=False, transform=data_tf, download=True)

dl_train = data.DataLoader(ds_train, batch_size=32, shuffle=True)
dl_valid = data.DataLoader(ds_valid, batch_size=64, shuffle=True)

### batch sample plot (optional)

In [None]:
# batch sample plot
batch_images, batch_labels = next(iter(dl_train))

mean = 0.5
std = 0.5
batch_images = batch_images * std + mean

plot_images(batch_images)

### network class

In [None]:
# network
class SimpleCNN(nn.Module):
    def __init__(self, nb_classes=10, *args, **kwargs):
        super(SimpleCNN, self).__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.max_pool1 = nn.MaxPool2d(2)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(10, 20, 5)
        self.dropout1 = nn.Dropout2d()
        self.max_pool2 = nn.MaxPool2d(2)
        self.relu2 = nn.ReLU()
        
        self.flatten1 = nn.Flatten()

        self.fc1 = nn.Linear(320, 50)
        self.relu3 = nn.ReLU()

        self.fc2 = nn.Linear(50, nb_classes)
        self.relu4 = nn.ReLU()

        self.logsoftmax1 = nn.LogSoftmax(1)

    def forward(self, input):
        input = self.conv1(input)
        input = self.max_pool1(input)
        input = self.relu1(input)

        input = self.conv2(input)
        input = self.dropout1(input)
        input = self.max_pool2(input)
        input = self.relu2(input)
        
        input = self.flatten1(input)
        input = self.fc1(input)
        input = self.relu3(input)

        input = self.fc2(input)
        input = self.relu4(input)

        input = self.logsoftmax1(input)

        return input

### network topology (optional)

In [None]:
# network topology
Model(SimpleCNN(NB_CLASSES)).summary(input_shape=(1, 28, 28))

### training settings (loss, optim & metrics)

In [None]:
model = SimpleCNN(NB_CLASSES)
model.loss_fn = nn.CrossEntropyLoss()
model.optim = optim.Adam(model.parameters(), lr=1e-3)
model.metrics_dict = {
    'precision': precision_metrics,
    'accuracy': accuracy_metrics
}

### training and show history

In [None]:
# model training
dfhistory = train_model(model, dl_train, dl_valid, 10, log_per_epochs=1, log_per_steps=200)
dfhistory

### save history and weights

In [None]:
# save training history
dfhistory.to_csv(HISTORY_FILE)

# save weights
torch.save(model.state_dict(), WEIGHT_FILE)

### load training history (optional)

In [None]:
# load training history
dfhistory = pd.read_csv(HISTORY_FILE, index_col='epoch')
dfhistory

### re-training

In [None]:
model = SimpleCNN(NB_CLASSES)
model.loss_fn = nn.CrossEntropyLoss()
model.optim = optim.Adam(model.parameters(), lr=1e-3)
model.metrics_dict = {
    'precision': precision_metrics,
    'accuracy': accuracy_metrics
}

# load weights
weights = torch.load(WEIGHT_FILE)
model.load_state_dict(weights)

dfhistory2 = train_model(model, dl_train, dl_valid, 2, log_per_epochs=1, log_per_steps=200, skip_epoch=10)
dfhistory2

### save history and weights

In [None]:
# save training history
dfhistory2.to_csv(HISTORY2_FILE)

# save weights
torch.save(model.state_dict(), WEIGHT2_FILE)

### metrics plot

In [None]:
plot_metric(dfhistory2, 'loss')

In [None]:
plot_metric(dfhistory2, 'precision')

### prediction

In [None]:
samples, _ = next(iter(dl_valid))

model = SimpleCNN(NB_CLASSES)

NROWS = 8

# load weights
weights = torch.load(WEIGHT2_FILE)
model.load_state_dict(weights)
targets = predict_model(model, samples)
print(targets.numpy().reshape(-1, NROWS))

mean = 0.5
std = 0.5
samples = samples * std + mean
plot_images(samples, nrows=NROWS)