# Residual neural networks

## Objectives:
- implementing a residual layer 
- using batchnormalization layers 
- observing the effects on the gradients at initialization of residual layers, batchnormalization layers and the depth of neural networks.
- observing the effects on overfitting of residual layers, batchnormalization layers and the depth of neural networks. 

## Contents

1. Implementing a Residual Network  
2. Observing vanishing, shattering and exploding gradients
3. Effects of residual layers, batchnormalization and depth on overfitting   


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import random_split

torch.manual_seed(123)

## 1. Implementing a Residual Network

#### TODO

1. Define a class ``MyResidualLayer`` implementing a residual layer, taking as parameter ``n_hid`` and ``use_skip`` such that:
- ``out = x + relu(fc(x))`` if ``use_skip = True``
- ``out = relu(fc(x))`` if ``use_skip = False``

where ``fc`` is a fully connected layer with as many inputs as outputs ``n_hid``.

2. Define a class ``MyNet``, implementing a neural network with ``L`` (blocks of) trainable layers such that:
- The first layer is a fully connected layer with ``n_in=1`` inputs and ``n_hid`` outputs
- The next ``L-2`` layers are residual layers (defined using ``MyResidualLayer``) with ``n_hid`` hidden units or fully connected depending on the value of a boolean parameter ``use_skip``. You can set ``n_hid`` to ``128`` for example.
- The last layer is a fully connected layer with ``n_hid`` inputs and ``n_out=1`` outputs. 
- Every ``batchnorm_frequency`` layers, a batchnorm layer is inserted. 


## 2. Observing vanishing, shattering and exploding gradients

### Computing gradients with respect to the input 


Rather than computing the gradients with respect to the network's parameters as you are now used to, we will instead compute the gradients with respect to the input data. This is meaningful since by the chain rule, the derivatives with respect to inputs are connected to the derivatives with respect to parameters. Since we are interested in observing how network depth impacts the gradients, we will use a simple grid of uniformly spaced data points ranging from -3 to 3 as input for our network.

#### TODO

1. Write a function that computes the mean gradient values with respect to an input ``X`` consisting of 256 points uniformly spaced between ``[-3,3]`` for ``n_iter = 30`` random model initializations of your ``MyNet`` model. To do so:
1. Seeds pytorch, 
1. Repeat:
    1. Instanciante a ``MyNet`` model. (You might need to add parameters to your function so that you can instanciate your model with different values of ``batchnorm_frequency``, ``use_skip``, etc.)
    1. Emulate the beginning of a training phase by calling ``train_one_epoch``
    1. Disable autograd for all the network's parameters. Enable autograd for an uniformly spaced input ``X``.
    1. Perform a forward pass with ``X`` as input.
    1. Run a backward pass on the output. (You might need to sum all the components of the output before calling ``backward()``)
    1. Store the gradient with respect to the input ``X``. You might need to use .clone() and/or .detach() methods.
1. For each element of ``X`` compute the mean gradient value obtained with the different model initialization.
1. Return the mean gradient values.   

In [None]:
def train_one_epoch(model):
    """
    Performs one small training iteration to emulate an early training situation
    """
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=1e-1)
    loss_fn = nn.MSELoss()
    optimizer.zero_grad(set_to_none=True)
    
    X = torch.linspace(-3, 3, 10).unsqueeze(1)  

    outputs = model(X)
    loss = loss_fn(outputs, torch.cos(X))
    loss.backward()
    
    optimizer.step()
    optimizer.zero_grad()



### Experiment

#### TODO

1. Plot mean gradient values with respect to the input ``X`` for different values of depth, ``batchnorm_frequency`` and with or without residual skip (``use_skip`` set to ``True`` or ``False``). You can keep the values suggested in the cell below.
1. Comment your results. 
1. Do you observe shattering gradients? If so, in which conditions?  
1. Do you observe vanishing gradients? If so, in which conditions? 
1. Do you observe exploding gradients? If so, in which conditions? 

In [None]:
N = 256
X = torch.linspace(-3, 3, N).unsqueeze(1) 
list_depth = [2, 5, 10, 25, 50, 75]
list_frequency = [None, 10, 5, 3, 1]
list_use_skip = [False, True]


        

## 3. Effects of residual layers, batchnormalization and depth on overfitting


### Experiment

1. Load and preprocess the CIFAR-10 dataset. Split it into 3 datasets: training, validation and test. Take a subset of these datasets by keeping only 2 labels: bird and plane.
1. Modify your residual network so that its input and output layers match the dataset (Now ``n_in=32*32*3``, ``n_out=2``. You can also set ``n_hid`` to ``64`` for example, to reduce computations) 
1. Plot the training loss and the validation loss for different values of depth, ``batchnorm_frequency`` and with or without residual skip. (you can keep the values suggested in the cell below)
1. Comment your results.
1. Select and evaluate the best model among the different values of depth, ``batchnorm_frequency`` and ``use_skip`` used.

In [None]:
n_epochs = 20
batch_size = 512

list_depth = [3, 5, 10, 20]
list_frequency = [None, 9, 7, 3, 1]
list_use_skip = [False, True]
