In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_02 import *
import torch.nn.functional as F

In [3]:
mpl.rcParams['image.cmap'] = 'gray'

In [4]:
x_train, y_train, x_valid, y_valid = get_data()

In [63]:
n_data,input_dim = x_train.shape
n_class = y_train.max()+1
hid_dim = 50

In [6]:
class Model(nn.Module):
    def __init__(self, input_dim, hid_dim, output_dim):
        super().__init__()
        self.layers = [nn.Linear(input_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, output_dim)]
        
    def __call__(self, x):
        for l in self.layers: x=l(x)
        return x

In [7]:
model = Model(input_dim, hid_dim, 10)

In [8]:
x_train.shape

torch.Size([50000, 784])

In [9]:
pred = model(x_train)

In [40]:
pred

tensor([[ 6.5276e-03,  5.5869e-02, -5.2176e-02,  ...,  1.2981e-01,
          1.0999e-01,  7.7809e-02],
        [ 9.8875e-02,  1.3690e-01,  4.9177e-02,  ...,  9.8473e-02,
          1.6890e-01,  1.8723e-01],
        [ 1.4474e-02,  1.1824e-01,  1.3496e-02,  ..., -2.6546e-05,
          1.6337e-02,  5.1659e-02],
        ...,
        [-1.0649e-01,  4.6565e-02,  9.4658e-02,  ...,  4.5425e-02,
         -1.1333e-02, -5.1588e-02],
        [-4.0015e-02,  1.2427e-01, -7.3041e-02,  ..., -9.8888e-03,
          2.4177e-02,  8.2370e-03],
        [-5.6798e-02,  3.8407e-03, -9.6393e-03,  ...,  1.5783e-02,
          4.8957e-02, -2.3656e-02]], grad_fn=<AddmmBackward>)

## Cross entropy loss

First, we will need to compute the `softmax` of our activations. This is defined by:

$$\hbox{softmax(x)}_{i} = \frac{e^{x_{i}}}{e^{x_{0}} + e^{x_{1}} + \cdots + e^{x_{n-1}}}$$

or more concisely:

$$\hbox{softmax(x)}_{i} = \frac{e^{x_{i}}}{\sum_{0 \leq j \leq n-1} e^{x_{j}}}$$

In practice, we will need the **log of the softmax** when we calculate the loss.

In [41]:
def log_softmax(x): return (x.exp()/x.exp().sum(dim=-1,keepdim=True)).log()

In [42]:
sm_pred = log_softmax(pred)

In [43]:
sm_pred

tensor([[-2.3348, -2.2855, -2.3935,  ..., -2.2115, -2.2313, -2.2635],
        [-2.2881, -2.2501, -2.3378,  ..., -2.2885, -2.2181, -2.1997],
        [-2.2946, -2.1909, -2.2956,  ..., -2.3091, -2.2928, -2.2574],
        ...,
        [-2.3868, -2.2338, -2.1857,  ..., -2.2349, -2.2917, -2.3319],
        [-2.3234, -2.1592, -2.3565,  ..., -2.2933, -2.2592, -2.2752],
        [-2.3319, -2.2713, -2.2848,  ..., -2.2593, -2.2262, -2.2988]],
       grad_fn=<LogBackward>)

In [44]:
def nll(input, target): 
    
    return -input[range(target.shape[0]), target].mean()

In [45]:
loss = nll(sm_pred, y_train)

In [46]:
loss

tensor(2.2962, grad_fn=<NegBackward>)

**Note:** The formula

$$\log \left ( \frac{a}{b} \right ) = \log(a) - \log(b)$$

gives a simplification when we compute the log softmax, which was previously defined as 

```py
(x.exp()/(x.exp().sum(-1,keepdim=True))).log()
```

In [47]:
def log_softmax(x): return x - x.exp().sum(-1,keepdim=True).log()

In [48]:
test_near(nll(log_softmax(pred), y_train), loss)

Then, there is a way to compute the `log` of the sum of exponentials in a more stable way, called the `LogSumExp` trick. The idea is to use the following formula:

$$
\log \left ( \sum_{j=1}^{n} e^{x_{j}} \right ) = \log \left ( e^{a} \sum_{j=1}^{n} e^{x_{j}-a} \right ) = a + \log \left ( \sum_{j=1}^{n} e^{x_{j}-a} \right )
$$

where a is the maximum of the $x_j$

In [50]:
def logsumexp(x):
    m = x.max(1)[0]
    return m + (x-m[:,None]).exp().sum(-1).log()

In [52]:
test_near(logsumexp(pred), pred.logsumexp(1))

In [53]:
def log_softmax(x): return x - x.logsumexp(-1,keepdim=True)

In [54]:
test_near(nll(log_softmax(pred), y_train), loss)

In [55]:
test_near(F.nll_loss(F.log_softmax(pred, -1), y_train), loss)

In PyTorch, `F.log_softmax` and `F.nll_loss` are combined in one optimized function, `F.cross_entropy`.

In [56]:
test_near(F.cross_entropy(pred, y_train), loss)

## Basic training loop
Basically the training loop repeats over the following steps:

- Get the output of the model on a batch of inputs
- Compare the output to the labels we have and compute a loss
- Calculate the gradients of the loss with respect to every parameter of the model
- Update said parameters with those gradients to make them a little bit better

In [57]:
loss_func = F.cross_entropy

In [58]:
#export
def accuracy(out: torch.Tensor, yb:torch.Tensor) -> torch.Tensor:
    return (torch.argmax(out, dim=1)==yb).float().mean()

In [59]:
batch_sz = 64

xb = x_train[:batch_sz]
yb = y_train[:batch_sz]
preds = model(xb)
preds[0], preds.shape

(tensor([ 0.0065,  0.0559, -0.0522,  0.1353, -0.1861,  0.1592, -0.1072,  0.1298,
          0.1100,  0.0778], grad_fn=<SelectBackward>), torch.Size([64, 10]))

In [60]:
loss_func(preds, yb)

tensor(2.2989, grad_fn=<NllLossBackward>)

In [61]:
accuracy(preds, yb)

tensor(0.1250)

In [62]:
lr = 0.5   # learning rate
epochs = 1 # how many epochs to train for

In [72]:
for epoch in range(epochs):
    for i in range((n_data-1)// batch_sz + 1):
#         set_trace()
        start_i = i*batch_sz
        end_i = start_i+batch_sz
        xb = x_train[start_i:end_i]
        # print(xb.shape)
        yb = y_train[start_i:end_i]
        print(model(xb).shape)
        loss = loss_func(model(xb), yb)

        loss.backward()
        with torch.no_grad():
            for l in model.layers:
                if hasattr(l, 'weight'):
                    l.weight -= l.weight.grad * lr
                    l.bias   -= l.bias.grad   * lr
                    l.weight.grad.zero_()
                    l.bias  .grad.zero_()

torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([6

torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([6

In [78]:
for epoch in range(epochs):
    for i in range((n_data-1)// batch_sz + 1):
        start_i = i*batch_sz
        end_i = start_i + batch_sz
        
        xb = x_train[start_i:end_i]
        yb = y_train[start_i:end_i]
        
        preds = model(xb)
        loss = loss_func(preds, yb)
        
        loss.backward()
        with torch.no_grad():
            for l in model.layers:
                if hasattr(l, 'weight'):
                    l.weight -= l.weight.grad * lr
                    l.bias   -= l.bias.grad   * lr
                    l.weight.grad.zero_()
                    l.bias  .grad.zero_()

In [79]:
loss_func(model(xb), yb), accuracy(model(xb), yb)

(tensor(0.0016, grad_fn=<NllLossBackward>), tensor(1.))