# A basic training loop

## MNIST data setup

In [1]:
from pathlib import Path

DATA_PATH = Path('data')
PATH = DATA_PATH/'mnist'

PATH.mkdir(parents=True, exist_ok=True)

In [3]:
import requests

URL='http://deeplearning.net/data/mnist/'
FILENAME='mnist.pkl.gz'

if not (PATH/FILENAME).exists():
    content = requests.get(URL+FILENAME).content
    (PATH/FILENAME).open('wb').write(content)

In [4]:
import pickle, gzip

((x, y), (x_valid, y_valid), _) = pickle.load(gzip.open(PATH/FILENAME, 'rb'), encoding='latin-1')

In [5]:
import torch 

x,y,x_valid,y_valid = map(torch.tensor, (x,y,x_valid,y_valid))

In [6]:
n,c = x.shape
x, x.shape, y.min(), y.max()

(tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),
 torch.Size([50000, 784]),
 tensor(0),
 tensor(9))

## Basic model and training loop

In [7]:
import math

weights = torch.randn(784,10)/math.sqrt(784)
weights.requires_grad_()
bias = torch.zeros(10, requires_grad=True)

In [34]:
import torch.nn.functional as F

def model(xb):
    xb = (xb @ weights) + bias
    return F.log_softmax(xb, dim=-1)

In [35]:
bs=64

In [36]:
preds = model(x[0:bs])
preds[0], preds.shape

(tensor([-2.3212, -2.1283, -2.3868, -2.3388, -2.5165, -2.7557, -2.4033,
         -2.3795, -1.9538, -2.0764]), torch.Size([64, 10]))

In [37]:
loss_fn = F.nll_loss
loss_fn(preds, y[0:bs])

tensor(2.4042)

In [95]:
lr = 0.5
epochs = 2

In [39]:
from IPython.core.debugger import set_trace

In [40]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
#         set_trace()
        start_i = i*bs
        end_i = min(start_i+bs, n)
        xb = x[start_i:end_i]
        yb = y[start_i:end_i]
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        with torch.no_grad():
            weights -= weights.grad * lr
            bias -= bias.grad * lr
            weights.grad.zero_()
            bias.grad.zero_()

In [41]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2186)

## Refactor using nn.Module/Parameter

In [42]:
from torch import nn

class Mnist_Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(784,10)/math.sqrt(784))
        self.bias = nn.Parameter(torch.zeros(10))

    def forward(self, xb):
        xb = (xb @ self.weights) + self.bias
        return F.log_softmax(xb, dim=-1)

In [43]:
model = Mnist_Logistic()

In [44]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(2.2875)

In [45]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
        start_i = i*bs
        end_i = min(start_i+bs, n)
        xb = x[start_i:end_i]
        yb = y[start_i:end_i]
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        with torch.no_grad():
            for p in model.parameters(): p -= p.grad * lr
            model.zero_grad()

In [46]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2182)

## Refactor using nn.Linear

In [49]:
class Mnist_Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784,10)

    def forward(self, xb):
        return F.log_softmax(self.lin(xb), dim=-1)

In [50]:
model = Mnist_Logistic()
loss_fn(model(x[0:bs]), y[0:bs])

tensor(2.2816)

In [51]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
        start_i = i*bs
        end_i = min(start_i+bs, n)
        xb = x[start_i:end_i]
        yb = y[start_i:end_i]
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        with torch.no_grad():
            for p in model.parameters(): p -= p.grad * lr
            model.zero_grad()

In [52]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2182)

## Refactor using optim

In [53]:
from torch import optim

In [54]:
model = Mnist_Logistic()
opt = optim.SGD(model.parameters(), lr=1.)

loss_fn(model(x[0:bs]), y[0:bs])

tensor(2.3345)

In [55]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
        start_i = i*bs
        end_i = min(start_i+bs, n)
        xb = x[start_i:end_i]
        yb = y[start_i:end_i]
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

In [56]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2204)

## Refactor using Dataset

In [57]:
from torch.utils.data import TensorDataset

In [58]:
model = Mnist_Logistic()
opt = optim.SGD(model.parameters(), lr=1.)

In [59]:
train_ds = TensorDataset(x, y)

In [60]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(2.3900)

In [61]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
        xb,yb = train_ds[i*bs : i*bs+bs]
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

In [62]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2182)

## Refactor using DataLoader

In [63]:
from torch.utils.data import DataLoader

In [64]:
model = Mnist_Logistic()
opt = optim.SGD(model.parameters(), lr=1.)

In [65]:
train_ds = TensorDataset(x, y)
train_dl = DataLoader(train_ds, batch_size=bs)

In [66]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(2.3333)

