# Initial setup

In [None]:
from pathlib import Path
import pickle
import gzip

from fastcore.test import test_close
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch

torch.manual_seed(1103)

matplotlib.rcParams["image.cmap"] = "gray"
torch.set_printoptions(precision=2, linewidth=160, sci_mode=False)
np.set_printoptions(precision=2, linewidth=160)

data_path = Path("data")
mnist_path = data_path / "mnist.pkl.gz"
with gzip.open(mnist_path) as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="bytes")
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Model(nn.Module):
    def __init__(self, n_in, n_h, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in, n_h), nn.ReLU(), nn.Linear(n_h, n_out)]

    def __call__(self, x, target):
        for l in self.layers:
            x = l(x)
        return x

In [None]:
nh = 50
c = y_train.max().item() + 1
c

10

In [None]:
model = Model(x_train.shape[1], nh, c)

In [None]:
preds = model(x_train, x_valid)

In [None]:
preds.shape

torch.Size([50000, 10])

## Cross entropy loss

In [None]:
def log_softmax(x):
    exps = torch.exp(x)
    return torch.log(exps / exps.sum(keepdims=True, dim=1))

In [None]:
res = log_softmax(preds)
res.shape, res

(torch.Size([50000, 10]),
 tensor([[-2.40, -2.41, -2.43,  ..., -2.20, -2.25, -2.14],
         [-2.30, -2.51, -2.23,  ..., -2.18, -2.30, -2.26],
         [-2.34, -2.43, -2.34,  ..., -2.23, -2.20, -2.20],
         ...,
         [-2.37, -2.42, -2.38,  ..., -2.18, -2.27, -2.23],
         [-2.31, -2.40, -2.42,  ..., -2.21, -2.22, -2.25],
         [-2.41, -2.43, -2.38,  ..., -2.17, -2.30, -2.19]], grad_fn=<LogBackward0>))

In [None]:
def log_softmax(x):
    return x - torch.log(torch.exp(x).sum(keepdims=True, dim=1))

In [None]:
res = log_softmax(preds)
res.shape, res

(torch.Size([50000, 10]),
 tensor([[-2.40, -2.41, -2.43,  ..., -2.20, -2.25, -2.14],
         [-2.30, -2.51, -2.23,  ..., -2.18, -2.30, -2.26],
         [-2.34, -2.43, -2.34,  ..., -2.23, -2.20, -2.20],
         ...,
         [-2.37, -2.42, -2.38,  ..., -2.18, -2.27, -2.23],
         [-2.31, -2.40, -2.42,  ..., -2.21, -2.22, -2.25],
         [-2.41, -2.43, -2.38,  ..., -2.17, -2.30, -2.19]], grad_fn=<SubBackward0>))

In [None]:
test_close(F.log_softmax(preds, dim=1), res)

In [None]:
def logsumexp(x):
    max = torch.max(x)
    return (x - max).exp().sum(keepdims=True, dim=1).log() + max

In [None]:
def log_softmax(x):
    return x - logsumexp(x)

In [None]:
test_close(res, log_softmax(preds))

In [None]:
ll = F.log_softmax(preds, dim=1)

In [None]:
ll.shape, y_train.shape

(torch.Size([50000, 10]), torch.Size([50000]))

In [None]:
ll[range(y_train.shape[0]), y_train].shape

torch.Size([50000])

In [None]:
def nll(x, target):
    ll = F.log_softmax(x, dim=1)
    return -ll[range(target.shape[0]), target].mean()

In [None]:
res = nll(preds, y_train)

In [None]:
test_close(res, F.cross_entropy(preds, y_train))

## Basic training loop

In [None]:
xb = x_train[:512]
yb = y_train[:512]
preds_b = model(xb, yb)

In [None]:
loss = F.cross_entropy(preds_b, yb)

In [None]:
cls = preds_b.argmax(dim=1)

In [None]:
cls.shape

torch.Size([512])

In [None]:
def accuracy(out, yb):
    return (out==yb).float().mean()

In [None]:
accuracy(cls, yb)

tensor(0.14)

In [None]:
def accuracy(preds, yb):
    return (preds.argmax(dim=1)==yb).float().mean().item()

In [None]:
accuracy(preds_b, yb)

0.13671875

In [None]:
lr = 0.1
epochs = 5
bs = 512

In [None]:
hasattr(epochs, "1")

False

In [None]:
xb = x_train[:bs]
yb = y_train[:bs]
predb = model(xb, yb)
print(loss.item(), accuracy(predb, yb))

for i in range(epochs):
    for b in range(0, len(x_train), bs):
        xb = x_train[b:b+bs]
        yb = y_train[b:b+bs]
        predb = model(xb, yb)
        loss = F.cross_entropy(predb, yb)
        loss.backward()

        with torch.no_grad():
            for l in model.layers:
                if hasattr(l, "weight"):
                    l.weight -= l.weight.grad * lr
                    l.bias -= l.bias.grad * lr
                    l.weight.grad.zero_()
                    l.bias.grad.zero_()
    print(loss.item(), accuracy(predb, yb))

2.295273542404175 0.13671875
0.7081261277198792 0.836309552192688
0.5039593577384949 0.8720238208770752
0.43487557768821716 0.9017857313156128
0.39776724576950073 0.9107142686843872
0.37305158376693726 0.9196428656578064
