# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

In [7]:
import random
import numpy as np
from torchvision import datasets, transforms
import time

In [8]:
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

--2022-11-19 22:06:05--  https://s3.amazonaws.com/img-datasets/mnist.npz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.162.64, 52.217.89.70, 52.217.165.24, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.162.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11490434 (11M) [application/octet-stream]
Saving to: 'mnist.npz'


2022-11-19 22:06:07 (8.80 MB/s) - 'mnist.npz' saved [11490434/11490434]



In [9]:
# Let's read the mnist dataset

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [10]:
def softmax(z):
    # Stable version of the softmax function
    exps = np.exp(z - np.max(z))
    return exps / np.sum(exps, axis = 0)

def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

def sigmoid_prime(z):
    # Derivative of the sigmoid
    return sigmoid(z)*(1-sigmoid(z))

In [11]:
class Network(object):
    def __init__(self, sizes):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x) for x, y in zip(sizes[:-1], sizes[1:])]
        
    def predict(self, a):
        # Run the network on a batch
        a = a.T
        for l in range(self.num_layers - 1):
            b = self.biases[l]
            w = self.weights[l]
            if (l != self.num_layers - 2):
                a = sigmoid(np.matmul(w, a) + b)
            else:
                a = softmax(np.matmul(w, a) + b)
        return a
    
    def update_mini_batch(self, mini_batch, eta):
        # Update networks weights and biases by applying a single step
        # of gradient descent using backpropagation to compute the gradient.
        # The gradient is computed for a mini_batch which is as in tensorflow API.
        # eta is the learning rate      
        nabla_b, nabla_w = self.backprop(mini_batch[0].T,mini_batch[1].T)
            
        self.weights = [w-(eta/len(mini_batch[0]))*nw 
                        for w, nw in zip(self.weights, nabla_w)]
        self.biases = [b-(eta/len(mini_batch[0]))*nb 
                       for b, nb in zip(self.biases, nabla_b)]
        
    def backprop(self, x, y):
        # For a single input (x,y) return a pair of lists.
        # First contains gradients over biases, second over weights.
        delta_nabla_b = [np.zeros_like(p) for p in self.biases]
        delta_nabla_w = [np.zeros_like(p) for p in self.weights]
        
        z_list = []
        a = x
        a_list = [a]
        for l in range(self.num_layers - 1):
            b = self.biases[l]
            w = self.weights[l]
            z = np.dot(w, a) + b
            z_list.append(z)
            if (l != self.num_layers - 2):
                a = sigmoid(z)
            else:
                a = softmax(z)
            a_list.append(a)
        
        # backward pass <- both steps at once
        dZ = a - y
        for l in reversed(range(self.num_layers - 1)):
            dW = np.dot(dZ,  a_list[l].T)
            db = np.sum(dZ, axis = 1, keepdims = True)
            delta_nabla_b[l] = db
            delta_nabla_w[l] = dW
            if l != 0:
                dA = np.matmul(self.weights[l].T, dZ)
                dZ = np.multiply(dA, sigmoid_prime(z_list[l - 1]))
                
        return delta_nabla_b, delta_nabla_w

    def evaluate(self, test_data):
        # Count the number of correct answers for test_data
        pred = np.argmax(self.predict(test_data[0]),axis=0)
        corr = np.argmax(test_data[1],axis=1).T
        return np.mean(pred == corr)
    
    def cost_derivative(self, output_activations, y):
        return (output_activations-y) 
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                y_mini_batch = y_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                self.update_mini_batch((x_mini_batch, y_mini_batch), eta)
            if test_data:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate((x_test, y_test))))
            else:
                print("Epoch: {0}".format(j))

In [13]:
class NetworkWithCheckpointing(Network):

    def __init__(self, sizes, every_nth_layer: int = 0, *args, **kwargs):
        super().__init__(sizes, *args, **kwargs)
        self.every_nth_layer = every_nth_layer

    def forward_between_checkpoints(self, a, idx_start, idx_end):
        z_list = []
        a_list = [a]
        for l in range(idx_start, idx_end):
            b = self.biases[l]
            w = self.weights[l]
            z = np.dot(w, a) + b
            if (l != self.num_layers - 2):
                a = sigmoid(z)
            else:
                a = softmax(z)
            a_list.append(a)
            z_list.append(z)
        return a_list, z_list
    
    def backprop(self, x, y):
        # For a single input (x,y) return a pair of lists.
        # First contains gradients over biases, second over weights.
        delta_nabla_b = [np.zeros_like(p) for p in self.biases]
        delta_nabla_w = [np.zeros_like(p) for p in self.weights]
        
        a = x
        checkpoint_list = [a]
        for l in range(self.num_layers - 1):
            b = self.biases[l]
            w = self.weights[l]
            z = np.dot(w, a) + b
            if (l != self.num_layers - 2):
                a = sigmoid(z)
            else:
                a = softmax(z)
            if (((l + 1) % self.every_nth_layer == 0)):
                checkpoint_list.append(a)
        
        # backward pass <- both steps at once
        dA = 0
        dZ = 0
        for l in reversed(range(len(checkpoint_list))):
            if ((l + 1) * self.every_nth_layer < self.num_layers - 1):
                start = l * self.every_nth_layer
                stop = (l + 1) * self.every_nth_layer
                a_list, z_list = self.forward_between_checkpoints(checkpoint_list[l],
                                            start, stop)
                for k in reversed(range(start, stop)):
                    dA = np.matmul(self.weights[k + 1].T, dZ)
                    dZ = np.multiply(dA, sigmoid_prime(z_list[k - start]))
                    dW = np.dot(dZ,  a_list[k - start].T)
                    db = np.sum(dZ, axis = 1, keepdims = True)
                    delta_nabla_b[k] = db
                    delta_nabla_w[k] = dW
            else:
                start = l * self.every_nth_layer
                stop = self.num_layers - 1
                a_list, z_list = self.forward_between_checkpoints(checkpoint_list[l],
                                            start, stop)
                dZ = a_list[-1] - y
                for k in reversed(range(start, stop)):
                    dW = np.dot(dZ,  a_list[k - start].T)
                    db = np.sum(dZ, axis = 1, keepdims = True)
                    delta_nabla_b[k] = db
                    delta_nabla_w[k] = dW
                    if (k - start - 1 >= 0):
                        dA = np.matmul(self.weights[k].T, dZ)
                        dZ = np.multiply(dA, sigmoid_prime(z_list[k - start - 1]))
        return delta_nabla_b, delta_nabla_w


