In [None]:
# https://www.kaggle.com/code/hojjatk/read-mnist-dataset/notebook
#
# This is a sample Notebook to demonstrate how to read "MNIST Dataset"
#
import numpy as np # linear algebra
import struct
from array import array
import os.path
import matplotlib.pyplot as plt
import torch
import torch.utils.data
from collections import defaultdict
import yaml
import glob


class Data:
    def __init__(self, inputs, labels):
        assert len(labels) == len(inputs)
        assert isinstance(inputs, torch.Tensor)
        assert isinstance(labels, torch.Tensor)
        self.inputs = inputs
        self.labels = labels
        self.dataset = torch.utils.data.TensorDataset(inputs, labels)
#
# MNIST Data Loader Class
#
class MnistDataloader(object):
    def __init__(self, input_path):
        join = os.path.join        
        self.training_images_filepath = join(input_path, 'train-images.idx3-ubyte')
        self.training_labels_filepath = join(input_path, 'train-labels.idx1-ubyte')
        self.test_images_filepath = join(input_path, 't10k-images.idx3-ubyte')
        self.test_labels_filepath = join(input_path, 't10k-labels.idx1-ubyte')        
    
    def read_images_labels(self, images_filepath, labels_filepath):        
        labels = []
        with open(labels_filepath, 'rb') as file:
            magic, size = struct.unpack(">II", file.read(8))
            if magic != 2049:
                raise ValueError('Magic number mismatch, expected 2049, got {}'.format(magic))
            labels = array("B", file.read())        
        
        with open(images_filepath, 'rb') as file:
            magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
            if magic != 2051:
                raise ValueError('Magic number mismatch, expected 2051, got {}'.format(magic))
            image_data = array("B", file.read())        
        images = []
        #for i in range(size):
        #    images.append([0] * rows * cols)
        for i in range(size):
            img = np.array(image_data[i * rows * cols:(i + 1) * rows * cols])
            img = img.reshape(28, 28)
            images.append(img)            
        
        inputs = torch.stack([torch.tensor(image) for image in images])
        inputs = inputs.unsqueeze(1)
        assert list(inputs.shape)[1:] == [1, 28, 28]

        inputs = (inputs * (2.0/255.0)) - 1.0

        labels = torch.tensor(labels)
        assert labels.shape[0] == inputs.shape[0]
        return Data(inputs, labels)


    def load_data(self):
        train_data = self.read_images_labels(self.training_images_filepath, self.training_labels_filepath)
        test_data = self.read_images_labels(self.test_images_filepath, self.test_labels_filepath)
        return train_data, test_data

#
# Helper function to show a list of images with their relating titles
#
def show_images(inputs, title_texts):
    cols = 5
    rows = int(len(inputs)/cols) + 1
    plt.figure(figsize=(30,20))
    index = 1    
    for input, title_text in zip(inputs, title_texts):
        image = (input[0] + 1.0) * 255.0        
        plt.subplot(rows, cols, index)        
        plt.imshow(image, cmap=plt.cm.gray)
        if (title_text != ''):
            plt.title(title_text, fontsize = 15);        
        index += 1


In [None]:
project_dir = "/home/dking/data/mnist/"
loader = MnistDataloader(os.path.join(project_dir, 'raw'))
train_data, test_data = loader.load_data()

In [None]:
import random

images_2_show = []
titles_2_show = []
for i in range(0, 10):
    r = random.randint(1, 60000)
    images_2_show.append(train_data.inputs[r])
    titles_2_show.append('training image [' + str(r) + '] = ' + str(train_data.labels[r]))    

show_images(images_2_show, titles_2_show)

In [None]:
# https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html

import torch
import torch.nn as nn
import torch.nn.functional as F

class MnistNet1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # 28x28 input (example assumes 32x32)
        # 28 - 4 = 24
        # 24/2 = 12
        # 12 - 4 = 8
        # 8/2 = 4
        # 16 channel * 4x4
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
class Results:
    def __init__(self, data, predictions, failure_idxs):
        self.data = data
        self.predictions = predictions
        self.failure_idxs = failure_idxs
        assert predictions.shape == self.data.labels.shape

    def failure_ratio(self):
        return len(self.failure_idxs) / len(self.predictions)

def compute_results(net, data):
    batch_size = 16
    failure_idxs = []
    predictions = torch.zeros(data.labels.shape, dtype=torch.int)
    with torch.no_grad():
        # TODO is there a good way to batch this?
        for idx in range(0, len(data.inputs), batch_size):
            label = data.labels[idx:idx+batch_size]
            input = data.inputs[idx:idx+batch_size]
            output = net(input)
            _, predict = torch.max(output, 1)
            predictions[idx:idx+batch_size] = predict
            errs = predict != label
            if torch.any(errs):
                for i, (p,l) in enumerate(zip(predict, label)):
                    if p != l:
                        failure_idxs.append(idx+i)
    return Results(data, predictions, failure_idxs)

