In [0]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

In [2]:
from torchvision.datasets import MNIST

dataset = MNIST('data', download=True, transform=transforms.ToTensor())
loader = DataLoader(dataset, batch_size=16, num_workers=4)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
Processing...
Done!


In [0]:
class MNISTNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.first_layer = nn.Sequential(nn.Conv2d(1, 16, 3, padding=1), nn.ReLU())
        self.conv_layers = nn.Sequential(*[nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU()) for _ in range(4)])
        self.last_layer = nn.Linear(28*28*16, 10)
    
    def forward(self, x):
        x = self.first_layer(x)
        x = self.conv_layers(x)
        x = x.reshape(-1, 28*28*16)
        x = self.last_layer(x)
        return x
        

In [6]:
net = MNISTNet()

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
loss_fun = nn.CrossEntropyLoss()

num_elements = len(loader.dataset)

for e in range(10):
    acc = 0
    loss = 0.
    for im, label in loader:
        optimizer.zero_grad()
        
        output = net(im)
        batch_loss = loss_fun(output, label)
        
        batch_loss.backward()
        
        loss += batch_loss.item() * im.shape[0]
        
        optimizer.step()
        
        _, pred = output.detach().max(1)
        acc += (pred == label).sum()
        
    acc = acc.float() / num_elements
    loss = loss / num_elements
    
    print("Epoch {:d}: \nLoss: {:.2e}\nAccuracy: {:.2%}\n\n".format(e+1, loss, acc))

torch.save(net.state_dict(), 'mnist_net.model')

Epoch 1: 
Loss: 1.57e-01
Accuracy: 95.33%


Epoch 2: 
Loss: 5.88e-02
Accuracy: 98.25%


Epoch 3: 
Loss: 3.84e-02
Accuracy: 98.87%


Epoch 4: 
Loss: 2.77e-02
Accuracy: 99.14%


Epoch 5: 
Loss: 2.11e-02
Accuracy: 99.33%


Epoch 6: 
Loss: 1.71e-02
Accuracy: 99.44%


Epoch 7: 
Loss: 1.28e-02
Accuracy: 99.60%


Epoch 8: 
Loss: 1.20e-02
Accuracy: 99.64%


Epoch 9: 
Loss: 1.01e-02
Accuracy: 99.63%


Epoch 10: 
Loss: 9.13e-03
Accuracy: 99.66%


