see https://github.com/pytorch/examples/blob/master/mnist/main.py for an example

Things to figure out:
* weight initializations (use https://pytorch.org/docs/master/nn.html#torch-nn-init)
* live graphing of loss

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.tensor as T
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import matplotlib
import matplotlib.pyplot as plt

import numpy as np

In [3]:
# import MNIST dataset
mnist_train = torchvision.datasets.MNIST(root='./mnist_data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=4, shuffle=True)

mnist_test = torchvision.datasets.MNIST(root='./mnist_data', train=False, transform=transforms.ToTensor(), download=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=4, shuffle=True)

## Torch.nn methods
Convolution layers https://pytorch.org/docs/stable/nn.html?#convolution-layers

Pooling layers https://pytorch.org/docs/stable/nn.html?#pooling-layers

Padding layers https://pytorch.org/docs/stable/nn.html?#padding-layers

Non-linear activations https://pytorch.org/docs/stable/nn.html?#non-linear-activations-weighted-sum-nonlinearity

Normalization layers https://pytorch.org/docs/stable/nn.html?#normalization-layers

Recurrent layers https://pytorch.org/docs/stable/nn.html?#normalization-layers

Linear layers https://pytorch.org/docs/stable/nn.html?#linear-layers

Dropout layers https://pytorch.org/docs/stable/nn.html?#dropout-layers

Loss functions https://pytorch.org/docs/stable/nn.html?#loss-functions

Vision layers https://pytorch.org/docs/stable/nn.html?#vision-layers



In [79]:
# define network architecture
class mnist_net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2))
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(10, 20, kernel_size=7, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2))
        
        self.fc1 = nn.Linear(6*6*20, 121)
        self.fc2 = nn.Linear(121, 10)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        return self.fc2(out)

In [83]:
# train network
n_epochs = 5

net = mnist_net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for e in range(n_epochs):
    print('Training epoch: ' + str(e+1))
    L = 0
    for n, (im, im_label) in enumerate(train_loader):        
        optimizer.zero_grad()
        
        # forward pass through network
        output = net(im)

        loss = criterion(output, im_label)

        loss.backward()

        optimizer.step()
        
        L += loss.item()
        if(n%2000 == 1999):
            print('loss: %f' %(L/2000))
            L = 0

print('Done training.')

Training epoch: 1
loss: 0.784215
loss: 0.188390
loss: 0.163224
loss: 0.125645
loss: 0.106898
loss: 0.088044
loss: 0.097614
Training epoch: 2
loss: 0.074930
loss: 0.087394
loss: 0.070652
loss: 0.068146
loss: 0.062035
loss: 0.067421
loss: 0.067272
Training epoch: 3
loss: 0.051382
loss: 0.048713
loss: 0.060203
loss: 0.046884
loss: 0.058429
loss: 0.045207
loss: 0.047194
Training epoch: 4
loss: 0.037979
loss: 0.040030
loss: 0.036117
loss: 0.044277
loss: 0.040929
loss: 0.045510
loss: 0.041777
Training epoch: 5
loss: 0.034922
loss: 0.034496
loss: 0.029339
loss: 0.035563
loss: 0.039827
loss: 0.027901
loss: 0.036755
Training epoch: 6
loss: 0.027691
loss: 0.030956
loss: 0.033826
loss: 0.026702
loss: 0.026709
loss: 0.027927
loss: 0.036674
Training epoch: 7
loss: 0.022913
loss: 0.029878
loss: 0.025263
loss: 0.027349
loss: 0.026655
loss: 0.024728
loss: 0.026758
Training epoch: 8
loss: 0.021379
loss: 0.015413
loss: 0.020381
loss: 0.020706
loss: 0.022676
loss: 0.030986
loss: 0.023343
Training epoch: 

In [84]:
# evaluate performance on test set
n_correct = 0
n_total = 0
with torch.no_grad():
    for im, im_label in test_loader:
        # predict using trained net
        out = net(im)
        
        # check correct predictions
        n_correct += torch.sum(im_label == out.max(dim=1)[1])
        n_total += 4.0
        
print('Correctly classified %.1f%% of digits' %(100*n_correct.item()/n_total))

Correctly classified 99.2% of digits
