# Quick MNIST

*by Jeremy Howard, fast.ai.*

In this notebook we'll see how easy it is to use the simple functions we created along with `torch.nn` and friends to train MNIST. We're going to create the same convolution net that we created at the end of the previous notebook, but this time we'll use the functions we've already written, and we'll skip the explanatory text so you can see just the final code.

You can use this same notebook to train other neural nets on other datasets with minimal changes.

In [2]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn,optim,tensor
from torch.utils.data import TensorDataset, DataLoader
import pickle, gzip


def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)


def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl: loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses,nums = zip(*[loss_batch(model, loss_func, xb, yb)
                                for xb,yb in valid_dl])
        val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)

        print(epoch, val_loss)


class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func=func

    def forward(self, x): return self.func(x)


class WrappedDataLoader():
    def __init__(self, dl, func):
        self.dl = dl
        self.func = func

    def __len__(self): return len(self.dl)

    def __iter__(self):
        batches = iter(self.dl)
        for b in batches: yield(self.func(*b))



In [4]:


with gzip.open('data/mnist.pkl.gz', 'rb') as f:
    ((train_x, train_y), (valid_x, valid_y), _) = pickle.load(f, encoding='latin-1')

In [6]:
bs=64
lr=0.1
epochs=20

dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [7]:
def preprocess(x,y): return x.view(-1,1,28,28).to(dev),y.to(dev)

def get_dataloader(x,y,bs,shuffle):
    ds = TensorDataset(*map(tensor, (x,y)))
    dl = DataLoader(ds, batch_size=bs, shuffle=shuffle)
    return WrappedDataLoader(dl, preprocess)

In [9]:
train_dl = get_dataloader(train_x, train_y, bs,   shuffle=False)
valid_dl = get_dataloader(valid_x, valid_y, bs*2, shuffle=True )

In [10]:
model = nn.Sequential(
    nn.Conv2d(1,  16, kernel_size=3, stride=2, padding=1), nn.ReLU(),
    nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(),
    nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),
    Lambda(lambda x: x.view(x.size(0),-1))
).to(dev)

opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [11]:
fit(epochs, model, F.cross_entropy, opt, train_dl, valid_dl)

0 0.35200357556343076
1 0.23578783464431763
2 0.20812358429431915
3 0.18592815330028534
4 0.17274532527923583
5 0.16054231069684027
6 0.15549320149421691
7 0.1520364279359579
8 0.14865893067568542
9 0.14952965597659348
10 0.146713733959198
11 0.1522614889740944
12 0.14569349290132522
13 0.13937673519700766
14 0.1439363737821579
15 0.13884912676736713
16 0.13784348073005676
17 0.13778409123420715
18 0.13393967863321304
19 0.13233190182447432
