In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
from matplotlib import pyplot as plt

import torch
from torch import nn, optim
from torchvision import datasets, transforms, utils
from torch.utils import data
from torchkeras import summary, Model
from sklearn.metrics import precision_score
import pandas as pd
import os

In [12]:
MNIST_PATH = os.path.join('..', 'data')
HISTORY_FILE = os.path.join(MNIST_PATH, 'MNIST', 'mnist_history.csv')
WEIGHT_FILE = os.path.join(MNIST_PATH, 'MNIST', 'mnist_weight.pkl')

NB_CLASSES = 10

data_tf = transforms.Compose([
    transforms.ToTensor(),  # 0~255 -> 0~1
    transforms.Normalize((0.5, ), (0.5, ))  # 0~1 -> -1~1
])

ds_train = datasets.MNIST(MNIST_PATH, train=True, transform=data_tf, download=True)
ds_valid = datasets.MNIST(MNIST_PATH, train=False, transform=data_tf, download=True)

dl_train = data.DataLoader(ds_train, batch_size=32, shuffle=True)
dl_valid = data.DataLoader(ds_valid, batch_size=128, shuffle=True)

In [None]:
# batch sample plot
batch_images, batch_labels = next(iter(dl_train))
grid_image = utils.make_grid(batch_images, nrow=4).numpy().transpose(1, 2, 0)

mean = 0.5
std = 0.5
grid_image = grid_image * std + mean

# import numpy as np
# grid_image = np.array([
#     [
#         [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
#     ], [
#         [1.0, 1.0, 1.0], [0.0, 0.0, 0.0]
#     ]
# ])
plt.figure(figsize=(2, 2))
plt.imshow(grid_image)
plt.xticks([])
plt.yticks([])
plt.show()

In [18]:
class SimpleCNN(Model):
    def __init__(self, nb_classes=10, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.max_pool1 = nn.MaxPool2d(2)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(10, 20, 5)
        self.dropout1 = nn.Dropout2d()
        self.max_pool2 = nn.MaxPool2d(2)
        self.relu2 = nn.ReLU()
        
        self.flatten1 = nn.Flatten()

        self.fc1 = nn.Linear(320, 50)
        self.relu3 = nn.ReLU()

        self.fc2 = nn.Linear(50, nb_classes)
        self.relu4 = nn.ReLU()

        self.logsoftmax1 = nn.LogSoftmax(1)

    def forward(self, input):
        input = self.conv1(input)
        input = self.max_pool1(input)
        input = self.relu1(input)

        input = self.conv2(input)
        input = self.dropout1(input)
        input = self.max_pool2(input)
        input = self.relu2(input)
        
        input = self.flatten1(input)
        input = self.fc1(input)
        input = self.relu3(input)

        input = self.fc2(input)
        input = self.relu4(input)

        input = self.logsoftmax1(input)

        return input

In [None]:
model = SimpleCNN(NB_CLASSES)
model.summary(input_shape=(1, 28, 28))

In [None]:
def precision_metrics(y_pred, y_true):
    y_pred = y_pred.data.max(1)[1]
    score = precision_score(y_true.numpy(), y_pred.numpy(), average='macro')
    return torch.tensor(score)

loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=1e-3)
metrics_dict = {
    'precision': precision_metrics
}

model.compile(loss_fn, opt, metrics_dict)

In [None]:
# training
dfhistory = model.fit(20, dl_train, dl_valid, 500)

### save history and weights

In [None]:
# save training history
dfhistory.to_csv(HISTORY_FILE)

# save weights
torch.save(model.state_dict(), WEIGHT_FILE)

### load training history

In [None]:
# load training history
dfhistory = pd.read_csv('../data/MNIST/history.csv', index_col=0)

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import matplotlib.pyplot as plt

def plot_metric(dfhistory, metric):
    train_metrics = dfhistory[metric]
    val_metrics = dfhistory['val_'+metric]
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics, 'bo--')
    plt.plot(epochs, val_metrics, 'ro-')
    plt.title('Training and validation '+ metric)
    plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.legend(["train_"+metric, 'val_'+metric])
    plt.show()

In [None]:
plot_metric(dfhistory, 'precision')

In [None]:
plot_metric(dfhistory, 'loss')

### load weights

In [19]:
# load weights
weights = torch.load(WEIGHT_FILE)

model = SimpleCNN(NB_CLASSES)
model.load_state_dict(weights)

<All keys matched successfully>

In [23]:
model.evaluate(dl_valid)

ModuleAttributeError: 'SimpleCNN' object has no attribute 'device'