# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

In [1]:
import random
import math
import numpy as np
from torchvision import datasets, transforms


# Let's read the mnist dataset
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

--2022-11-14 23:38:34--  https://s3.amazonaws.com/img-datasets/mnist.npz
Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.110.118
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.110.118|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11490434 (11M) [application/octet-stream]
Saving to: ‘mnist.npz’


2022-11-14 23:38:36 (5.10 MB/s) - ‘mnist.npz’ saved [11490434/11490434]



In [2]:
def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

class Network(object):
    
    def __init__(self, sizes):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x) 
                        for x, y in zip(sizes[:-1], sizes[1:])]

    def feedforward(self, a):
        # Run the network on a batch
        a = a.T
        for b, w in zip(self.biases, self.weights):
            a = sigmoid(np.matmul(w, a)+b)
        return a
    
    def update_mini_batch(self, mini_batch, eta):
        # Update networks weights and biases by applying a single step
        # of gradient descent using backpropagation to compute the gradient.
        # The gradient is computed for a mini_batch which is as in tensorflow API.
        # eta is the learning rate      
        nabla_b, nabla_w = self.backprop(mini_batch[0].T,mini_batch[1].T)
            
        self.weights = [w-(eta/len(mini_batch[0]))*nw 
                        for w, nw in zip(self.weights, nabla_w)]
        self.biases = [b-(eta/len(mini_batch[0]))*nb 
                       for b, nb in zip(self.biases, nabla_b)]

    def backprop(self, x, y):
        # For a single input (x,y) return a pair of lists.
        # First contains gradients over biases, second over weights.
        g = x
        gs = [g] # list to store all the gs, layer by layer
        fs = [] # list to store all the fs, layer by layer
        for b, w in zip(self.biases, self.weights):
            f = np.dot(w, g)+b
            fs.append(f)
            g = sigmoid(f)
            gs.append(g)
        # backward pass <- both steps at once
        dLdg = self.cost_derivative(gs[-1], y)
        dLdfs = []
        for w,g in reversed(list(zip(self.weights,gs[1:]))):
            dLdf = np.multiply(dLdg,np.multiply(g,1-g))
            dLdfs.append(dLdf)
            dLdg = np.matmul(w.T, dLdf)
        
        dLdWs = [np.matmul(dLdf,g.T) for dLdf,g in zip(reversed(dLdfs),gs[:-1])] 
        dLdBs = [np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1) for dLdf in reversed(dLdfs)] 
        return (dLdBs,dLdWs)

    def evaluate(self, test_data):
        # Count the number of correct answers for test_data
        pred = np.argmax(self.feedforward(test_data[0]),axis=0)
        corr = np.argmax(test_data[1],axis=1).T
        return np.mean(pred==corr)
    
    def cost_derivative(self, output_activations, y):
        return (output_activations-y) 
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                y_mini_batch = y_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                self.update_mini_batch((x_mini_batch, y_mini_batch), eta)
            if test_data:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate((x_test, y_test))))
            else:
                print("Epoch: {0}".format(j))

