In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import numpy as np

import matplotlib.pyplot as plt

In [None]:
import urllib.request

training = (
    "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
    "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
)

testing = (
    "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
    "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
)

def dl_mnist(ds, prefix):
    img_url, lab_url = ds
    print(f"Downloading {prefix}")
    urllib.request.urlretrieve(img_url, f'./{prefix}-img')
    urllib.request.urlretrieve(lab_url, f'./{prefix}-lab')
    print("Done")

In [None]:
dl_mnist(training, "train")

In [None]:
import struct
import gzip

class MnistDataset(Dataset):
    def __init__(self, imgfile, labelfile, transform=None):
        self.imgfile = imgfile
        self.labelfile = labelfile
        self.images, self.labels = self.read_dataset(imgfile, labelfile)
        self.transform = transform
    
    def __len__(self):
        return self.images.shape[0]
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = {'image': self.images[idx], 'label': self.labels[idx]}
        
        if self.transform:
            sample = self.transform(sample)
        return sample

    def read_dataset(self, imgname, labname):
        import struct
        import gzip
        X = []
        y = []
        with gzip.open(imgname, "rb") as img, gzip.open(labname, "rb") as labs:
            img_header = struct.unpack(">4i", img.read(16))
            lab_header = struct.unpack(">2i", labs.read(8))

            img_size = img_header[2] * img_header[3]

            for i in range(img_header[1]):
                image = struct.unpack(f"{img_size}B", img.read(img_size))
                label = struct.unpack("B", labs.read(1))
                image = np.array(image).reshape((28, 28))
                X.append(image)
                y.append(label[0])
            X = np.array(X, dtype="float32")
        return X, np.array(y)

In [None]:
class normalize(object):
    """Rescale values between 0 and 1"""
    
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        
        return {'image': image / 255, 'label': label}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        return {'image': torch.from_numpy(image),
                'label': label}
    
compose = transforms.Compose([ToTensor()])

In [None]:
def print_data(x):
    img = x['image'].numpy()
    plt.imshow(x['image'], cmap='Greys')
    plt.xlabel(f"Label: {x['label'].item()}")

In [None]:
ds = MnistDataset("train-img", "train-lab", transform=compose)

In [None]:
print_data(ds[5])

In [None]:
dataloader = DataLoader(ds, batch_size=1000, shuffle=True)

In [None]:
dl_mnist(testing, "test")

In [None]:
test = MnistDataset("test-img", 'test-lab')

In [None]:
test_dataloader = DataLoader(test, shuffle=True, batch_size=10)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.sigmoid(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x

net = Net()

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

In [None]:
def accuracy(dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            images, labels = data['image'], data['label']
            out = net(images)
            out = torch.argmax(out, dim=1)
            total += labels.shape[0]
            correct += (out == labels).sum().item()
    print("Training accuracy:", correct / total)
    return correct / total

In [None]:
epoch = 10

running_loss = 0.0
mlp_accuracy = []
for e in range(epoch):
    for i, data in enumerate(dataloader):
        images, labels = data['image'], data['label']
        optimizer.zero_grad()
        
        out = net(images)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 50 == 0:
            #print('[%d, %5d] loss: %.3f' %
                  #(e + 1, i + 1, running_loss / 1000))
            #running_loss = 0.0
            mlp_accuracy.append(accuracy(test_dataloader))
print("Training complete")

In [None]:
data = ds[112]

In [None]:
print_data(data)

In [None]:
net.forward(data['image'])

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Conv2d(1, 3, 3)
        self.fc = nn.Linear(3 * 28 * 28, 10)
        
    def forward(self, x):
        x = self.relu(self.conv(x))
        x = x.view(-1, 3 * 28 * 28)
        x = self.fc(conv)
        
cnn = CNN()

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

In [None]:
epoch = 10

running_loss = 0.0
cnn_accuracy = []
for e in range(epoch):
    for i, data in enumerate(dataloader):
        images, labels = data['image'], data['label']
        optimizer.zero_grad()
        
        out = net(images)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 50 == 0:
            #print('[%d, %5d] loss: %.3f' %
                  #(e + 1, i + 1, running_loss / 1000))
            #running_loss = 0.0
            cnn_accuracy.append(accuracy(test_dataloader))
print("Training complete")

In [None]:
from bokeh.plotting import show, figure

In [None]:
p = figure(tools="wheel_zoom,pan,hover,reset")
p.line(np.arange(len(mlp_accuracy)), mlp_accuracy, line_color="blue", line_width=3, legend="MLP")
p.line(np.arange(len(cnn_accuracy)), cnn_accuracy, line_color="red", line_width=3, legend="CNN")
p.legend.location="top_left"
p.legend.click_policy="hide"
p.yaxis.axis_label="Accuracy (%)"
show(p)