# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

In [1]:
import random
import numpy as np
from tqdm import tqdm
from torchvision import datasets, transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

--2022-11-19 11:03:14--  https://s3.amazonaws.com/img-datasets/mnist.npz
Translacja s3.amazonaws.com (s3.amazonaws.com)... 52.217.18.62, 52.216.207.69, 52.216.108.229, ...
Łączenie się z s3.amazonaws.com (s3.amazonaws.com)|52.217.18.62|:443... połączono.
Żądanie HTTP wysłano, oczekiwanie na odpowiedź... 200 OK
Długość: 11490434 (11M) [application/octet-stream]
Zapis do: `mnist.npz'


2022-11-19 11:03:16 (7,57 MB/s) - zapisano `mnist.npz' [11490434/11490434]



In [2]:
# Let's read the mnist dataset

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [3]:
# Utils functions
def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

def sigmoid_prime(z):
    # Derivative of the sigmoid
    return sigmoid(z)*(1-sigmoid(z))

In [197]:
class NetworkNoCheckpoints(object):
    def __init__(self, sizes):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x) 
                        for x, y in zip(sizes[:-1], sizes[1:])]
    def feedforward(self, a):
        # Run the network on a single case
        for b, w in zip(self.biases, self.weights):
            a = sigmoid(np.dot(w, a)+b)
        return a
    
    def update_mini_batch(self, x_mini_batch, y_mini_batch, eta):
        # Update networks weights and biases by applying a single step
        # of gradient descent using backpropagation to compute the gradient.
        # The gradient is computed for a mini_batch.
        # eta is the learning rate
        nabla_b = [np.zeros(b.shape) for b in self.biases]
        nabla_w = [np.zeros(w.shape) for w in self.weights]
        delta_nabla_b, delta_nabla_w = self.backprop(x_mini_batch.T, y_mini_batch.T)
        nabla_b = [nb+dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
        nabla_w = [nw+dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]
        self.weights = [w-(eta/len(x_mini_batch))*nw 
                        for w, nw in zip(self.weights, nabla_w)]
        self.biases = [b-(eta/len(x_mini_batch))*nb 
                       for b, nb in zip(self.biases, nabla_b)]
        
    def backprop(self, X, y):
        # Now X is a matrix of inputs
        # For a single input (x,y) return a tuple of lists.
        # First contains gradients over biases, second over weights.
        
        # First initialize the list of gradient arrays
        delta_nabla_b = [np.zeros_like(p) for p in self.biases]
        delta_nabla_w = [np.zeros_like(p) for p in self.weights]
        
        # Then go forward remembering all values before and after activations
        # in two other array lists
        f_ = [np.zeros_like(p) for p in self.sizes]
        g_ = [np.zeros_like(p) for p in self.sizes]
        f_[0] = g_[0] = X

        for layer in range(1, self.num_layers):
          f_[layer] = self.weights[layer - 1] @ g_[layer - 1] + self.biases[layer - 1]
          g_[layer] = sigmoid(f_[layer])
          

        # Now go backward from the final cost applying backpropagation

        f_derivatives = [np.zeros_like(p) for p in self.sizes]
        g_derivatives = [np.zeros_like(p) for p in self.sizes]
        N = len(self.sizes) - 1
        g_derivatives[N] = self.cost_derivative(g_[N], y)
        f_derivatives[N] = g_derivatives[N] * sigmoid(g_[N]) * (1 - sigmoid(g_[N]))        

        for layer in reversed(range(self.num_layers - 1)):
          g_derivatives[layer] = self.weights[layer].T @ f_derivatives[layer + 1]
          f_derivatives[layer] = g_derivatives[layer] * g_[layer] * (1 - g_[layer])
          delta_nabla_b[layer] = f_derivatives[layer + 1].sum(axis = 1).reshape(-1, 1)
          delta_nabla_w[layer] = f_derivatives[layer + 1] @ g_[layer].T
          

        return delta_nabla_b, delta_nabla_w

    def evaluate(self, x_test_data, y_test_data):
        # Count the number of correct answers for test_data
        test_results = [(np.argmax(self.feedforward(x_test_data[i].reshape(784,1))), np.argmax(y_test_data[i]))
                        for i in range(len(x_test_data))]
        # return accuracy
        return np.mean([int(x == y) for (x, y) in test_results])
    
    def cost_derivative(self, output_activations, y):
        return (output_activations-y)
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[i*mini_batch_size:(i*mini_batch_size + mini_batch_size)] 
                y_mini_batch = y_train[i*mini_batch_size:(i*mini_batch_size + mini_batch_size)]
                self.update_mini_batch(x_mini_batch, y_mini_batch, eta)
            if test_data:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate(x_test, y_test)))
            else:
                print("Epoch: {0}".format(j))