def print_result_summary(results, name):
    print(f"{name} : error {len(results.failure_idxs)/len(results.data.labels)*100.0:0.1f}%")

def print_result_breakdown(results, name):
    print_result_summary(results, name)
    err_counts = defaultdict(int)
    for idx in results.failure_idxs:
        err_counts[(results.predictions[idx].item(), results.data.labels[idx].item())] += 1
    if len(err_counts) > 0:
        print("  LABEL  PREDICT  :  COUNT")
        err_counts = sorted([(c, p, l) for (p, l), c in err_counts.items()], reverse=True)
        for c, p, l in err_counts[:10]:
            print(f"  {l:3d}      {p:3d}    :  {c}")

def show_failures(results, count=10):
    idxs = random.sample(results.failure_idxs, count)
    inputs = results.data.inputs[idxs]
    titles = [f'{i}, expect {results.data.labels[i].item()} got {results.predictions[i].item()}' for i in idxs]    
    show_images(inputs, titles)


In [None]:
import torch.optim as optim
import time


def train(net_name, opt_name, batch_size=16, epoch_count=30):
    torch.manual_seed(0)
    if net_name == "net1":
        net = MnistNet1()
    else:
        raise RuntimeError("invalid net name")

    print(net)
    print("Total Param Count", sum([np.prod(p.size()) for p in net.parameters()]))

    if opt_name == 'sgd1':
        optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.0)
    elif opt_name == 'sgd2':
        optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    elif opt_name == 'sgdN':
        optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, nesterov=True)
    elif opt_name == 'adam':
        optimizer = optim.Adam(net.parameters())
    elif opt_name == 'rmsprop':
        optimizer = optim.RMSprop(net.parameters(), lr=0.001)
    else:
        raise RuntimeError("invalid opt name")

    print(optimizer)

    criterion = nn.CrossEntropyLoss()

    data_loader = torch.utils.data.DataLoader(train_data.dataset, batch_size=batch_size, shuffle=True)


    train_failures = []
    test_failures = []
    losses = []

    start = time.time()
    for epoch in range(epoch_count):
        print(f"EPOCH {epoch:02d}")
        running_loss = []

        for input, target in data_loader:
            optimizer.zero_grad()
            output = net(input)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            running_loss.append(loss.item())

        avg_loss = float(np.array(running_loss).mean())
        print(f"  LOSS {avg_loss:0.3f}")
        test_results = compute_results(net, test_data)
        train_results = compute_results(net, train_data)
        print_result_summary(train_results, "  TRAIN")
        print_result_summary(test_results, "  TEST")
        train_failures.append(train_results.failure_ratio())
        test_failures.append(test_results.failure_ratio())
        losses.append(avg_loss)
    training_duration = time.time() - start

    fn = os.path.join(project_dir, 'results', f'training_{net_name}_{opt_name}.yaml')
    print("Saving results to", fn)
    with open(fn, 'w') as fd:
        yaml.safe_dump({'net_name': net_name, 
                        'opt_name':opt_name, 
                        'losses': losses, 
                        'train_failures': train_failures,
                        'test_failures': test_failures,
                        'training_duration': training_duration,
                        }, 
                        stream=fd)

    train_results = compute_results(net, train_data)
    test_results = compute_results(net, test_data)
    return (train_results, test_results)


In [None]:
#for opt_name in ('rmsprop', ):
for opt_name in ('sgd1', 'sgd2', 'sgdN', 'adam'):
    train('net1', opt_name)

In [None]:
import glob

fns = sorted(glob.glob(os.path.join(project_dir, 'results', 'training*.yaml')))
print(fns)

training_durations = {}

plt.figure("training", figsize=(20,15))
for fn in fns:
    with open(fn, 'r') as fd:
        y = yaml.safe_load(fd)
    label = f"{y['net_name']},{y['opt_name']}"
    training_durations[label] = y['training_duration']
    plt.subplot(3, 1, 1)
    plt.plot(np.log(y['losses']), '.-', label=label)
    plt.ylabel('log loss')
    plt.legend(loc='best')
    plt.subplot(3,1,2)
    plt.ylabel('log train failure count')
    plt.plot(np.log(y['train_failures']), '.-', label=label)
    plt.legend(loc='best')
    plt.subplot(3,1,3)
    plt.ylabel('log test failure count')
    plt.plot(np.log(y['test_failures']), '.-', label=label)
    plt.legend(loc='best')
plt.show(block=False)

print("OPTIMIZER  : DURATION")
for label, duration in training_durations.items():
    print(f"{label:<10} : {duration:.1f}")
