# A basic training loop

## MNIST data setup

In [270]:
from pathlib import Path

DATA_PATH = Path('data')
PATH = DATA_PATH/'mnist'

PATH.mkdir(parents=True, exist_ok=True)

In [271]:
import requests

URL='http://deeplearning.net/data/mnist/'
FILENAME='mnist.pkl.gz'

if not (PATH/FILENAME).exists():
    content = requests.get(URL+FILENAME).content
    (PATH/FILENAME).open('wb').write(content)

In [272]:
import pickle, gzip

((x, y), (x_valid, y_valid), _) = pickle.load(gzip.open(PATH/FILENAME, 'rb'), encoding='latin-1')

In [273]:
import torch 

x,y,x_valid,y_valid = map(torch.tensor, (x,y,x_valid,y_valid))

In [274]:
n,c = x.shape
x, x.shape, y.min(), y.max()

(tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),
 torch.Size([50000, 784]),
 tensor(0),
 tensor(9))

## Basic model and training loop

In [275]:
import math

weights = torch.randn(784,10)/math.sqrt(784)
weights.requires_grad_()
bias = torch.zeros(10, requires_grad=True)

In [276]:
import torch.nn.functional as F

def model(xb):
    xb = (xb @ weights) + bias
    return F.log_softmax(xb, dim=-1)

In [277]:
bs=64

In [278]:
preds = model(x[0:bs])
preds[0], preds.shape

(tensor([-1.9563, -2.3752, -2.0437, -2.4484, -2.5666, -2.6892, -2.4609,
         -2.1520, -2.0369, -2.6191]), torch.Size([64, 10]))

In [279]:
loss_fn = F.nll_loss
loss_fn(preds, y[0:bs])

tensor(2.3245)

In [280]:
lr = 0.5
epochs = 2

In [281]:
from IPython.core.debugger import set_trace

In [40]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
#         set_trace()
        start_i = i*bs
        end_i = min(start_i+bs, n)
        xb = x[start_i:end_i]
        yb = y[start_i:end_i]
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        with torch.no_grad():
            weights -= weights.grad * lr
            bias -= bias.grad * lr
            weights.grad.zero_()
            bias.grad.zero_()

In [41]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2186)

## Refactor using nn.Module

In [284]:
from torch import nn

class Mnist_Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(784,10)/math.sqrt(784))
        self.bias = nn.Parameter(torch.zeros(10))

    def forward(self, xb):
        xb = (xb @ self.weights) + self.bias
        return F.log_softmax(xb, dim=-1)

In [285]:
model = Mnist_Logistic()

In [44]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(2.2875)

In [45]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
        start_i = i*bs
        end_i = min(start_i+bs, n)
        xb = x[start_i:end_i]
        yb = y[start_i:end_i]
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        with torch.no_grad():
            for p in model.parameters(): p -= p.grad * lr
            model.zero_grad()

In [46]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2182)

## Refactor using nn.Linear

In [286]:
class Mnist_Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784,10)

    def forward(self, xb):
        return F.log_softmax(self.lin(xb), dim=-1)

In [287]:
model = Mnist_Logistic()
loss_fn(model(x[0:bs]), y[0:bs])

tensor(2.3395)

In [51]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
        start_i = i*bs
        end_i = min(start_i+bs, n)
        xb = x[start_i:end_i]
        yb = y[start_i:end_i]
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        with torch.no_grad():
            for p in model.parameters(): p -= p.grad * lr
            model.zero_grad()

In [52]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2182)

## Refactor using optim

In [288]:
from torch import optim

In [289]:
model = Mnist_Logistic()
opt = optim.SGD(model.parameters(), lr=1.)

loss_fn(model(x[0:bs]), y[0:bs])

tensor(2.3157)

In [55]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
        start_i = i*bs
        end_i = min(start_i+bs, n)
        xb = x[start_i:end_i]
        yb = y[start_i:end_i]
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

In [56]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2204)

## Refactor using Dataset

In [290]:
from torch.utils.data import TensorDataset

In [291]:
model = Mnist_Logistic()
opt = optim.SGD(model.parameters(), lr=1.)

