<a href="https://colab.research.google.com/github/garycll/blogs/blob/main/sgd_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neural network from scratch with Python and PyTorch's gradient calculation

Deep learning libraries such as PyTorch and TensorFlow make it easy for us to build a deep neural network without knowing its implementation details. However, to gain a better understanding of how a neural network works, it is useful to build one from scratch.

This notebook implements stochastic gradient descent (SGD, the optimization algorithm for training a neural network) with Python and apply it to train a neural network. To apply the same SGD codes to different network architectures, we use Pytorch to calculate graidents (and that's the only Pytorch functionality I use). We illustrate this advantage by applying SGD to two network architectures, namely a linear model and a simple network with 2 layers.

To demonstrate the codes in this notebook, we use the famous [MNIST](http://yann.lecun.com/exdb/mnist/) data.

# Reading the MNIST data

First of all, we read the MNIST data and split it into training and validation sets

In [1]:
import torch
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

In [2]:
orig_x, orig_y = fetch_openml('mnist_784', version=1, return_X_y=True)

In [3]:
x = (orig_x/255).astype('float32') # divide by 255 to normalize the pixel values between 0 and 1
y = orig_y.astype('long')

In [4]:
train_x, valid_x, train_y, valid_y = train_test_split(x, y, test_size=0.15, random_state=42)

In [5]:
train_x = torch.from_numpy(train_x)
valid_x = torch.from_numpy(valid_x)
train_y = torch.from_numpy(train_y)
valid_y = torch.from_numpy(valid_y)

In [6]:
print(f'There are {train_x.shape[0]} training samples and {valid_x.shape[0]} validation samples')

There are 59500 training samples and 10500 validation samples


# Neural Network and SDG on MNIST

A key component in neural network is the loss function. In this notebook, we use cross entropy as the loss function for a multi-class classification problem since MNIST has 10 classes.

We implement cross entropy from scratch instead of using F.cross_entropy in Pytorch. To improve numerical stability in the calculation of cross entropy, we use the LogSumExp trick. I copied [the codes](https://github.com/fastai/course-v3/blob/master/nbs/dl2/03_minibatch_training.ipynb) from Jeremy Howard (see the Cross Entropy Loss section in his notebook for details).

In [7]:
def logsumexp(x):
  m = x.max(-1)[0]
  return m + (x - m[:, None]).exp().sum(-1).log()

In [8]:
def log_softmax(x):
    return x - logsumexp(x).unsqueeze(1)

In [9]:
def nll(predictions, target):
    return -predictions[range(target.shape[0]), target].mean()

In [10]:
def mnist_loss(predictions, targets):
    # equivalent to F.cross_entropy(predictions, targets) in Pytorch
    # return F.cross_entropy(predictions, targets)
    return nll(log_softmax(predictions), targets)

Below are the codes of neural network and SDG

In [11]:
def init_params(size, std=1.0):
    return (torch.randn(size)*std).requires_grad_()

In [12]:
def linear1(xb, params):
    weights, bias = params
    return xb@weights + bias

In [13]:
def batch_accuracy(xb, yb):
    correct = (torch.argmax(xb, dim=1) == yb)
    return correct.float().mean()

In [14]:
def validate_epoch(model, params, dataloader):
    accs = [batch_accuracy(model(xb, params), yb) for xb, yb in dataloader]
    return round(torch.stack(accs).mean().item(), 4)

In [15]:
def train_epoch(model, lr, params, dataloader):
    for xb, yb in dataloader:
        preds = model(xb, params)
        loss = mnist_loss(preds, yb)
        loss.backward()
        for p in params:
            p.data -= p.grad * lr
            p.grad.zero_()

In [16]:
def train_model(model, epochs, lr, params, train_dl, valid_dl):
    for i in range(epochs):
        train_epoch(model, lr, params, train_dl)
        print(f'{i+1}: accuracy={validate_epoch(model, params, valid_dl):.4f}')

Below I implement a simple data loader so we can train the model with mini batch.

In [17]:
class SimpleDataLoader:
    def __init__(self, x, y, batch_size=256):
        self.x = x
        self.y = y
        self.batch_size = batch_size
    
    def __iter__(self):
        self.idx = 0
        return self
        
    def __next__(self):
        size = len(self.y)
        if self.idx < size and self.idx + self.batch_size < size:
            x = self.x[self.idx : self.idx + self.batch_size]
            y = self.y[self.idx : self.idx + self.batch_size]
        elif self.idx < size:
            x = self.x[self.idx :]
            y = self.y[self.idx :]
        else:
            raise StopIteration
        self.idx += self.batch_size
        return (x, y)

In [18]:
train_dl = iter(SimpleDataLoader(train_x, train_y, 256))
valid_dl = iter(SimpleDataLoader(valid_x, valid_y, 256))

# A linear model using SGD

In [19]:
lr = 1
weights = init_params((28*28, 10))
bias = init_params(10)
params = weights, bias

In [20]:
train_model(linear1, 20, lr, params, train_dl, valid_dl)

1: accuracy=0.8483
2: accuracy=0.8742
3: accuracy=0.8825
4: accuracy=0.8878
5: accuracy=0.8906
6: accuracy=0.8942
7: accuracy=0.8964
8: accuracy=0.8986
9: accuracy=0.8996
10: accuracy=0.9011
11: accuracy=0.9019
12: accuracy=0.9028
13: accuracy=0.9036
14: accuracy=0.9045
15: accuracy=0.9053
16: accuracy=0.9058
17: accuracy=0.9065
18: accuracy=0.9081
19: accuracy=0.9089
20: accuracy=0.9094


# A 2-layer simple network using SDG

In [21]:
def simple_net(xb, parmas):
    w1, b1, w2, b2 = params
    res = xb@w1 + b1
    res = res.max(torch.tensor(0.0))
    res = res@w2 + b2
    return res

In [22]:
lr = 1
w1 = init_params((28*28, 256))
b1 = init_params(256)
w2 = init_params((256, 10))
b2 = init_params(10)
params = w1, b1, w2, b2

In [23]:
train_model(simple_net, 20, lr, params, train_dl, valid_dl)

1: accuracy=0.8092
2: accuracy=0.8690
3: accuracy=0.8871
4: accuracy=0.8980
5: accuracy=0.9043
6: accuracy=0.9091
7: accuracy=0.9122
8: accuracy=0.9156
9: accuracy=0.9192
10: accuracy=0.9217
11: accuracy=0.9244
12: accuracy=0.9276
13: accuracy=0.9289
14: accuracy=0.9310
15: accuracy=0.9329
16: accuracy=0.9355
17: accuracy=0.9356
18: accuracy=0.9373
19: accuracy=0.9370
20: accuracy=0.9373


References:
*   [Neural Network From Scratch with NumPy and MNIST](https://mlfromscratch.com/neural-network-tutorial/#/)
*   [The fastai book chapter 4](https://github.com/fastai/fastbook/blob/master/04_mnist_basics.ipynb)