network_no_checkpoints = NetworkNoCheckpoints([784, 100, 30, 10])

Ref 1.
On each backprop iteration we save the whole foward pass in separate array and we use that when going backwards, thus memory usage is linear to the number of layers used

Ref 2.
Implementing network with checkpoints:

In [198]:
class NetworkWithCheckpointing(NetworkNoCheckpoints):

    def __init__(self, sizes, checkpoint_every_nth_layer: int = 0, *args, **kwargs):
        super().__init__(sizes, *args, **kwargs)
        self.checkpoint_distance = checkpoint_every_nth_layer
            
    def forward_between_checkpoints(self, g, checkpoint_idx_start, layer_idx_end):
        
        f_temp = []
        g_temp = []
        
        for layer in range(checkpoint_idx_start, layer_idx_end):
            f = self.weights[layer] @ g + self.biases[layer]
            f_temp.append(f)
            g = sigmoid(f)
            g_temp.append(g)
            
        return f_temp, g_temp

    def backprop(self, X, y):
        # initialize checkpoints array for checking
        checkpoints = np.full(len(self.sizes), False)
        for i in range(0, self.num_layers, self.checkpoint_distance):
            checkpoints[i] = True
        
        # initialize all the changes needed
        delta_nabla_b = [np.zeros_like(p) for p in self.biases]
        delta_nabla_w = [np.zeros_like(p) for p in self.weights]

        # Then go forward remembering only values on checkpoint layers
        sizes_checkpoints = np.array(self.sizes)[checkpoints]
        f_checkpoints = [np.zeros_like(p) for p in sizes_checkpoints]
        g_checkpoints = [np.zeros_like(p) for p in sizes_checkpoints]
        f_checkpoints[0] = g_checkpoints[0] = X

        # calculating only certain layers
        for layer in range(1, len(sizes_checkpoints)):
            f_, g_ = self.forward_between_checkpoints(
                g = g_checkpoints[layer - 1],
                checkpoint_idx_start = (layer - 1) * self.checkpoint_distance, 
                layer_idx_end = min(self.num_layers, layer * self.checkpoint_distance))
            
            f_checkpoints[layer], g_checkpoints[layer] = f_[-1], g_[-1]
        
        # Now go backward from the final cost applying backpropagation
        f_derivatives = [np.zeros_like(p) for p in self.sizes]
        g_derivatives = [np.zeros_like(p) for p in self.sizes]
        N = len(self.sizes) - 1

        # If the last layer is checkpoint
        if N % self.checkpoint_distance == 0:
            g_derivatives[N] = self.cost_derivative(g_checkpoints[-1], y)
            f_derivatives[N] = g_derivatives[N] * sigmoid(g_checkpoints[-1]) * (1 - sigmoid(g_checkpoints[-1]))
        # If the last layer is not a checkpoint
        else:
            last_checkpoint_index = N - (N % self.checkpoint_distance)
            f_temp, g_temp = self.forward_between_checkpoints(g_checkpoints[-1], last_checkpoint_index, N)
            g_derivatives[N] = self.cost_derivative(g_temp[-1], y)
            f_derivatives[N] = g_derivatives[N] * sigmoid(g_temp[-1]) * (1 - sigmoid(g_temp[-1]))
            for f, g in zip(reversed(f_temp[:-1]), reversed(g_temp[:-1])):
                g_derivatives[N - 1] = self.weights[N - 1].T @ f_derivatives[N]
                f_derivatives[N - 1] = g_derivatives[N - 1] * g * (1 - g)
                delta_nabla_b[N - 1] = f_derivatives[N].sum(axis = 1).reshape(-1, 1)
                delta_nabla_w[N - 1] = f_derivatives[N] @ g.T
                N = N - 1

        for layer in reversed(range(N)):
            if layer % self.checkpoint_distance == 0:
                g_derivatives[layer] = self.weights[layer].T @ f_derivatives[layer + 1]
                f_derivatives[layer] = g_derivatives[layer] * g_checkpoints[layer // self.checkpoint_distance] * (1 - g_checkpoints[layer // self.checkpoint_distance])
                delta_nabla_b[layer] = f_derivatives[layer + 1].sum(axis = 1).reshape(-1, 1)
                delta_nabla_w[layer] = f_derivatives[layer + 1] @ g_checkpoints[layer // self.checkpoint_distance].T
            else:
                last_checkpoint_index = layer - (layer % self.checkpoint_distance)
                f_temp, g_temp = self.forward_between_checkpoints(g_checkpoints[last_checkpoint_index // self.checkpoint_distance], last_checkpoint_index, layer)
                for f, g in zip(reversed(f_temp), reversed(g_temp)):
                    g_derivatives[layer] = self.weights[layer].T @ f_derivatives[layer + 1]
                    f_derivatives[layer] = g_derivatives[layer] * g * (1 - g)
                    delta_nabla_b[layer] = f_derivatives[layer + 1].sum(axis = 1).reshape(-1, 1)
                    delta_nabla_w[layer] = f_derivatives[layer + 1] @ g.T
                    layer = layer - 1
    
        return delta_nabla_b, delta_nabla_w

        
network_with_checkpoints = NetworkWithCheckpointing([784, 100, 30, 10], checkpoint_every_nth_layer=2)

In [199]:
%%time
network_with_checkpoints.SGD((x_train, y_train), epochs=50, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.8577
Epoch: 1, Accuracy: 0.8946
Epoch: 2, Accuracy: 0.9118
Epoch: 3, Accuracy: 0.9213
Epoch: 4, Accuracy: 0.9271
Epoch: 5, Accuracy: 0.9311
Epoch: 6, Accuracy: 0.9337
Epoch: 7, Accuracy: 0.9359
Epoch: 8, Accuracy: 0.9381
Epoch: 9, Accuracy: 0.9388
Epoch: 10, Accuracy: 0.9404
Epoch: 11, Accuracy: 0.942
Epoch: 12, Accuracy: 0.9429
Epoch: 13, Accuracy: 0.9434
Epoch: 14, Accuracy: 0.9441
Epoch: 15, Accuracy: 0.945
Epoch: 16, Accuracy: 0.9459
Epoch: 17, Accuracy: 0.9463
Epoch: 18, Accuracy: 0.9468
Epoch: 19, Accuracy: 0.947
Epoch: 20, Accuracy: 0.9473
Epoch: 21, Accuracy: 0.9472
Epoch: 22, Accuracy: 0.9478
Epoch: 23, Accuracy: 0.9481
Epoch: 24, Accuracy: 0.9481
Epoch: 25, Accuracy: 0.9484
Epoch: 26, Accuracy: 0.9482
Epoch: 27, Accuracy: 0.9486
Epoch: 28, Accuracy: 0.9489
Epoch: 29, Accuracy: 0.9489
Epoch: 30, Accuracy: 0.949
Epoch: 31, Accuracy: 0.9493
Epoch: 32, Accuracy: 0.9498
Epoch: 33, Accuracy: 0.9498
Epoch: 34, Accuracy: 0.9497
Epoch: 35, Accuracy: 0.9502
Epoch:

In [200]:
%%time
network_no_checkpoints.SGD((x_train, y_train), epochs=50, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.8548
Epoch: 1, Accuracy: 0.8906
Epoch: 2, Accuracy: 0.9056
Epoch: 3, Accuracy: 0.9134
Epoch: 4, Accuracy: 0.9197
Epoch: 5, Accuracy: 0.9253
Epoch: 6, Accuracy: 0.9277
Epoch: 7, Accuracy: 0.9309
Epoch: 8, Accuracy: 0.9341
Epoch: 9, Accuracy: 0.9362
Epoch: 10, Accuracy: 0.9371
Epoch: 11, Accuracy: 0.9384
Epoch: 12, Accuracy: 0.9399
Epoch: 13, Accuracy: 0.941
Epoch: 14, Accuracy: 0.9423
Epoch: 15, Accuracy: 0.9435
Epoch: 16, Accuracy: 0.9438
Epoch: 17, Accuracy: 0.9441
Epoch: 18, Accuracy: 0.9458
Epoch: 19, Accuracy: 0.9465
Epoch: 20, Accuracy: 0.9472
Epoch: 21, Accuracy: 0.9475
Epoch: 22, Accuracy: 0.948
Epoch: 23, Accuracy: 0.9481
Epoch: 24, Accuracy: 0.9486
Epoch: 25, Accuracy: 0.9491
Epoch: 26, Accuracy: 0.9491
Epoch: 27, Accuracy: 0.9496
Epoch: 28, Accuracy: 0.9504
Epoch: 29, Accuracy: 0.9504
Epoch: 30, Accuracy: 0.9507
Epoch: 31, Accuracy: 0.9509
Epoch: 32, Accuracy: 0.9512
Epoch: 33, Accuracy: 0.9515
Epoch: 34, Accuracy: 0.9516
Epoch: 35, Accuracy: 0.9511
Epoc

Ref 3. As we see the time difference is $6.45$ minutes vs $5.57$ minutes in favour of the 'standard' network, if we were to compare memory usage as a separate functions then $F_{standard}(a, g) = 2 * \sum_{L = 1}^{N_L}(a_L + g_L)$
and $F_{checkpoints} = \frac{N_L}{k} + (k - 1) $ where k is checkpoint checking length. When it comes to the complexity, standard implementation only calculate one forward and backward pass, thus making around $N_L$ operations and because we recompute only once each layer then version with checkpoints makes only $2 \cdot N_L$ operations

