In [1]:
from torch import torch
from torchvision import datasets
import torch.nn as nn

import matplotlib.pyplot as plt


In [17]:

train_data = datasets.MNIST("../", train=True, download=True)
test_data = datasets.MNIST("../", train=False, download=True)

x_train = train_data.data.reshape(-1, 784).float() / 255
x_test = test_data.data.reshape(-1, 784).float() / 255
y_train = train_data.targets
y_test = test_data.targets


In [18]:
weights = nn.init.xavier_uniform_(torch.empty(784, 10, requires_grad=True))
bias = torch.zeros(10, requires_grad=True)
    

In [19]:
def log_softmax(x):
    """
    softmax(x) = x.exp() / x.exp().sum(-1).unsqueeze(-1)
    """
    return x - x.exp().sum(-1).log().unsqueeze(-1)

def model(x):
    return log_softmax(x @ weights + bias)


In [20]:
def nll(input, target):
    """
    negative log likelihood
    select the prediction for the target, and then take the mean
    use negative because we want to minimize the loss
    """
    return -input[range(target.shape[0]), target].mean()

loss_func = nll

In [21]:
def accuracy(input, target):
    preds = torch.argmax(input, dim=1)
    return (preds == target).float().mean()

In [22]:
batch_size = 64
x_batch = x_train[0:batch_size]
preds = model(x_batch)
y_batch = y_train[0:batch_size]
loss = loss_func(preds, y_batch)
print(loss)
print(accuracy(preds, y_batch))

tensor(2.3988, grad_fn=<NegBackward0>)
tensor(0.0781)


In [23]:
lr = 0.5
epochs = 2

for epoch in range(epochs):
    # run n // batch_size times if n == k * batch_size
    # run n // batch_size + 1 times if n != k * batch_size
    n = x_train.shape[0]
    for i in range((n - 1) // batch_size + 1):
        index = slice(i * batch_size, (i + 1) * batch_size)
        x_batch = x_train[index]
        y_batch = y_train[index]
        preds = model(x_batch)
        loss = loss_func(preds, y_batch)
        print(loss)
        print(accuracy(preds, y_batch))

        loss.backward()
        with torch.no_grad():
                weights -= lr * weights.grad
                bias -= lr * bias.grad
                # gradient will be accumulated, so we need to zero it
                weights.grad.zero_()
                bias.grad.zero_()

tensor(2.3988, grad_fn=<NegBackward0>)
tensor(0.0781)
tensor(1.7880, grad_fn=<NegBackward0>)
tensor(0.3750)
tensor(1.7949, grad_fn=<NegBackward0>)
tensor(0.2969)
tensor(1.5901, grad_fn=<NegBackward0>)
tensor(0.4531)
tensor(1.2725, grad_fn=<NegBackward0>)
tensor(0.6094)
tensor(0.9806, grad_fn=<NegBackward0>)
tensor(0.7812)
tensor(0.8826, grad_fn=<NegBackward0>)
tensor(0.7969)
tensor(1.0609, grad_fn=<NegBackward0>)
tensor(0.7031)
tensor(0.9911, grad_fn=<NegBackward0>)
tensor(0.7188)
tensor(1.1831, grad_fn=<NegBackward0>)
tensor(0.6406)
tensor(0.9230, grad_fn=<NegBackward0>)
tensor(0.7031)
tensor(0.9204, grad_fn=<NegBackward0>)
tensor(0.7031)
tensor(0.7825, grad_fn=<NegBackward0>)
tensor(0.7188)
tensor(0.9434, grad_fn=<NegBackward0>)
tensor(0.6562)
tensor(1.0560, grad_fn=<NegBackward0>)
tensor(0.6875)
tensor(0.7380, grad_fn=<NegBackward0>)
tensor(0.8438)
tensor(1.0216, grad_fn=<NegBackward0>)
tensor(0.6875)
tensor(0.9540, grad_fn=<NegBackward0>)
tensor(0.7188)
tensor(0.5185, grad_fn=<NegB

In [24]:
print(loss_func(model(x_batch), y_batch), accuracy(model(x_batch), y_batch))

tensor(0.0512, grad_fn=<NegBackward0>) tensor(1.)
