# **Mini Batch Training for MLP**

In [40]:
import pickle, gzip, math, os, time, shutil
import numpy as np
import matplotlib as mpl
import torch
from torch import tensor, nn
import torch.nn.functional as F
from fastcore.test import  test_close
from pathlib import Path

# Configs
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'
torch.set_printoptions(precision=2, linewidth=125, sci_mode=False)
np.set_printoptions(precision=2, linewidth=125)

# Path setup
path_data = Path('data')
path_gz = path_data/'mnist.pkl.gz'
with gzip.open(path_gz, 'rb') as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
# Loading MNIST data as tensors
x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])

## Initial Setup

### Data

Copying over the starting cells from the previous NB.

In [41]:
n, m = x_train.shape
c = y_train.max() + 1
nh = 50

n, m

(50000, 784)

In [42]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in, nh), nn.ReLU(), nn.Linear(nh, n_out)]

    def __call__(self, x):
        for l in self.layers: x = l(x)
        return x

In [43]:
model = Model(m, nh,  10)
pred = model(x_train)
pred.shape

torch.Size([50000, 10])

In [44]:
pred

tensor([[-0.09, -0.21, -0.08,  ..., -0.03,  0.01,  0.06],
        [-0.07, -0.14, -0.14,  ...,  0.03,  0.04,  0.14],
        [-0.19, -0.04,  0.02,  ..., -0.01, -0.00,  0.02],
        ...,
        [-0.03, -0.22, -0.04,  ..., -0.01,  0.09,  0.14],
        [-0.10, -0.09, -0.05,  ..., -0.01,  0.02,  0.11],
        [-0.03, -0.25, -0.06,  ...,  0.00,  0.03,  0.14]], grad_fn=<AddmmBackward0>)

### Cross Entropy Loss

We need to improve our loss function from before. Instead of outputting 1 number per image, we will now have 10 one-hot-encoded numbers per image.

In [45]:
def log_softmax(x): return (x.exp() / (x.exp().sum(-1, keepdim=True))).log()

In [46]:
log_softmax(pred)

tensor([[-2.37, -2.49, -2.36,  ..., -2.31, -2.28, -2.22],
        [-2.37, -2.44, -2.44,  ..., -2.27, -2.26, -2.16],
        [-2.48, -2.33, -2.28,  ..., -2.30, -2.30, -2.27],
        ...,
        [-2.33, -2.52, -2.34,  ..., -2.31, -2.21, -2.16],
        [-2.38, -2.38, -2.33,  ..., -2.29, -2.26, -2.17],
        [-2.33, -2.55, -2.36,  ..., -2.29, -2.27, -2.16]], grad_fn=<LogBackward0>)

Using the formula: $$ \log\left(\frac{a}{b}\right) = \log(a) - \log(b)$$ allows us to simplify the `log_softmax()` function further.

In [47]:
def log_softmax(x): return x - x.exp().sum(-1, keepdim=True).log()

