<a href="https://colab.research.google.com/github/mattbarrett98/mikit-learn/blob/main/MyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# fast linear algebra, useful for every algorithm
import numpy as np

#Base network

In [2]:
class BaseNetwork:
    """Base network from which all our networks will inherit since they
    all have some shared functionality. They each have a classification accuracy
    associated with them and need a method to calculate that accuracy. 

    We also have a magic method to allow us to compare our implementation to 
    PyTorch's. We consider the implementations to be equivalent if their
    classification accuracies are within 0.5% of each other. 
    """
    def __init__(self):
        self.accuracy = None

    def softmax(self, x):
        """Softmax activation function for final layer of networks."""
        p = np.exp(x)
        return p / np.sum(p, axis=0)

    def relu(self, x):
        """Relu activation function. Note that max(x, 0) = (x + |x|)/2."""
        return (x + abs(x))/2     

    def evaluate_predictions(self, predictions, true_classes):
        n_correct_predictions = sum(predictions == true_classes)
        n_predictions = predictions.shape[0]
        self.accuracy = 100 * n_correct_predictions/n_predictions
        return

    def __eq__(self, pytorch_model):
        diff = self.accuracy - pytorch_model.accuracy
        diff = round(diff, 2)
        if diff == 0:
           return "True, MyTorch's accuracy is the same as PyTorch's."
        if diff > 0 and diff <= 0.5:
           return f"True, MyTorch's accuracy is just {diff}% higher than PyTorch's."
        if diff > 0 and diff > 0.5:
           return f"False, MyTorch's accuracy is {diff}% higher than PyTorch's."   
        
        if diff < 0 and diff >= -0.5:
           return f"True, MyTorch's accuracy is just {-diff}% lower than PyTorch's."  
        if diff < 0 and diff < -0.5:
           return f"False, MyTorch's accuracy is {-diff}% lower than PyTorch's."

# Multilayer perceptron

