# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

In [None]:
import random
import numpy as np
from torchvision import datasets, transforms

In [None]:
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

--2022-11-03 08:30:09--  https://s3.amazonaws.com/img-datasets/mnist.npz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.76.230
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.76.230|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11490434 (11M) [application/octet-stream]
Saving to: ‘mnist.npz’


2022-11-03 08:30:09 (52.6 MB/s) - ‘mnist.npz’ saved [11490434/11490434]



In [None]:
# Let's read the mnist dataset

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [None]:
from scipy.special import softmax
from timeit import default_timer as timer

def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

def relu(z):
  return np.maximum(0, z)

class Network(object):
    def __init__(self, sizes, momentum=0.9, l2=0.01, drop_rate_input=0.9, drop_rate_hidden=0.5):
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.momentum = momentum
        self.l2=l2

        assert 0 <= drop_rate_input <= 1
        assert 0 <= drop_rate_hidden <= 1

        self.drop_rate_input = drop_rate_input
        self.drop_rate_hidden = drop_rate_hidden

        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x) 
                        for x, y in zip(sizes[:-1], sizes[1:])]
        
        # momentum
        self.weights_momentum = [np.zeros_like(w) for w in self.weights]
        self.biases_momentum = [np.zeros_like(b) for b in self.biases]

    def feedforward(self, a):
        a = a.T
        for layer_index, (b, w) in enumerate(zip(self.biases, self.weights)):
            a = np.matmul(w, a)+b
            if layer_index < self.num_layers-2:
              a = sigmoid(a)
            else:
              a = softmax(a, axis=0)
        return a
    
    def update_mini_batch(self, mini_batch, eta):
        nabla_b, nabla_w, time_used, mem_used, max_mem_objects = self.backprop(mini_batch[0].T,mini_batch[1].T)
        nabla_b = [b / len(mini_batch[0]) for b in nabla_b]
        nabla_w = [w / len(mini_batch[0]) for w in nabla_w]

        # l2 regularization
        self.weights = [w*(1-self.l2*eta) for w in self.weights]

        self.biases_momentum = [self.momentum*b + (1 - self.momentum)*np.multiply(eta, nb) for b, nb in zip(self.biases_momentum, nabla_b)]
        self.weights_momentum = [self.momentum*w + (1 - self.momentum)*np.multiply(eta, nw) for w, nw in zip(self.weights_momentum, nabla_w)]
            
        self.weights = [w-nw 
                        for w, nw in zip(self.weights, self.weights_momentum)]
        self.biases = [b-nb 
                       for b, nb in zip(self.biases, self.biases_momentum)]

        self.time_used.append(time_used)
        self.mem_used.append(mem_used)
        self.max_mem_objects.append(max_mem_objects)

        
    def backprop(self, x, y):
        start_time = timer()

        g = x
        
        # dropout input
        mask = (np.random.rand(*g.shape) < self.drop_rate_input).astype(np.float32)
        g = mask * g / self.drop_rate_input


        gs = [g] # list to store all the gs, layer by layer


        for layer_index, (b, w) in enumerate(zip(self.biases, self.weights)):
            f = np.dot(w, g)+b
            if layer_index < self.num_layers-2:
              # dropout hidden layers
              mask = (np.random.rand(*g.shape) < self.drop_rate_hidden).astype(np.float32)
              g = mask * g / self.drop_rate_hidden
              
              g = sigmoid(f)
            else:
              g = softmax(f, axis=0)

            gs.append(g)

        dLdg = self.cost_derivative(gs[-1], y)
        dLdfs = []
        for layer_index, (w,g) in reversed(list(enumerate(zip(self.weights,gs[1:])))):
            if layer_index < self.num_layers-2:
              dLdf = np.multiply(dLdg,np.multiply(g,1-g))
            else:
              dLdf = dLdg
            dLdfs.append(dLdf)
            dLdg = np.matmul(w.T, dLdf)
        
        dLdWs = [np.matmul(dLdf,g.T) for dLdf,g in zip(reversed(dLdfs),gs[:-1])] 
        dLdBs = [np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1) for dLdf in reversed(dLdfs)] 

        mem_used =  sum([x.nbytes if x is not None else 0 for x in gs[1:]])
        max_mem_objects = len([x for x in gs[1:] if x is not None])
        end_time = timer()
        
        return dLdBs, dLdWs, end_time-start_time, mem_used, max_mem_objects

    def evaluate(self, test_data):
        pred = np.argmax(self.feedforward(test_data[0]),axis=0)
        corr = np.argmax(test_data[1],axis=1).T
        return np.mean(pred==corr)
    
    def cost_derivative(self, output_activations, y):
        return (output_activations-y) 
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        self.time_used = []
        self.mem_used = []
        self.max_mem_objects = []

        best_acc = -1
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                y_mini_batch = y_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                self.update_mini_batch((x_mini_batch, y_mini_batch), eta)
            if test_data:
                accuracy = self.evaluate((x_test, y_test))
                best_acc = max(best_acc, accuracy)
                print("Epoch: {0}, Accuracy: {1}".format(j, accuracy))
            else:
                print("Epoch: {0}".format(j))
        print(f"\nBest accuracy: {best_acc}")
        avg_time = sum(self.time_used)/epochs
        avg_mem = sum(self.mem_used)/(len(self.mem_used) * 1024 * 1024)
        max_mem_objects = max(self.max_mem_objects)

        print(f"Mean time / epoch  : {avg_time:.3f}s")
        print(f"Mean memory / batch: {avg_mem:.2f}Mb")
        print(f"Max no. of objects in memory: {max_mem_objects}")