Also, we can simplify things even further by using the [LogSumExp](https://en.wikipedia.org/wiki/LogSumExp) trick.

In [48]:
def logsumexp(x):
    # Taking the max on the last dim
    m = x.max(-1)[0]
    return m + (x - m[:, None]).exp().sum(-1).log()

In [49]:
# Rewriting log_softmax() using logsumexp() from PyTorch
def log_softmax(x): return x - x.logsumexp(-1, keepdim=True)

In [50]:
test_close(logsumexp(pred), pred.logsumexp(-1))

sm_pred = log_softmax(pred)
sm_pred

tensor([[-2.37, -2.49, -2.36,  ..., -2.31, -2.28, -2.22],
        [-2.37, -2.44, -2.44,  ..., -2.27, -2.26, -2.16],
        [-2.48, -2.33, -2.28,  ..., -2.30, -2.30, -2.27],
        ...,
        [-2.33, -2.52, -2.34,  ..., -2.31, -2.21, -2.16],
        [-2.38, -2.38, -2.33,  ..., -2.29, -2.26, -2.17],
        [-2.33, -2.55, -2.36,  ..., -2.29, -2.27, -2.16]], grad_fn=<SubBackward0>)

We can index into our 1-hot encoded x's using PyTorch's (and, NumPy's) advanced indexing methods.

In [51]:
# Let's pick a sample
y_train[:3]

tensor([5, 0, 4])

In [52]:
# Note the positioning of the indices
sm_pred[0, 5], sm_pred[1, 0], sm_pred[2, 4]

(tensor(-2.20, grad_fn=<SelectBackward0>),
 tensor(-2.37, grad_fn=<SelectBackward0>),
 tensor(-2.36, grad_fn=<SelectBackward0>))

In [53]:
# The indexing method allows us to get these values as follows
sm_pred[[0, 1, 2], y_train[:3]]

tensor([-2.20, -2.37, -2.36], grad_fn=<IndexBackward0>)

In [54]:
# Calculating negative log likelihood loss
def nll(input, target): return -input[range(target.shape[0]), target].mean()

In [55]:
loss = nll(sm_pred, y_train)
loss

tensor(2.30, grad_fn=<NegBackward0>)

In [56]:
# PyTorch's version
test_close(F.nll_loss(F.log_softmax(pred, -1), y_train), loss, 1e-3)

`F.log_softmax` and `F.nll_loss` are combined in one function called `F.cross_entropy`.

In [57]:
test_close(F.cross_entropy(pred, y_train), loss, 1e-3)

## Basic Training Loop

In [58]:
# Our loss function
loss_func = F.cross_entropy

In [59]:
bs = 50               # batch size

xb = x_train[0 : bs]  # Mini batch from training data
preds = model(xb)     
preds[0], preds.shape

(tensor([-0.09, -0.21, -0.08,  0.10, -0.04,  0.08, -0.04, -0.03,  0.01,  0.06], grad_fn=<SelectBackward0>),
 torch.Size([50, 10]))

In [60]:
# Picking our target, matching the mini-batch size
yb = y_train[0 : bs]
yb

tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7, 3, 8, 6, 9, 0, 5, 6, 0, 7,
        6, 1, 8, 7, 9, 3, 9, 8, 5, 9, 3])

In [61]:
# Apply the loss function
loss_func(preds, yb)

tensor(2.30, grad_fn=<NllLossBackward0>)

In [62]:
preds.argmax(dim=1)

tensor([3, 9, 3, 8, 5, 9, 3, 9, 3, 9, 5, 3, 9, 9, 3, 9, 9, 5, 8, 7, 9, 5, 3, 8, 9, 5, 9, 5, 5, 9, 3, 5, 9, 7, 5, 7, 9, 9, 3,
        9, 3, 5, 3, 8, 3, 5, 9, 5, 9, 5])

In [63]:
# Calculating the accuracy of our predictions
def accuracy(out, yb): return (out.argmax(dim=1)==yb).float().mean()

accuracy(preds, yb)

tensor(0.08)

In [64]:
# Setting a learning rate and number of epochs
lr = 0.5
epochs = 3

In [65]:
def report(loss, preds, yb): print(f'{loss:.2f}, {accuracy(preds, yb):.2f}')

In [66]:
xb, yb = x_train[:bs], y_train[:bs]
preds = model(xb)
report(loss_func(preds, yb), preds, yb)

2.30, 0.08


In [67]:
# Our training loop
for epoch in range(epochs):
    for i in range(0, n, bs):
        s = slice(i, min(n, i+bs))
        xb, yb = x_train[s], y_train[s]
        preds = model(xb)
        loss = loss_func(preds, yb)
        loss.backward()

        with torch.no_grad():
            for l in model.layers:
                if hasattr(l, 'weight'):
                    l.weight -= l.weight.grad * lr
                    l.bias   -= l.bias.grad   * lr
                    l.weight.grad.zero_()
                    l.bias  .grad.zero_()
    report(loss, preds, yb)

0.12, 0.98
0.12, 0.94
0.08, 0.96