In [None]:
class MyMLP(BaseNetwork):
    """Classifier based on a multilayer perceptron. We use 2 hidden layers, relu
    activations for hidden layers and softmax activation for output layer. We 
    use the negative log likelihood as our loss function and use the Adam 
    optimisation algorithm to minimise it. 

    Attributes
    ----------
    epochs : int. The number of times we loop through the whole dataset during
    training. 

    batch_size : int. The number of training samples we use to calculate the 
    loss for each stochastic optimisation step.

    layer_sizes : list of ints. The number of neurons in each layer of the 
    network, including the input and output layer. 

    learning_rate : float. Controls step sizes in Adam. 

    beta_1 : float, must be in [0,1). Decay rate for the first moments in Adam.

    beta_2 : float, in [0,1). Decay rate for second moments in Adam.   
    """
    def __init__(
        self,
        epochs,
        batch_size,
        layer_sizes,
        learning_rate,
        beta_1,
        beta_2
    ):
        self.epochs = epochs
        self.batch_size = batch_size
        self.layer_sizes = layer_sizes
        self.learning_rate = learning_rate
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.optimal_par = None

    def _unpack(self, par):
        """Takes a flattened vector of parameters, par, and returns lists of 
        arrays of weights and biases for each network layer.  
        """
        lower_idx = upper_idx = 0
        weights = []
        biases = []
        for i in range(len(self.layer_sizes) - 1):
            n_neurons = self.layer_sizes[i]
            n_neurons_next = self.layer_sizes[i + 1]
            upper_idx += n_neurons * n_neurons_next
            w = par[lower_idx:upper_idx].reshape(n_neurons_next, n_neurons)
            b = par[upper_idx:upper_idx + n_neurons_next]
            weights.append(w)
            biases.append(b)
            upper_idx += n_neurons_next
            lower_idx = upper_idx
        return weights, biases

    def _pack(self, *parameters):
        """Takes any number of parameter arrays, and returns all the parameters
        as one flattened vector. 
        """
        packed_par = parameters[0].ravel()
        for i in range(1, len(parameters)):
            packed_par = np.concatenate((packed_par, parameters[i].ravel()))
        return packed_par

    def _mlp_grad(self, batch, truth, parameters):
        """This function first makes a forward pass through the network to find 
        the outputs. Using the outputs and backpropogation we make a backward 
        pass through the network to find all the gradients. 

        Parameters
        ----------
        batch : numpy array of shape (n_features, batch_size) containing batch
        of training data.

        truth : numpy array of shape (n_classes, batch_size) containing the one
        hot encoded true classes of each batch observation. 

        parameters : vector containing all network parameter values.

        Returns
        -------
        flattened vector containing gradients of the loss wrt all parameters. 
        """
        weights, biases = self._unpack(parameters)
        z1 = np.matmul(weights[0], batch) + biases[0][:, np.newaxis]
        a1 = self.relu(z1)
        z2 = np.matmul(weights[1], a1) + biases[1][:, np.newaxis]
        a2 = self.relu(z2)
        # output of neural network
        a3 = self.softmax(np.matmul(weights[2], a2) + biases[2][:, np.newaxis])
        # calculate gradients using the chain rule with the negative log
        # likelihood as our loss function
        inter1 = a3 - truth
        grad_b2 = np.sum(inter1, axis=1)
        grad_w2 = np.matmul(inter1, a2.T)
        inter2 = np.matmul(weights[2].T, inter1) * np.sign(a2)
        grad_b1 = np.sum(inter2, axis=1)
        grad_w1 = np.matmul(inter2, a1.T)
        inter3 = np.matmul(weights[1].T, inter2) * np.sign(a1)
        grad_b0 = np.sum(inter3, axis=1)
        grad_w0 = np.matmul(inter3, batch.T)
        return self._pack(grad_w0, grad_b0, grad_w1, grad_b1, grad_w2, grad_b2)

    def _adam(self, batch, truth, par, learning_rate, beta_1, beta_2, m, v, t):
        """This function performs one optimisation step given by Adam. The
        implementation and notation follows https://arxiv.org/abs/1412.6980 .

        Parameters
        ----------
        batch : numpy array of shape (n_features, batch_size) containing batch
        of training data.

        truth : numpy array of shape (n_classes, batch_size) containing the one
        hot encoded true classes of each batch observation. 

        par : vector containing all network parameter values. 

        learning_rate : float. Controls the size of the optimisation step taken.

        beta_1 : float, must be in [0,1). Decay rate for the first moment 
        estimate, 'm'.

        beta_2 : float, in [0,1). Decay rate for second moment estimate, 'v'.

        m : vector of first moment estimates.

        v : vector of second moment estimates. 

        t : int. The current timestep. 

        Returns
        -------
        par, m, v, t : the updated values of the parameters, first moment
        estimates, second moment estimates and timestep respectively.
        """
        if t == 0:
            m = v = np.zeros(par.shape[0])
        t += 1
        g = self._mlp_grad(batch, truth, par)
        m = beta_1*m + (1 - beta_1)*g
        v = beta_2*v + (1 - beta_2)*g*g
        alpha = learning_rate * np.sqrt(1 - beta_2**t) / (1 - beta_1**t)
        par -= alpha * m / (np.sqrt(v) + 1e-8)
        return par, m, v, t

    def fit(self, X, y):
        """This function initialises the weights of the network and performs one 
        Adam optimisation step for each batch in each epoch. The trained weights 
        are stored in the attribute optimal_par.
        
        Parameters
        ----------
        X : numpy array of shape (n_features, n_training_obs) containing the 
        training data.

        y : numpy array of shape (n_classes, n_training_obs) containing the one 
        hot encoded true class of each training observation.
        """
        # He initialisation since we are using relu activations
        n0, n1, n2, n3 = self.layer_sizes
        w0 = np.random.normal(0, np.sqrt(1 / n0), (n1, n0))
        b0 = np.zeros(n1)
        w1 = np.random.normal(0, np.sqrt(2 / n1), (n2, n1))
        b1 = np.zeros(n2)
        w2 = np.random.normal(0, np.sqrt(2 / n2), (n3, n2))
        b2 = np.zeros(n3)
        par = self._pack(w0, b0, w1, b1, w2, b2)
        t = m = v = 0
        for i in range(self.epochs):
            # in each epoch perform gradient descent on each mini batch
            for j in range(int(np.ceil(X.shape[1] / self.batch_size))):
                batch = X[:, self.batch_size * j:self.batch_size * (j + 1)]
                truth = y[:, self.batch_size * j:self.batch_size * (j + 1)]
                par, m, v, t = self._adam(batch,
                                          truth,
                                          par,
                                          self.learning_rate,
                                          self.beta_1,
                                          self.beta_2,
                                          m,
                                          v,
                                          t
                                          )
        self.optimal_par = par
        return self.optimal_par

    def predict(self, test_x):
        """This returns the predicted classes of the test data given by the MLP.

        Parameters 
        ----------
        test_x : numpy array of shape (n_test_samples, n_features) containing
        the test data.
        """
        weights, biases = self._unpack(self.optimal_par)
        z1 = np.matmul(weights[0], test_x.T) + biases[0][:, np.newaxis]
        a1 = self.relu(z1)
        z2 = np.matmul(weights[1], a1) + biases[1][:, np.newaxis]
        a2 = self.relu(z2)
        z3 = np.matmul(weights[2], a2) + biases[2][:, np.newaxis]
        a3 = self.softmax(z3)
        predictions = np.argmax(a3, axis=0)
        return predictions