In [67]:
for epoch in range(epochs):
    for xb,yb in train_dl:
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

In [68]:
loss_fn(model(x[0:bs]), y[0:bs])

tensor(0.2184)

# Add validation

## First try

In [69]:
model = Mnist_Logistic()
opt = optim.SGD(model.parameters(), lr=1.)

In [70]:
train_ds = TensorDataset(x, y)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

In [71]:
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs*2)

In [72]:
loss_fn(model(x_valid[0:bs]), y_valid[0:bs])

tensor(2.3428)

In [73]:
for epoch in range(epochs):
    model.train()
    for xb,yb in train_dl:
        pred = model(xb)
        loss = loss_fn(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()
        
    model.eval()
    valid_loss = 0.
    with torch.no_grad():
        for xb,yb in valid_dl:
            valid_loss += loss_fn(model(xb), yb)

    print(epoch, valid_loss/len(valid_dl))

0 tensor(0.3758)
1 tensor(0.3508)


## Create fit()

In [75]:
def fit(epochs, model, loss_fn, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl:
            pred = model(xb)
            loss = loss_fn(pred, yb)

            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        val_loss = 0.
        with torch.no_grad():
            for xb,yb in valid_dl:
                loss += loss_fn(model(xb), yb)

        print(epoch, loss/len(valid_dl))

In [76]:
model = Mnist_Logistic()
opt = optim.SGD(model.parameters(), lr=0.5)

In [77]:
train_ds = TensorDataset(x, y)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs*2)

In [78]:
fit(epochs, model, loss_fn, opt, train_dl, valid_dl)

0 tensor(0.3347)
1 tensor(0.2960)


## Move opt creation into fit()

In [79]:
def fit(epochs, model, loss_fn, opt_fn, lr, train_dl, valid_dl):
    opt = opt_fn(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl:
            pred = model(xb)
            loss = loss_fn(pred, yb)

            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        val_loss = 0.
        with torch.no_grad():
            for xb,yb in valid_dl:
                loss += loss_fn(model(xb), yb)

        print(epoch, loss/len(valid_dl))

In [80]:
model = Mnist_Logistic()

In [81]:
train_ds = TensorDataset(x, y)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs*2)

In [82]:
fit(epochs, model, loss_fn, optim.SGD, lr, train_dl, valid_dl)

0 tensor(0.5208)
1 tensor(0.3869)


## Refactor to ModelData

In [210]:
class ModelData():
    def __init__(self, train_ds, valid_ds, bs):
        self.train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
        self.valid_dl = DataLoader(valid_ds, batch_size=bs*2)

In [163]:
def fit(epochs, model, loss_fn, opt_fn, lr, data):
    opt = opt_fn(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        for xb,yb in data.train_dl:
            pred = model(xb)
            loss = loss_fn(pred, yb)

            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        val_loss = 0.
        with torch.no_grad():
            for xb,yb in data.valid_dl:
                loss += loss_fn(model(xb), yb)

        print(epoch, loss/len(valid_dl))

In [98]:
model = Mnist_Logistic()
data = ModelData(TensorDataset(x, y), TensorDataset(x_valid, y_valid), bs)

In [99]:
fit(epochs, model, loss_fn, optim.SGD, lr, data)

0 tensor(0.3433)
1 tensor(0.3326)


## Switch to CNN

In [164]:
class Mnist_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,  16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 10,  kernel_size=3, stride=2, padding=1)

    def forward(self, xb):
        xb = xb.view(-1,1,28,28)
        xb = F.relu(self.conv1(xb))
        xb = F.relu(self.conv2(xb))
        xb = F.relu(self.conv3(xb))
        xb = F.avg_pool2d(xb, 4)
        xb = xb.view(-1,xb.size(1))
        return F.log_softmax(xb, dim=-1)

In [165]:
model = Mnist_CNN()

In [166]:
fit(epochs, model, loss_fn, optim.SGD, lr*2, data)

0 tensor(0.6594)
1 tensor(0.3322)


In [167]:
fit(epochs, model, loss_fn, optim.SGD, lr/2, data)

0 tensor(0.2187)
1 tensor(0.2157)


## Transformation

In [246]:
def mnist2image(v): return v.view(1,28,28)

In [247]:
from torch.utils.data import Dataset

class TransformedDataset(Dataset):
    def __init__(self, ds, x_tfms=None, y_tfms=None):
        self.ds,self.x_tfms,self.y_tfms = ds,x_tfms,y_tfms
        
    def __len__(self): return len(self.ds)

    def __getitem__(self,i):
        x,y = self.ds[i]
        if self.x_tfms is not None: x = self.x_tfms(x)
        if self.y_tfms is not None: y = self.y_tfms(y)
        return x,y

In [248]:
train_tfm_ds = TransformedDataset(train_ds, mnist2image)
valid_tfm_ds = TransformedDataset(valid_ds, mnist2image)

In [249]:
train_dl = DataLoader(train_tfm_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_tfm_ds, batch_size=bs*2)

In [250]:
data = ModelData(train_tfm_ds, valid_tfm_ds, bs)

In [251]:
class Mnist_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,  16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 10,  kernel_size=3, stride=2, padding=1)

    def forward(self, xb):
        xb = F.relu(self.conv1(xb))
        xb = F.relu(self.conv2(xb))
        xb = F.relu(self.conv3(xb))
        xb = F.avg_pool2d(xb, 4)
        xb = xb.view(-1,xb.size(1))
        return F.log_softmax(xb, dim=-1)