In [292]:
train_ds = TensorDataset(x, y)

In [293]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(2.2903)

In [61]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
        xb,yb = train_ds[i*bs : i*bs+bs]
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

In [62]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2182)

## Refactor using DataLoader

In [294]:
from torch.utils.data import DataLoader

In [295]:
model = Mnist_Logistic()
opt = optim.SGD(model.parameters(), lr=1.)

In [296]:
train_ds = TensorDataset(x, y)
train_dl = DataLoader(train_ds, batch_size=bs)

In [297]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(2.2988)

In [67]:
for epoch in range(epochs):
    for xb,yb in train_dl:
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

In [68]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2184)

# Add validation

## First try

In [310]:
model = Mnist_Logistic()
opt = optim.SGD(model.parameters(), lr=1.)

In [311]:
train_ds = TensorDataset(x, y)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

In [312]:
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs*2)

In [316]:
loss_fn(model(x_valid[0:bs]), y_valid[0:bs])

tensor(2.3045)

In [73]:
for epoch in range(epochs):
    model.train()
    for xb,yb in train_dl:
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()
        
    model.eval()
    valid_loss = 0.
    with torch.no_grad():
        for xb,yb in valid_dl:
            valid_loss += loss_fn(model(xb), yb)

    print(epoch, valid_loss/len(valid_dl))

0 tensor(0.3758)
1 tensor(0.3508)


## Create fit()

In [305]:
def loss_batch(model, xb, yb, opt=None):
    loss = loss_fn(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
        
    return loss.item(), len(xb)

In [318]:
import numpy as np

def fit(epochs, model, loss_fn, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl: loss_batch(model, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses,nums = zip(*[loss_batch(model, xb, yb)
                                for xb,yb in valid_dl])
        val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)

        print(epoch, val_loss)

In [324]:
model = Mnist_Logistic()
opt = optim.SGD(model.parameters(), lr=0.5)

In [325]:
train_ds = TensorDataset(x, y)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs*2)

In [329]:
fit(epochs, model, loss_fn, opt, train_dl, valid_dl)

0 0.26180660036802295
1 0.26372072726488116


## Move opt creation into fit()

In [330]:
def fit(epochs, model, loss_fn, opt_fn, lr, train_dl, valid_dl):
    opt = opt_fn(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl: loss_batch(model, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses,nums = zip(*[loss_batch(model, xb, yb)
                                for xb,yb in valid_dl])
        val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)

        print(epoch, val_loss)

In [331]:
model = Mnist_Logistic()

In [332]:
train_ds = TensorDataset(x, y)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs*2)

In [333]:
fit(epochs, model, loss_fn, optim.SGD, lr, train_dl, valid_dl)

0 0.29762732087373733
1 0.2890522579073906


## Refactor to ModelData

In [365]:
class ModelData():
    def __init__(self, train_ds, valid_ds, bs):
        self.train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
        self.valid_dl = DataLoader(valid_ds, batch_size=bs*2)

In [366]:
def fit(epochs, model, loss_fn, opt_fn, lr, data):
    opt = opt_fn(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        for xb,yb in data.train_dl:
            loss_batch(model, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses,nums = zip(*[loss_batch(model, xb, yb)
                                for xb,yb in data.valid_dl])
        val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)

        print(epoch, val_loss)

In [367]:
model = Mnist_Logistic()
data = ModelData(TensorDataset(x, y), TensorDataset(x_valid, y_valid), bs)

In [368]:
fit(epochs, model, loss_fn, optim.SGD, lr, data)

0 0.4939220434188843
1 0.302523850440979


# Switch to CNN

## First try

In [369]:
class Mnist_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,  16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 10,  kernel_size=3, stride=2, padding=1)

    def forward(self, xb):
        xb = xb.view(-1,1,28,28)
        xb = F.relu(self.conv1(xb))
        xb = F.relu(self.conv2(xb))
        xb = F.relu(self.conv3(xb))
        xb = F.avg_pool2d(xb, 4)
        xb = xb.view(-1,xb.size(1))
        return F.log_softmax(xb, dim=-1)

In [370]:
model = Mnist_CNN()

In [340]:
fit(epochs, model, loss_fn, optim.SGD, lr*2, data)

0 0.5376164931297303
1 0.48154172859191896


In [341]:
fit(epochs, model, loss_fn, optim.SGD, lr/2, data)

0 0.41251860780715943
1 0.4114646418571472


## Transformation

In [371]:
def mnist2image(v): return v.view(1,28,28)

In [372]:
from torch.utils.data import Dataset

class TransformedDataset(Dataset):
    def __init__(self, ds, x_tfms=None, y_tfms=None):
        self.ds,self.x_tfms,self.y_tfms = ds,x_tfms,y_tfms
        
    def __len__(self): return len(self.ds)

    def __getitem__(self,i):
        x,y = self.ds[i]
        if self.x_tfms is not None: x = self.x_tfms(x)
        if self.y_tfms is not None: y = self.y_tfms(y)
        return x,y

In [373]:
train_tfm_ds = TransformedDataset(train_ds, mnist2image)
valid_tfm_ds = TransformedDataset(valid_ds, mnist2image)

In [374]:
train_dl = DataLoader(train_tfm_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_tfm_ds, batch_size=bs*2)

In [375]:
data = ModelData(train_tfm_ds, valid_tfm_ds, bs)

In [376]:
class Mnist_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,  16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 10,  kernel_size=3, stride=2, padding=1)

    def forward(self, xb):
        xb = F.relu(self.conv1(xb))
        xb = F.relu(self.conv2(xb))
        xb = F.relu(self.conv3(xb))
        xb = F.avg_pool2d(xb, 4)
        xb = xb.view(-1,xb.size(1))
        return F.log_softmax(xb, dim=-1)

