In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from torchvision.transforms import ToTensor

import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline 

In [None]:
class CNN(nn.Module):
    def __init__(self, n_classes=10):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(1, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

        self.classifier = nn.Sequential(
            nn.Linear(128 * 7 * 7, 128),
            nn.ReLU(),
            nn.Dropout(.5),
            nn.Linear(128, n_classes),
        )

    def forward(self, data):
        data = self.features(data)
        data = data.view(data.size()[0], -1)
        return self.classifier(data)

In [None]:
class Agent:
    def __init__(self, model, data, lr=4e-4, lr_decay=.9, batch_size=64):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.model = model
        self.data = data
        self.name = data.__name__

        self.lr = lr
        self.lr_decay = lr_decay
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(model.parameters(), lr=lr)

        self.train_data = data(self.name, train=True, download=True, transform=ToTensor())
        self.train_loader = torch.utils.data.DataLoader(self.train_data,
                                                      batch_size=batch_size,
                                                      shuffle=True)
        
        self.test_data = data(self.name, train=False, download=True, transform=ToTensor())
        self.test_loader = torch.utils.data.DataLoader(self.test_data,
                                                     batch_size=batch_size)
        
        self.model.to(self.device)
        self.loss_fn.to(self.device)

        self.best = None

        print('n_train', len(self.train_data))
        print('n_test', len(self.test_data))
        print('n_features', sum(p.numel() for p in self.model.features.parameters()))
        print('n_classifier', sum(p.numel() for p in self.model.classifier.parameters()))
        
    def train(self, epochs=10):
        self.optimizer.lr = self.lr

        timestamp = time.time()
        for epoch in range(epochs):
            self.model.train()

            train_loss = 0
            train_acc = 0
            train_cnt = 0

            for images, labels in iter(self.train_loader):
                train_cnt += len(images)

                images, labels = images.to(self.device), labels.to(self.device)
                
                self.optimizer.zero_grad()

                prediction = self.model(images)

                loss = self.loss_fn(prediction, labels)
                loss.backward()

                self.optimizer.step()

                train_loss += loss.item()
                success = (labels.data == prediction.max(dim=1)[1])
                train_acc += success.type(torch.FloatTensor).sum()

            train_acc /= train_cnt
            test_acc, _ = self.test()

            duration = time.gmtime(time.time() - timestamp)

            print(f'epoch:{epoch+1:3d}',
                  f'   loss:{train_loss:9.3f}',
                  f'   train:{train_acc:.4f}',
                  f'   test:{test_acc:.4f}',
                  f'   lr:{self.optimizer.lr:.6f}',
                  time.strftime('   %X', duration),)

            self.save_model(self.model, test_acc, train_acc, train_loss)

            self.optimizer.lr *= self.lr_decay

    def test(self):
        self.model.eval()

        with torch.no_grad():
            
            accuracy = 0
            fail = []
            
            for i, (images, labels) in enumerate(self.test_loader):
                images, labels = images.to(self.device), labels.to(self.device)

                prediction = self.model(images)

                success = (labels.data == prediction.max(dim=1)[1])

                fail.extend([i*len(images)+j
                             for j, x in enumerate(success) if not x])

                accuracy += success.type(torch.FloatTensor).sum()

            accuracy /= len(self.test_data)

            return accuracy, fail

    def save_model(self, model, test_acc, train_acc, train_loss):
        if self.is_better(test_acc, train_acc, train_loss):
            self.best = {
                'model': model.state_dict(),
                'test_acc': test_acc,
                'train_acc': train_acc,
                'train_loss': train_loss,
            }

            torch.save(self.best, f'{self.name}.model')

    def load_model(self, file_name=None):
        if file_name is None:
            file_name = f'{self.name}.model'

        self.best = torch.load(file_name)
        self.model.load_state_dict(self.best['model'])

    def is_better(self, test_acc, train_acc, train_loss):
        if self.best is None:
            return True

        if self.best['test_acc'] != test_acc:
            return self.best['test_acc'] < test_acc
        if self.best['train_acc'] != train_acc:
            return self.best['train_acc'] < test_acc
        return self.best['train_loss'] < train_loss

    def predict(self, n=None):
        if n is None:
            n = np.random.randint(0, len(self.test_data))

        sample = torch.utils.data.Subset(self.test_data, [n])
        sample = torch.utils.data.DataLoader(sample, batch_size=1)

        self.model.eval()
        softmax = nn.Softmax(dim=1)

        for image, label in iter(sample):
            with torch.no_grad():
                predict = self.model(image.to(self.device))

            image = image.detach().numpy().squeeze()
            label = f'\n{n}: {self.data.classes[label.item()]}\n'
            predict = predict.cpu()
            predict = softmax(predict).detach().numpy()[0]
            p_label = f'\n{n}: {self.data.classes[np.argmax(predict)]}\n'

            plt.figure(figsize=(16,7))

            plt.subplot(1,2,1)
            plt.imshow(image)
            plt.title(label, fontsize=36)

            plt.subplot(1,2,2)
            sns.set(style="whitegrid", color_codes=True)
            pal = sns.color_palette("coolwarm_r", len(predict))
            rank = predict.argsort().argsort()
            sns.barplot(x=[self.data.classes[i] for i in range(len(predict))],
                        y=predict,
                        palette=np.array(pal[::-1])[rank])
            plt.title(p_label, fontsize=36)
            plt.xticks(rotation=70)
            plt.ylim(0, 1)

            plt.show()

In [None]:
model = CNN()
agent = Agent(model, MNIST)

In [None]:
acc, fail = agent.test()

print(f'Accuracy: {acc:.4f}\n')
print(f'Failed predictions: {fail}')

In [None]:
agent.predict(340)

In [None]:
agent.train(40)

In [None]:
agent.load_model()