# Lab exercises 3

For this lab exercises, please submit 2 notebooks / python script and 2 reports, one for each part. The deadline is 22 december.

It is important the you **read the documentation** to understand how to use Pytorch functions, what kind of transformation they apply etc. You have to **take time to read it carefully** to understand what you are doing.

* https://pytorch.org/docs/stable/nn.html
* https://pytorch.org/docs/stable/torch.html

## 1. Part one: MNIST classification with Pytorch

The goal of the first part is to learn how to use Pytorch and to observe the impact of regularization during training. You should test different network architectures, e.g. with hidden layers of size 128-128, 128-64-32-16, 256-128-64-32-16, 512-256-128-64-32-16, 800-800, and different activation functions (tanh, relu, sigmoid).

Remember that Pytorch expects data in a different format than in the previous lab exercise: the first dimension is always the batch dimension.

In [5]:
# import libs that we will use
import os
import numpy as np
import matplotlib.pyplot as plt
import math

# To load the data we will use the script of Gaetan Marceau Caron
# You can download it from the course webiste and move it to the same directory that contains this ipynb file
import dataset_loader

%matplotlib inline

In [6]:
# Download mnist dataset 
mnist_path = "./mnist.pkl.gz"

# load the 3 splits
train_data, dev_data, test_data = dataset_loader.load_mnist(mnist_path)

In [8]:
image = torch.from_numpy(train_data[0][1])
print(image.shape) # flat image of dim (784,)

# reshape the tensor so it is represented as a batch containing a single image
# -1 means "all remaining elements", here it would be equivalent to image.reshape(1, 784)
image = image.reshape(1, -1)
print(image.shape) # flat image of dim (1, 784)

torch.Size([784])
torch.Size([1, 784])


In [9]:
# Constructing a batched input
batch_size = 10
first = 20

# the cat() function concatenates a list of tensor along a dimension
batch_input = torch.cat(
    [
        # we reshape the image tensor so it has dimension (1, 784)
        torch.from_numpy(image).reshape(1, -1)
        for image in train_data[0][first:first + batch_size]
    ],
    # we want to concatenate on the batch dimension
    dim=0
)
print(batch_input.shape)  # batch of ten flat images (10, 784)

torch.Size([10, 784])


### 1.2. Layer initialization

By default, Pytorch will apply Kaiming initialization to linear layers. However, I recommend you to always explicitly initialize you network by hand in the constructor.

In [10]:
linear = torch.nn.Linear(10, 20, bias=bias)

# initialization are always in-place operations
# linear.weight is a Parameter, linear.weight.data is the tensor containing the parameter values
torch.nn.init.xavier_uniform_(linear.weight.data)  # Xavier/Glorot init for tanh
torch.nn.init.kaiming_uniform_(linear.weight.data)  # Kaiming/He init for tanh

if bias:
    torch.nn.init.zeros_(linear.bias.data)

NameError: name 'bias' is not defined

### 1.3. Regularization

You can try two types of regularization (they can be combined together):

* weight decay: it is a parameter of the optimizer
* dropout: see slides

### 1.4. Gradient clipping

A commong trick for training neural networks is gradient clipping: if the norm of the gradient is too big, we rescale the gradient. This trick can be used to prevent exploding gradients and also to make "too big steps" in the wrong direction due the use of approximate gradient computation in SGD.

In [None]:
batch_loss.backward()  # compute gradient
torch.nn.utils.clip_grad_value_(network.parameters(), 5.)  # clip gradient if its norm exceed 5
optimizer.step()  # update parameters

### 1.4. Bonus: Convolutional Neural Network

You can try to rely on a CNN instead of a MLP to classify MNIST images (you can still have a single layer MLP on top of convolutions, after pooling!). Note that this will requires you to reshape the input images!

https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d

In [None]:
t = torch.rand((10, 100))  # t is batch of 10 "flat" pictures
t = t.reshape(10, 10, 10)  # we reshape t so each batch contains a 10x10 picture that is not flat

## Part 2: Variational Auto-Encoder

To build a new Variational Auto-Encoder, you need two networks:

* An encoder that will take as input an image and compute the parameters of list of Normal distributions
* A decoder that will take a sample from each Normal distribution and will output an image

For simplicity we will assume that:

* each network as a single hidden layer of size 100
* the latent space contains only 2 points

To understand exactly what a VAE is, you can:

* check the slides of Michèle Sebag
* check this tutorial: https://arxiv.org/abs/1606.05908

### 1.2. Encoder

* Compute an hidden representation: $z=relu(W1x+b1) $
* Compute the means of the normal distributions: $mu=W2x+b2 $
* Compute the log variance of the normal distributions:  $ logSigmaSquared=W3x+b3 $

### 1.2. Decoder

This a simple MLP, nothing new here!

### 1.3. Training loss

To compute the training loss, you must compute two terms:

* a Monte-Carlo estimation of the reconstruction loss
* the KL divergence between the distributions computed by the encoder and the prior

To sample values, you can use the reparameterization trick as follows:

In [None]:
e = torch.normal(0, 1., mu.shape)
z = mu + e * torch.sqrt(torch.exp(log_sigma_squared))

For the reconstruction loss, use the Binary Cross Entropy loss:

In [None]:
loss_builder = torch.nn.BCEWithLogitsLoss(reduction="sum")

The formula of the KL divergence with the prior is as follows:

In [None]:
-0.5 * torch.sum(1 + log_sigma_squared - mu.pow(2) - log_sigma_squared.exp())

### 1.4. Recomended hyper parameters

* Optimizer: Adam
* N. epochs: 50
* Use gradient clipping!
* Large batch size, e.g. 128

In [None]:
# use itertools.chain to join parameters of the two networks
optimizer = torch.optim.Adam(itertools.chain(encoder.parameters(), decoder.parameters()))
torch.nn.utils.clip_grad_value_(itertools.chain(encoder.parameters(), decoder.parameters()), 5.)

### 1.5. Generate new images

Note: they will be blurry, but that's ok!

In [None]:
e = torch.normal(0, 1., (10, 2))
images = decoder(e).sigmoid()

for i in range(10):
    picture = images[i].clone().detach().numpy()
    plt.imshow(picture.reshape(28,28), cmap='Greys')
    plt.show()