In [377]:
model = Mnist_CNN()

In [378]:
fit(epochs, model, loss_fn, optim.SGD, lr*2, data)

0 0.6518395250320435
1 0.5133248091697693


In [379]:
fit(epochs, model, loss_fn, optim.SGD, lr/2, data)

0 0.4676215552330017
1 0.46685396366119386


## Refactor network

In [380]:
class Simple_CNN(nn.Module):
    def __init__(self, actns, kernel_szs, strides):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv2d(actns[i], actns[i+1], kernel_szs[i], stride=strides[i], padding=kernel_szs[i]//2)
            for i in range(len(strides))])

    def forward(self, xb):
        for conv in self.convs: xb = F.relu(conv(xb))
        xb = F.adaptive_avg_pool2d(xb, 1)
        xb = xb.view(-1,xb.size(1))
        return F.log_softmax(xb, dim=-1)

In [381]:
def get_model(): return Simple_CNN([1,16,16,10], [3,3,3], [2,2,2])

In [382]:
model = get_model()

In [383]:
fit(epochs, model, loss_fn, optim.SGD, lr*2, data)

0 0.4564737708568573
1 0.47677493288517


In [355]:
fit(epochs, model, loss_fn, optim.SGD, lr/2, data)

0 0.18453329709768296
1 0.1730178252518177


## CUDA

In [384]:
default_device = torch.device('cuda')

In [388]:
class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl,self.device = dl,device
        
    def __iter__(self):
        for b in self.dl: yield [o.to(device) for o in b]

class ModelData():
    def __init__(self, train_ds, valid_ds, bs, device):
        train_dl = DataLoader(train_tfm_ds, batch_size=bs, shuffle=True)
        self.train_dl = DeviceDataLoader(train_dl, device)
        valid_dl = DataLoader(valid_tfm_ds, batch_size=bs*2)
        self.valid_dl = DeviceDataLoader(valid_dl, device)

In [389]:
data = ModelData(TensorDataset(x, y), TensorDataset(x_valid, y_valid),
                 bs, default_device)

In [390]:
model = Mnist_CNN().to(default_device)

In [391]:
fit(epochs, model, loss_fn, optim.SGD, lr*2, data)

0 0.7703350713729858
1 0.4241752480506897


In [392]:
fit(epochs, model, loss_fn, optim.SGD, lr/2, data)

0 0.22181436491012574
1 0.21518816499710083
