In [None]:
%load_ext autoreload
%autoreload 2

## What is MNIST? 

MNIST a standard dataset back from the 80s that was used to experiment with convolutional neural networks, and still used today as a toy dataset for research. It is simply a dataset of about 80k handwritten labeled digits, very convenient to experiment some stuff as there are 10 defined & balanced classes, each having enough variability to evalute the generalization abilities of our model. 

In [None]:
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from dandb import plot_image_grid

val_set = MNIST(root='datasets/MNIST', train=False, download=True, transform = transforms.ToTensor())
plot_image_grid(torch.stack([val_set[i][0] for i in range(32)]), display=True)

## Let's bend
### First step :  how does a Variational Auto-Encoder (VAE) work

Before entering the VAE's bending, we have to import the model.

In [None]:
from dandb.networks import make_mnist_vae
vae = make_mnist_vae()
state_dict = torch.load('models/original/mnist_vae/final.ckpt', map_location=torch.device('cpu'))
vae.load_state_dict(state_dict)

Now, let's reconstruct a . batch of example using the little auto-encoder.  [Variational auto-encoders](https://www.ee.bgu.ac.il/~rrtammy/DNN/StudentPresentations/2018/AUTOEN~2.PDF) are based on two modules : 
- an *encoder*, that outputs a latent normal distribution in the latent space : $q(\mathbf{z|x}) = \mathcal{N}(\boldsymbol{\mu}_\theta(\mathbf{x}), \boldsymbol{\sigma}^2_\theta(\mathbf{x}))$
- a *decoder*, that outputs a distribution (here Bernoulli has the data is binary) in the latent space : $p(\mathbf{x|z}) = \mathcal{N}(\boldsymbol{\mu}_\phi(\mathbf{z}), \boldsymbol{\sigma}^2_\phi(\mathbf{z}))$

while decoders were initially probabilistic, nowadays everybody gave up this idea; a determinstic output can actually be seen as a normal with zero variance in the limit (infinitely narrowed gaussian). 

In [None]:
from dandb import plot_image_reconstructions
# Load a batch of digits from MNIST

batch_size = 64
digit_loader = torch.utils.data.DataLoader(val_set, batch_size, shuffle=False, num_workers=0)
x = next(iter(digit_loader))[0]

# Encode it into its latent representation
# Remember, it is equivalent to knobs of a synth (here, a digit synth)
mu, var = vae.encode(x)
# sample the latent distribution
latent_representation, _ = vae.reparametrize(mu, var)

# Use this latent representation to generate the output
y = vae.decode(latent_representation)

plot_image_reconstructions(x, y)

Notice that the reconstruction is slightly blurry compared to the input. This is because of the probabilistic latent space, that enforces the model to smooth its generations. 

### Second step : bending and tracing our digit synth using torchbend

As we did with RAVE, let us summarize the activations of our little VAE. 

In [None]:
import torchbend as tb
tb.set_output('notebook')

bended_vae = tb.BendedModule(vae)
bended_vae.trace(x=x)

activation_names = bended_vae.activation_names()
activation_shapes = list(map(bended_vae.activation_shape, activation_names))
print('forward method : ')
bended_vae.print_graph();

The graph of the encoder (resp. decoder) is much easier, and consists actually in a series of convolutional (resp. transposed convolutional) operations, followed by batch normazliation and non-linearity. The interesting activations are then much easier to retrieve, as we do here by plotting all the activations for encoding a single example:

In [None]:
from dandb import plot_image_activations

# change the below number to change the displayed example!
batch_idx = 4
act_names = [f'encoder_net_act{i}' for i in range(4)]#+[f'decoder_net_act{i}' for i in range(2, 5)]
outs = bended_vae.get_activations(*act_names, x=x, fn='forward')

for act in act_names:
    plot_act = outs[act][batch_idx]
    plot_image_activations(plot_act, display=True, name=act, height=400)

Inversely, with the decoder: 

In [None]:
from dandb import plot_image_activations

# change the below number to change the displayed example!
batch_idx = 4
act_names = [f'decoder_net_act{i}' for i in range(2, 5)]
outs = bended_vae.get_activations(*act_names, x=x, fn='forward')

for act in act_names:
    plot_act = outs[act][batch_idx]
    plot_image_activations(plot_act, display=True, name=act, height=400)

By doing that, we can see all the intermediary values that are processed to encode and decode a given example.

### Dissecting the weights

Let's summarize the weights of our model : 

In [None]:
bended_vae.print_weights(r".*conv\d.weight")

We see that, for each layer, the kernels are $(n_{out}\times n_{in})$ little patches of $(5\times5)$ dimensions., that we can plot individually : 

In [None]:
max_kernels = 128
for i in range(4):
    kernels = bended_vae.state_dict()[f'encoder.net.conv{i}.weight']
    kernels = kernels.reshape(-1,5,5)[:max_kernels] 
    # normalize by kernel for visualization
    kernels = kernels / kernels.amax(0)[None]
    plot_image_activations(kernels, display=True, name="kernels for layer %d"%i, height=None)


We can see that none of these weights has a clear sense ; but somehow, the model manage to perform the task we asked it to do. This is typical of deep learning, and prevents to have a clear understanding on the responsability of a given part of a network ; and, that is normal, as every unit influences the whole network, such that this very tight entanglement is part of the network's efficiency.  In the next notebook, we will dissect an even more complex network : RAVE. 