In [3]:
class NetworkWithCheckpointing(Network):

    def __init__(self, sizes, checkpoint_every_nth_layer: int = 0, *args, **kwargs):
        super().__init__(sizes, *args, **kwargs)
        self.step = checkpoint_every_nth_layer
        if checkpoint_every_nth_layer > 0:
            self.checkpoints = []
            self.cached_activations = {}

    def forward_between_checkpoints(self, checkpoint_idx_start):
        assert(self.step > 0)
        assert(checkpoint_idx_start % self.step == 0)
        result = []
        g = self.checkpoints[checkpoint_idx_start // self.step]
        for l in range(checkpoint_idx_start+1, min(checkpoint_idx_start + self.step, self.num_layers)):
            g = sigmoid(np.dot(self.weights[l-1], g)+self.biases[l-1])
            self.cached_activations[l] = g

    def get_activation(self, layer, pop=True):
        if layer % self.step == 0:
            return self.checkpoints[layer // self.step]
        elif layer in self.cached_activations:
            res = self.cached_activations[layer]
            if pop:
                del self.cached_activations[layer]
            return res
        else:
            self.forward_between_checkpoints(layer - (layer % self.step))
            return self.get_activation(layer, pop)

    def backprop(self, x, y):
        if self.step == 0:
            return super().backprop(x, y)

        self.checkpoints = []
        self.checkpoints.append(x)
        g = x
        for l, b, w in zip(range(1, self.num_layers), self.biases, self.weights):
            g = sigmoid(np.dot(w, g)+b)
            if l % self.step == 0:
                self.checkpoints.append(g)

        # backward pass <- both steps at once
        dLdg = self.cost_derivative(self.get_activation(self.num_layers-1, pop=False), y)
        dLdfs = []
        for l in range(self.num_layers-1, 0, -1):
            g = self.get_activation(l)
            dLdf = np.multiply(dLdg,np.multiply(g,1-g))
            dLdfs.append(dLdf)
            dLdg = np.matmul(self.weights[l-1].T, dLdf)
        
        dLdWs = [np.matmul(dLdfs[self.num_layers-l-1],self.get_activation(l-1).T) for l in range(1, self.num_layers)] 
        dLdBs = [np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1) for dLdf in reversed(dLdfs)] 

        self.checkpoints = []
        return (dLdBs,dLdWs)

In [4]:
sizes = [784,100,100,10]

network0 = Network(sizes)
network1 = NetworkWithCheckpointing(sizes, checkpoint_every_nth_layer=1)
network2 = NetworkWithCheckpointing(sizes, checkpoint_every_nth_layer=2)
network4 = NetworkWithCheckpointing(sizes, checkpoint_every_nth_layer=4)

%time network0.SGD((x_train, y_train), epochs=5, mini_batch_size=100, eta=3.0, test_data=(x_test, y_test))
%time network1.SGD((x_train, y_train), epochs=5, mini_batch_size=100, eta=3.0, test_data=(x_test, y_test))
%time network2.SGD((x_train, y_train), epochs=5, mini_batch_size=100, eta=3.0, test_data=(x_test, y_test))
%time network4.SGD((x_train, y_train), epochs=5, mini_batch_size=100, eta=3.0, test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.5527
Epoch: 1, Accuracy: 0.592
Epoch: 2, Accuracy: 0.6714
Epoch: 3, Accuracy: 0.7347
Epoch: 4, Accuracy: 0.7471
CPU times: user 2min 29s, sys: 2min 2s, total: 4min 31s
Wall time: 26 s
Epoch: 0, Accuracy: 0.6938
Epoch: 1, Accuracy: 0.7314
Epoch: 2, Accuracy: 0.7483
Epoch: 3, Accuracy: 0.8177
Epoch: 4, Accuracy: 0.8273
CPU times: user 2min 49s, sys: 2min 16s, total: 5min 5s
Wall time: 30 s
Epoch: 0, Accuracy: 0.525
Epoch: 1, Accuracy: 0.5465
Epoch: 2, Accuracy: 0.5706
Epoch: 3, Accuracy: 0.6488
Epoch: 4, Accuracy: 0.7341
CPU times: user 4min 35s, sys: 3min 51s, total: 8min 27s
Wall time: 50.7 s
Epoch: 0, Accuracy: 0.1028
Epoch: 1, Accuracy: 0.1731
Epoch: 2, Accuracy: 0.1652
Epoch: 3, Accuracy: 0.2076
Epoch: 4, Accuracy: 0.2324
CPU times: user 5min 40s, sys: 4min 55s, total: 10min 36s
Wall time: 1min 3s


As we can see, learning time slightly increases when we increase `checkpoint_every_nth_layer`

When it comes to memory usage, checkpointing decreases peak memory usage.
Memory usage of network backpropagation with `l` layers `k` neurons each is as follows:
- with checkpointing every `n` layer, peak memory usage is `O(kl/n + k(n-1))` (checkpointed layers + whole cache filled with layer activations)
- without checkpointing is `O(kl)`

We can deduce that optimal (for memory usage) strategy is checkpointing with step roughly equal to `sqrt(l)`. The peak memory usage then is in the order of `O(sqrt(l)k)`