In [None]:
class NetworkWithCheckpointing(Network):

    def __init__(self, sizes, checkpoint_every_nth_layer: int = 0, *args, **kwargs):
        super().__init__(sizes, *args, **kwargs)
        self.checkpoint_every_nth_layer = max(checkpoint_every_nth_layer, 1)

    def forward_between_checkpoints(self, a, checkpoint_idx_start, layer_idx_end):
        gs = []
        g = a
        for layer_index in range(checkpoint_idx_start, layer_idx_end):
            w = self.weights[layer_index]
            b = self.biases[layer_index]

            f = np.dot(w, g)+b

            # dropout hidden layers
            mask = (np.random.rand(*g.shape) < self.drop_rate_hidden).astype(np.float32)
            g = mask * g / self.drop_rate_hidden
          
            g = sigmoid(f)
            gs.append(g)

        return gs

    def backprop(self, x, y):
        start_time = timer()

        g = x
        
        # dropout input
        mask = (np.random.rand(*g.shape) < self.drop_rate_input).astype(np.float32)
        g = mask * g / self.drop_rate_input


        gs = [None] * (len(self.weights) + 1) # list to store every nth gs
        gs[0] = g

        for layer_index, (b, w) in enumerate(zip(self.biases, self.weights)):
            f = np.dot(w, g)+b
            if layer_index < self.num_layers-2:
              # dropout hidden layers
              mask = (np.random.rand(*g.shape) < self.drop_rate_hidden).astype(np.float32)
              g = mask * g / self.drop_rate_hidden
      
              g = sigmoid(f)
            else:
              g = softmax(f, axis=0)

            if layer_index % self.checkpoint_every_nth_layer == 0:
              gs[layer_index+1] = g

        # backward pass
        dLdg = self.cost_derivative(g, y)
        dLdf = dLdg

        dLdWs, dLdBs = [], []
        last_checkpoint_index = len(self.weights) - ((len(self.weights) - 1) % self.checkpoint_every_nth_layer)
        was_checkpoint_propagated = False

        max_mem = -1
        max_mem_objects = -1
        for layer_index, w in reversed(list(enumerate(self.weights))):
            if layer_index < last_checkpoint_index:
              # clear unnecessary memory
              for gs_index in range(last_checkpoint_index, min(len(gs)-1, last_checkpoint_index+self.checkpoint_every_nth_layer)+1):
                gs[gs_index] = None

              # update checkpoint
              last_checkpoint_index = max(0, last_checkpoint_index - self.checkpoint_every_nth_layer)
              was_checkpoint_propagated = False


            if not was_checkpoint_propagated:
              checkpoint_out = gs[last_checkpoint_index]

              gs_between_checkpoints = self.forward_between_checkpoints(checkpoint_out, last_checkpoint_index, layer_index)
              for index, g in enumerate(gs_between_checkpoints):
                gs_index = last_checkpoint_index + 1 + index
                gs[gs_index] = g
                
              was_checkpoint_propagated = True
            
            g = gs[layer_index]

            dLdWs.append(np.matmul(dLdf,g.T))
            dLdBs.append(np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1))

            dLdg = np.matmul(w.T, dLdf)
            dLdf = np.multiply(dLdg,np.multiply(g,1-g))

            current_mem = sum([x.nbytes if x is not None else 0 for x in gs[1:]])
            max_mem = max(current_mem, max_mem)

            mem_objects = len([x for x in gs[1:] if x is not None])
            if mem_objects > max_mem_objects:
              # print(['X' if x is not None else 'O' for x in gs[1:]])
              max_mem_objects = mem_objects

        end_time = timer()
        
        return reversed(dLdBs), reversed(dLdWs), end_time-start_time, max_mem, max_mem_objects


In [None]:
layers = [784, 100, 200, 200, 200, 200, 200, 200, 200, 200, 200, 10]
epochs = 1

network = Network(layers, momentum=0.9, l2=1e-3, drop_rate_input=0.9, drop_rate_hidden=0.5)
network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=1000, eta=0.5, test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.1605

Best accuracy: 0.1605
Mean time / epoch  : 18.519s
Mean memory / batch: 14.57Mb
Max no. of objects in memory: 11


In [None]:
network = NetworkWithCheckpointing(layers, momentum=0.9, l2=1e-3, drop_rate_input=0.9, drop_rate_hidden=0.5, checkpoint_every_nth_layer=2)
network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=1000, eta=0.5, test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.2721

Best accuracy: 0.2721
Mean time / epoch  : 23.933s
Mean memory / batch: 8.39Mb
Max no. of objects in memory: 6


In [None]:
network = NetworkWithCheckpointing(layers, momentum=0.9, l2=1e-3, drop_rate_input=0.9, drop_rate_hidden=0.5, checkpoint_every_nth_layer=3)
network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=1000, eta=0.5, test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.2116

Best accuracy: 0.2116
Mean time / epoch  : 24.390s
Mean memory / batch: 6.87Mb
Max no. of objects in memory: 5


Backprop with checkpointing with $ n $ hidden layers and $\text{checkpoint_every_nth_layer}=k$ in the worst case needs approximately $ \lfloor \frac{n}{k} \rfloor + k $ objects stored in memory in peak: $ \lfloor \frac{n}{k} \rfloor $ for storing the checkpoints and $k$ for storing the outputs between checkpoints, while the default solution requires storing $n$ layers. Therefore, increasing k decreases memory consuption (to a certain moment).

However, backprop with checkpointing needs to to forward steps $2n-n/k$ times instead of $n$ times, which means that increasing $k$ results in increasing time taken.