In [252]:
model = Mnist_CNN()

In [253]:
fit(epochs, model, loss_fn, optim.SGD, lr*2, data)

0 tensor(0.7328)
1 tensor(0.3983)


In [254]:
fit(epochs, model, loss_fn, optim.SGD, lr/2, data)

0 tensor(0.2415)
1 tensor(0.1799)


## Refactor network

In [256]:
class Simple_CNN(nn.Module):
    def __init__(self, actns, kernel_szs, strides):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv2d(actns[i], actns[i+1], kernel_szs[i], stride=strides[i], padding=kernel_szs[i]//2)
            for i in range(len(strides))])

    def forward(self, xb):
        for conv in self.convs: xb = F.relu(conv(xb))
        xb = F.adaptive_avg_pool2d(xb, 1)
        xb = xb.view(-1,xb.size(1))
        return F.log_softmax(xb, dim=-1)

In [266]:
def get_model(): return Simple_CNN([1,16,16,10], [3,3,3], [2,2,2])

In [257]:
model = get_model()

In [259]:
fit(epochs, model, loss_fn, optim.SGD, lr*2, data)

0 tensor(0.7121)
1 tensor(0.3623)


In [260]:
fit(epochs, model, loss_fn, optim.SGD, lr/2, data)

0 tensor(0.2344)
1 tensor(0.2332)


## CUDA

In [261]:
default_device = torch.device('cuda')

In [262]:
def fit(epochs, model, loss_fn, opt_fn, lr, data):
    opt = opt_fn(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        for xb,yb in data.train_dl:
            xb,yb = xb.to(default_device),yb.to(default_device)
            pred = model(xb)
            loss = loss_fn(pred, yb)

            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        val_loss = 0.
        with torch.no_grad():
            for xb,yb in data.valid_dl:
                xb,yb = xb.to(default_device),yb.to(default_device)
                loss += loss_fn(model(xb), yb)

        print(epoch, loss/len(valid_dl))

In [267]:
model = get_model().to(default_device)

In [268]:
fit(epochs, model, loss_fn, optim.SGD, lr*2, data)

0 tensor(0.7301, device='cuda:0')
1 tensor(0.4838, device='cuda:0')


In [269]:
fit(epochs, model, loss_fn, optim.SGD, lr/2, data)

0 tensor(0.2183, device='cuda:0')
1 tensor(0.2172, device='cuda:0')


## Refactor CUDA with 'to()'

In [190]:
class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl,self.device = dl,device
        
    def __iter__(self):
        for b in self.dl: yield [o.to(device) for o in b]

class ModelData():
    def __init__(self, train_ds, valid_ds, bs, device):
        self.train_dl = DeviceDataLoader(DataLoader(train_ds, batch_size=bs, shuffle=True), device)
        self.valid_dl = DeviceDataLoader(DataLoader(valid_ds, batch_size=bs*2), device)

In [191]:
def fit(epochs, model, loss_fn, opt_fn, lr, data):
    opt = opt_fn(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        for xb,yb in data.train_dl:
            pred = model(xb)
            loss = loss_fn(pred, yb)

            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        val_loss = 0.
        with torch.no_grad():
            for xb,yb in data.valid_dl:
                loss += loss_fn(model(xb), yb)

        print(epoch, loss/len(valid_dl))

In [192]:
data = ModelData(TensorDataset(x, y), TensorDataset(x_valid, y_valid), bs, default_device)

In [193]:
model = Mnist_CNN().to(default_device)

In [194]:
fit(epochs, model, loss_fn, optim.SGD, lr*2, data)

0 tensor(0.8275, device='cuda:0')
1 tensor(0.3431, device='cuda:0')


In [195]:
fit(epochs, model, loss_fn, optim.SGD, lr/2, data)

0 tensor(0.1965, device='cuda:0')
1 tensor(0.1830, device='cuda:0')