In [15]:
network = NetworkWithCheckpointing([784,100,100, 100, 100, 100, 10], 2)
start = time.time()
network.SGD((x_train, y_train), epochs=50, mini_batch_size=100, eta=0.1, test_data=(x_test, y_test))
end = time.time()
print(end - start)

Epoch: 0, Accuracy: 0.663
Epoch: 1, Accuracy: 0.7486
Epoch: 2, Accuracy: 0.7864
Epoch: 3, Accuracy: 0.8115
Epoch: 4, Accuracy: 0.8284
Epoch: 5, Accuracy: 0.845
Epoch: 6, Accuracy: 0.8554
Epoch: 7, Accuracy: 0.8641
Epoch: 8, Accuracy: 0.8701
Epoch: 9, Accuracy: 0.8738
Epoch: 10, Accuracy: 0.8774
Epoch: 11, Accuracy: 0.8811
Epoch: 12, Accuracy: 0.8841
Epoch: 13, Accuracy: 0.8862
Epoch: 14, Accuracy: 0.8886
Epoch: 15, Accuracy: 0.892
Epoch: 16, Accuracy: 0.8947
Epoch: 17, Accuracy: 0.8965
Epoch: 18, Accuracy: 0.8981
Epoch: 19, Accuracy: 0.8994
Epoch: 20, Accuracy: 0.9011
Epoch: 21, Accuracy: 0.9027


KeyboardInterrupt: 

In [22]:
network = Network([784,100, 100, 100, 100 ,10])
start = time.time()
network.SGD((x_train, y_train), epochs=50, mini_batch_size=100, eta=3.0, test_data=(x_test, y_test))
end = time.time()
print(end - start)

Epoch: 0, Accuracy: 0.1135
Epoch: 1, Accuracy: 0.1009
Epoch: 2, Accuracy: 0.101
Epoch: 3, Accuracy: 0.0982
Epoch: 4, Accuracy: 0.1009
Epoch: 5, Accuracy: 0.0958
Epoch: 6, Accuracy: 0.1135
Epoch: 7, Accuracy: 0.1028
Epoch: 8, Accuracy: 0.0974
Epoch: 9, Accuracy: 0.0974
Epoch: 10, Accuracy: 0.0958
Epoch: 11, Accuracy: 0.1135
Epoch: 12, Accuracy: 0.101
Epoch: 13, Accuracy: 0.1135
Epoch: 14, Accuracy: 0.098
Epoch: 15, Accuracy: 0.1009
Epoch: 16, Accuracy: 0.098
Epoch: 17, Accuracy: 0.101
Epoch: 18, Accuracy: 0.0974
Epoch: 19, Accuracy: 0.0958
Epoch: 20, Accuracy: 0.101
Epoch: 21, Accuracy: 0.1009
Epoch: 22, Accuracy: 0.0982
Epoch: 23, Accuracy: 0.0974
Epoch: 24, Accuracy: 0.101
Epoch: 25, Accuracy: 0.1145
Epoch: 26, Accuracy: 0.101
Epoch: 27, Accuracy: 0.098
Epoch: 28, Accuracy: 0.0974
Epoch: 29, Accuracy: 0.1323
Epoch: 30, Accuracy: 0.2629
Epoch: 31, Accuracy: 0.2892
Epoch: 32, Accuracy: 0.5994
Epoch: 33, Accuracy: 0.8273
Epoch: 34, Accuracy: 0.9
Epoch: 35, Accuracy: 0.9064
Epoch: 36, Acc

# Analysis

For the normal network with L layers (including input and output layers), we store:
1. L-1 matrices of weights
2. L-1 vectors of biases
3. L matrices of outputs of activation functions
4. L-1 matrices of inputs of activation functions

For a network with checkpointing every n-th layer we store:
1. L-1 matrices of weights
2. L-1 vectors of biases
3. O(L/n) matrices of outputs of activation functions
4. During backpropagation we additionally store O(n) inputs and outputs of the activation functions (calculated with forward_between_checkpoints)

Additionally, with checkpointing, we have to recalculate activations of O(L(1 - 1/n)) layers.