# Convolutional neural network

In [5]:
class MyCNN(BaseNetwork):
    """A classifier based on a convolutional neural network. We also make use of
    a novel technique known as batch normalisation, details in:
    https://arxiv.org/abs/1502.03167 .

    CNN architecture
    -----------------
    - convolutional layer with 'n_filters_1' 2D filters of size ''filter_size'
    followed by relu activation,
    - 2x2 max pooling layer,
    - batch normalisation layer,
    - convolutional layer with 'n_filters_2' 2D filters of size ''filter_size'
    followed by relu activation,
    - 2x2 max pooling layer,
    - batch normalisation layer,
    - flatten,
    - fully connected layer with 'n_dense' neurons and relu activation,
    - batch normalisation layer,
    - fully connected output layer with n_classes neurons and softmax activation

    Attributes
    ----------
    epochs : int. The number of times we loop through the whole dataset during
    training. 

    batch_size : int. The number of training samples we use to calculate the 
    loss for each stochastic optimisation step.

    filter_size : int. The size of the array of weights in each convolutional 
    layer will be filter_size x filter_size. 

    n_filters_1, n_filters_2 : int. Number of filters in the first and second
    conv layers.

    n_dense : int. Number of neurons in the fully connected layer. 

    learning_rate : float. Controls step sizes in stochastic gradient descent.

    momentum : float, default=0.9. Must be in (0,1). Controls the momentum used
    to compute the moving averages of means and variances in the batch 
    normalisation layers. 

    par : to store the parameters found from training. 

    mu, var : stores the finals approximations of the means and 
    variances for the batch norm layers.
    """
    def __init__(
        self,
        epochs,
        batch_size,
        filter_size,
        n_filters_1,
        n_filters_2,
        n_dense,
        learning_rate,
        momentum=0.9
    ):
        self.epochs = epochs
        self.batch_size = batch_size
        self.filter_size = filter_size
        self.n_filters_1 = n_filters_1
        self.n_filters_2 = n_filters_2
        self.n_dense = n_dense
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.par = None
        self.mu = None
        self.var = None

    def _cnn_grad(self, batch, truth, f1, b1, gamma1, beta1, mu1_MA, var1_MA,
                  f2, b2, gamma2, beta2, mu2_MA, var2_MA, w1, b3, gamma3, beta3,
                  mu3_MA, var3_MA, w2, b4):
        """Computes the gradient of our negative log likelihood loss with 
        respect to all network parameters. Our implementation differs from the 
        frameworks TensorFlow and PyTorch, since they use automatic
        differentiation whereas we have calculated the gradients directly.  

        Parameters
        ----------
        batch : contains a batch of training samples.

        truth : contains one hot encoded true classes of the training samples in
        batch.

        f1 : filter weights for first conv layer of shape 
        (n_filters_1, filter_size**2).

        b1 : biases for first conv layer of shape (n_filters_1,).

        gamma1 : scale factors for the first batch norm layer, following the
        notation from the original paper. Shape (n_filters_1,).

        beta1 : shift factors for first batch norm layer of shape (n_filters_1,)

        mu1_MA, var1_MA : moving averages of the mean and variance for first 
        batch norm layer.

        f2 : filter weights for second conv layer of shape 
        (n_filters_2, n_filters_1, filter_size**2).

        b2 : biases for second conv layer of shape (n_filters_2,).

        gamma2 : scale factors for the second batch norm layer of shape 
        (n_filters_2,).

        beta2 : shift factors for second batch norm layer of shape
        (n_filters_2,).

        mu2_MA, var2_MA : floats. Moving averages of the mean and variance for 
        second batch norm layer.

        w1 : weights for the fully connected layer of shape 
        (n_dense, n_filters_2 * 49).

        b3 : biases for the fully connected layer, shape (n_dense,). 

        gamma3 : scale factors for the third batch norm layer of shape 
        (n_dense,).

        beta3 : shift factors for third batch norm layer of shape
        (n_dense,).

        mu3_MA, var3_MA : floats. Moving averages of the mean and variance for 
        third batch norm layer.

        w2 : weights for the output layer, shape (10, n_dense).

        b4 : biases for output layer, shape (10,).
        """
        img_size = batch.shape[1]
        # first we pad the input with zeros
        n_pad = int((self.filter_size - 1) / 2)
        batch = np.pad(batch, ((0, 0), (n_pad, n_pad), (n_pad, n_pad)))
        # first convolutional layer
        arr1 = np.empty((self.filter_size ** 2, self.batch_size, img_size, img_size))
        for i in range(self.filter_size ** 2):
            arr1[i, :, :, :] = batch[:, i//self.filter_size:i//self.filter_size + img_size,
                                        i%self.filter_size:i%self.filter_size + img_size]
        c1 = np.tensordot(f1, arr1, axes=((1), (0))) + b1.reshape(self.n_filters_1, 1, 1, 1)
        c1 = self.relu(c1)
        # first max pooling layer
        img_size_new = int(img_size / 2)
        s1 = self.n_filters_1 * self.batch_size * img_size_new**2
        z1 = np.arange(s1)
        res1 = np.swapaxes(c1.reshape(self.n_filters_1, self.batch_size, img_size_new, 2,
                                      img_size_new, 2), 3, 4).reshape(s1, 4)
        arg1 = np.argmax(res1, axis=1)
        mp1 = res1[z1, arg1].reshape(self.n_filters_1, self.batch_size, img_size_new, img_size_new)
        # first batch normalisation
        mu1 = np.mean(mp1, axis=(1, 2, 3))
        m1 = self.batch_size * img_size_new**2
        xmu1 = mp1 - mu1.reshape(self.n_filters_1, 1, 1, 1)
        var1 = np.sum(xmu1 ** 2, axis=(1, 2, 3)) / m1
        invsd1 = 1 / np.sqrt(var1.reshape(self.n_filters_1, 1, 1, 1) + 0.001)
        x_hat1 = (mp1 - mu1.reshape(self.n_filters_1, 1, 1, 1)) * invsd1
        bn1 = gamma1.reshape(self.n_filters_1, 1, 1, 1)*x_hat1 \
              + beta1.reshape(self.n_filters_1, 1, 1, 1)
        # update moving averages for mean and var
        mu1_MA = self.momentum*mu1_MA + (1 - self.momentum)*mu1
        var1_MA = self.momentum*var1_MA + (1 - self.momentum)*var1
        padded = np.pad(bn1, ((0, 0), (0, 0), (n_pad, n_pad), (n_pad, n_pad)))
        # second convolutional layer
        arr2 = np.empty((self.filter_size ** 2, self.n_filters_1, self.batch_size, img_size_new, img_size_new))
        for i in range(self.filter_size ** 2):
            arr2[i, :, :, :, :] = padded[:, :, i//self.filter_size:i//self.filter_size + img_size_new,
                                               i%self.filter_size:i%self.filter_size + img_size_new]
        c2 = np.tensordot(f2, arr2, axes=((1, 2), (1, 0))) + b2.reshape(self.n_filters_2, 1, 1, 1)
        c2 = self.relu(c2)
        # second max pooling layer
        img_size_new_2 = int(img_size_new / 2)
        s2 = self.n_filters_2 * self.batch_size * img_size_new_2**2
        z2 = np.arange(s2)
        res2 = np.swapaxes(c2.reshape(self.n_filters_2, self.batch_size, img_size_new_2, 2,
                                      img_size_new_2, 2), 3, 4).reshape(s2, 4)
        arg2 = np.argmax(res2, axis=1)
        mp2 = res2[z2, arg2].reshape(self.n_filters_2, self.batch_size, img_size_new_2, img_size_new_2)
        # second batch normalisation
        m2 = img_size_new_2**2 * self.batch_size
        mu2 = np.mean(mp2, axis=(1, 2, 3))
        xmu2 = mp2 - mu2.reshape(self.n_filters_2, 1, 1, 1)
        var2 = np.sum(xmu2 ** 2, axis=(1, 2, 3)) / m2
        invsd2 = 1 / np.sqrt(var2.reshape(self.n_filters_2, 1, 1, 1) + 0.001)
        x_hat2 = (mp2 - mu2.reshape(self.n_filters_2, 1, 1, 1)) * invsd2
        bn2 = gamma2.reshape(self.n_filters_2, 1, 1, 1)*x_hat2 \
              + beta2.reshape(self.n_filters_2, 1, 1, 1)
        # update moving averages for mean and var
        mu2_MA = self.momentum*mu2_MA + (1 - self.momentum)*mu2
        var2_MA = self.momentum*var2_MA + (1 - self.momentum)*var2
        # flatten the output
        flat = np.swapaxes(bn2, 0, 1).reshape(self.batch_size, -1)
        # dense layer
        d1 = np.matmul(w1, flat.T) + b3[:, np.newaxis]
        a1 = self.relu(d1)
        # third batch normalisation
        mu3 = np.mean(a1, axis=1)
        xmu3 = a1 - mu3[:, np.newaxis]
        var3 = 1 / self.batch_size * np.sum(xmu3 ** 2, axis=1)
        invsd3 = 1 / np.sqrt(var3.reshape(self.n_dense, 1) + 0.001)
        x_hat3 = (a1 - mu3.reshape(self.n_dense, 1)) * invsd3
        bn3 = gamma3.reshape(self.n_dense, 1)*x_hat3 + beta3.reshape(self.n_dense, 1)
        # update moving averages
        mu3_MA = self.momentum*mu3_MA + (1 - self.momentum)*mu3
        var3_MA = self.momentum*var3_MA + (1 - self.momentum)*var3
        a2 = self.softmax(np.matmul(w2, bn3) + b4[:, np.newaxis])

        # gradient via backprop - categorical cross entropy loss
        inter1 = a2 - truth
        db4 = np.sum(inter1, axis=1)
        dw2 = np.matmul(inter1, bn3.T)
        inter2 = np.matmul(w2.T, inter1)
        dbeta3 = np.sum(inter2, axis=1)
        dgam3 = np.sum(inter2 * x_hat3, axis=1)
        dxhat3 = inter2 * gamma3[:, np.newaxis]
        dsig3 = np.sum(dxhat3 * xmu3, axis=1) * 0.5 * (invsd3.reshape(self.n_dense) ** 3)
        dmu3 = np.sum(dxhat3 * -invsd3, axis=1) + dsig3*np.sum(-2 * xmu3, axis=1)/self.batch_size
        inter3 = (dxhat3*invsd3 + dsig3[:, np.newaxis]*2*xmu3/self.batch_size
                  + dmu3[:, np.newaxis]/self.batch_size) * np.sign(a1)
        db3 = np.sum(inter3, axis=1)
        dw1 = np.matmul(inter3, flat)
        inter4 = np.swapaxes(np.matmul(w1.T, inter3).T.reshape(self.batch_size, self.n_filters_2,
                                                               img_size_new_2, img_size_new_2), 0, 1)
        dbeta2 = np.sum(inter4, axis=(1, 2, 3))
        dgam2 = np.sum(inter4 * x_hat2, axis=(1, 2, 3))
        dxhat2 = inter4 * gamma2.reshape(self.n_filters_2, 1, 1, 1)
        dsig2 = np.sum(dxhat2 * xmu2, axis=(1, 2, 3)) * 0.5 * (invsd2.reshape(self.n_filters_2) ** 3)
        dmu2 = np.sum(dxhat2 * -invsd2, axis=(1, 2, 3)) + dsig2*np.sum(-2 * xmu2, axis=(1, 2, 3))/m2
        inter5 = dxhat2*invsd2 + dsig2.reshape(self.n_filters_2, 1, 1, 1)*2*xmu2/m2 \
                 + dmu2.reshape(self.n_filters_2, 1, 1, 1)/m2
        mp_grad2 = np.zeros((s2, 4))
        mp_grad2[z2, arg2] = inter5.ravel()
        mp_grad2 = np.swapaxes(mp_grad2.reshape(self.n_filters_2, self.batch_size, img_size_new_2, 
         img_size_new_2, 2, 2), 3, 4).reshape(self.n_filters_2, self.batch_size, img_size_new, img_size_new)
        inter6 = mp_grad2 * np.sign(c2)
        db2 = np.sum(inter6, axis=(1, 2, 3))
        df2 = np.swapaxes(np.tensordot(inter6, arr2, axes=((1, 2, 3), (2, 3, 4))), 1, 2)
        f2_prime = np.rot90(f2.reshape(self.n_filters_2, self.n_filters_1, self.filter_size, self.filter_size), 
                            2, axes=(2, 3))
        pad = np.pad(inter6, ((0, 0), (0, 0), (n_pad, n_pad), (n_pad, n_pad)))
        sub = np.empty((self.filter_size ** 2, self.n_filters_2, self.batch_size, img_size_new, img_size_new))
        for i in range(self.filter_size ** 2):
            sub[i, :, :, :, :] = pad[:, :, i//self.filter_size:i//self.filter_size + img_size_new,
                                           i%self.filter_size:i%self.filter_size + img_size_new]
        inter7 = np.tensordot(f2_prime.reshape(self.n_filters_2, self.n_filters_1, self.filter_size ** 2), sub,
                              axes=((0, 2), (1, 0)))
        dbeta1 = np.sum(inter7, axis=(1, 2, 3))
        dgam1 = np.sum(inter7 * x_hat1, axis=(1, 2, 3))
        dxhat1 = inter7 * gamma1.reshape(self.n_filters_1, 1, 1, 1)
        dsig1 = np.sum(dxhat1 * xmu1, axis=(1, 2, 3)) * 0.5 * (invsd1.reshape(self.n_filters_1) ** 3)
        dmu1 = np.sum(dxhat1 * -invsd1, axis=(1, 2, 3)) + dsig1 * np.sum(-2 * xmu1, axis=(1, 2, 3)) / m1
        inter8 = dxhat1*invsd1 + dsig1.reshape(self.n_filters_1, 1, 1, 1)*2*xmu1/m1 \
                 + dmu1.reshape(self.n_filters_1, 1, 1, 1)/m1
        mp_grad1 = np.zeros((s1, 4))
        mp_grad1[z1, arg1] = inter8.ravel()
        mp_grad1 = np.swapaxes(mp_grad1.reshape(self.n_filters_1, self.batch_size, img_size_new, img_size_new, 
                               2, 2), 3, 4).reshape(self.n_filters_1, self.batch_size, img_size, img_size)
        inter9 = mp_grad1 * np.sign(c1)
        db1 = np.sum(inter9, axis=(1, 2, 3))
        df1 = np.tensordot(inter9, arr1, axes=((1, 2, 3), (1, 2, 3)))
        return [df1, db1, dgam1, dbeta1, df2, db2, dgam2, dbeta2, dw1, db3, dgam3, dbeta3, dw2,
                db4], mu1_MA, var1_MA, mu2_MA, var2_MA, mu3_MA, var3_MA

    def _sgd(self, batch, truth, p, mu1, var1, mu2, var2, mu3, var3, learning_rate):
        """Implements stochastic gradient descent and returns updated parameters.

        Parameters
        ---------- 
        batch : numpy array of shape (batch_size, 28, 28) containing batch
        of training data.

        truth : numpy array of shape (n_classes, batch_size) containing the one
        hot encoded true classes of each batch observation. 

        p : list containing all network parameter values. 

        mu1,...,var3 : current estimates of the batch norm means and variances.

        learning_rate : float. Controls the size of the optimisation step taken.

        Returns
        -------
        p, mu1,...,var3 : the updated values of the parameters, means and variances.
        """
        grad, mu1, var1, mu2, var2, mu3, var3 = self._cnn_grad(batch, truth, p[0], p[1], p[2], p[3], mu1, var1, 
                                                p[4], p[5], p[6], p[7], mu2, var2, p[8], p[9], p[10], p[11], 
                                                mu3, var3, p[12], p[13])
        p = [a - learning_rate * b for a, b in zip(p, grad)]
        return p, mu1, var1, mu2, var2, mu3, var3

    def fit(self, X, y):
        """This function implements Glorot uniform initialisation of the weights 
        of the network and performs one gradient descent step for each batch in 
        each epoch. The trained weights, means and variances are stored in the 
        attributes par, mu and var.
        
        Parameters
        ----------
        X : numpy array of shape (n_training_obs, 28, 28) containing the 
        training images.

        y : numpy array of shape (n_classes, n_training_obs) containing the one 
        hot encoded true class of each training observation."""
        p = [np.random.uniform(-np.sqrt(6 / (self.filter_size**2 + self.n_filters_1)),
            np.sqrt(6 / (self.filter_size**2 + self.n_filters_1)), (self.n_filters_1, self.filter_size ** 2)),
            np.zeros(self.n_filters_1), np.ones(self.n_filters_1), np.zeros(self.n_filters_1),
            np.random.uniform(-np.sqrt(6 / (self.n_filters_1 * self.filter_size**2 + self.n_filters_2)),
                               np.sqrt(6 / (self.n_filters_1 * self.filter_size**2 + self.n_filters_2)),
                              (self.n_filters_2, self.n_filters_1, self.filter_size ** 2)),
            np.zeros(self.n_filters_2), np.ones(self.n_filters_2), np.zeros(self.n_filters_2),
            np.random.uniform(-np.sqrt(6 / (self.n_filters_2 * 49 + self.n_dense)),
                               np.sqrt(6 / (self.n_filters_2*49 + self.n_dense)), 
                              (self.n_dense, self.n_filters_2*49)),
            np.zeros(self.n_dense), np.ones(self.n_dense), np.zeros(self.n_dense),
            np.random.uniform(-np.sqrt(6 / (self.n_dense + 10)), np.sqrt(6 / (self.n_dense + 10)), 
                              (10, self.n_dense)),
            np.zeros(10)]
        mu1, var1, mu2, var2, mu3, var3 = 0, 1, 0, 1, 0, 1
        for i in range(self.epochs):
            # in each epoch perform gradient descent on each mini batch
            perm = np.random.permutation(X.shape[0])
            X = X[perm, :, :]
            y = y[:, perm]
            for j in range(int(np.ceil(X.shape[0] / self.batch_size))):
                batch = X[self.batch_size * j:self.batch_size * (j + 1), :, :]
                truth = y[:, self.batch_size * j:self.batch_size * (j + 1)]
                p, mu1, var1, mu2, var2, mu3, var3 = self._sgd(batch, truth, p, mu1, var1, mu2, var2,
                                                               mu3, var3, self.learning_rate)
        self.par, self.mu, self.var = p, [mu1, mu2, mu3], [var1, var2, var3]
        return p, mu1, var1, mu2, var2, mu3, var3

    def predict(self, test_X):
        """This returns the predicted classes of the test data given by the CNN.

        Parameters 
        ----------
        test_X : numpy array of shape (n_test_samples, 28, 28) containing
        the test data.
        """
        pred = np.empty(test_X.shape[0])
        batch_size = 1000
        img_size = test_X.shape[1]
        img_size_new = int(img_size/2)
        img_size_new_2 = int(img_size_new/2)
        for j in range(int(test_X.shape[0]/batch_size)):
            n_pad = int((self.filter_size - 1) / 2)
            batch = np.pad(test_X[batch_size * j:batch_size * (j + 1)], ((0,0), (n_pad, n_pad), (n_pad, n_pad)))
            arr1 = np.empty((self.filter_size ** 2, batch_size, img_size, img_size))
            for i in range(self.filter_size ** 2):
                arr1[i, :, :, :] = batch[:, i//self.filter_size:i//self.filter_size + img_size,
                                            i%self.filter_size:i%self.filter_size + img_size]
            c1 = np.tensordot(self.par[0], arr1, axes=((1), (0))) \
                 + self.par[1].reshape(self.n_filters_1, 1, 1, 1)
            c1 = self.relu(c1)
            mp1 = c1.reshape(self.n_filters_1, batch_size, img_size_new, 2, img_size_new, 2).max(axis=(3, 5))
            x_hat1 = (mp1 - self.mu[0].reshape(self.n_filters_1, 1, 1, 1)) \
                     / np.sqrt(self.var[0].reshape(self.n_filters_1, 1, 1, 1))
            bn1 = self.par[2].reshape(self.n_filters_1, 1, 1, 1)*x_hat1 \
                  + self.par[3].reshape(self.n_filters_1, 1, 1, 1)
            padded = np.pad(bn1, ((0, 0), (0, 0), (n_pad, n_pad), (n_pad, n_pad)))
            arr2 = np.empty((self.filter_size ** 2, self.n_filters_1, batch_size, img_size_new, img_size_new))
            for i in range(self.filter_size ** 2):
                arr2[i, :, :, :, :] = padded[:, :, i//self.filter_size:i//self.filter_size + img_size_new,
                                                   i%self.filter_size:i%self.filter_size + img_size_new]
            c2 = np.tensordot(self.par[4], arr2, axes=((1, 2), (1, 0))) \
                 + self.par[5].reshape(self.n_filters_2, 1, 1, 1)
            c2 = self.relu(c2)
            mp2 = c2.reshape(self.n_filters_2, batch_size, img_size_new_2, 2, img_size_new_2, 2).max(axis=(3, 5))
            x_hat2 = (mp2 - self.mu[1].reshape(self.n_filters_2, 1, 1, 1)) \
                     / np.sqrt(self.var[1].reshape(self.n_filters_2, 1, 1, 1))
            bn2 = self.par[6].reshape(self.n_filters_2, 1, 1, 1)*x_hat2 \
                  + self.par[7].reshape(self.n_filters_2, 1, 1, 1)
            flat = np.swapaxes(bn2, 0, 1).reshape(batch_size, -1)
            d1 = np.matmul(self.par[8], flat.T) + self.par[9][:, np.newaxis]
            a1 = self.relu(d1)
            x_hat3 = (a1 - self.mu[2].reshape(self.n_dense, 1)) / np.sqrt(self.var[2].reshape(self.n_dense, 1))
            bn3 = self.par[10].reshape(self.n_dense, 1)*x_hat3 + self.par[11].reshape(self.n_dense, 1)
            a2 = self.softmax(np.matmul(self.par[12], bn3) + self.par[13][:, np.newaxis])
            pred[j * batch_size:batch_size * (j + 1)] = np.argmax(a2, axis=0)
        return pred