This notebook was inspired by neural network & machine learning labs led by [GMUM](https://gmum.net/).

See also [Chapter 14](https://www.deeplearningbook.org/contents/autoencoders.html) of the Deep Learning book and Lilian Weng's [From Autoencoder to Beta-VAE](https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html).

Utils and imports (run and hide).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.decomposition import PCA

import torch
from torch import nn
from torch.utils.data import Subset
from torchvision.datasets import MNIST, FashionMNIST
from torchvision.transforms import ToTensor, Lambda, Compose

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def plot_dataset(train_data, model):
    view_data = train_data.data[:5].view(-1, 28*28) / 255.
    _, decoded_data = model.forward(train_data.data[:5].view(-1, 784).float().to(device) / 255.)
    decoded_data = decoded_data.cpu().detach().numpy()

    n_rows = 2 if decoded_data is not None else 1
    n_cols = len(view_data)
    plt.suptitle("Reconstruction")
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
    
    if decoded_data is not None:
        for i in range(n_cols):
            axes[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray')
            axes[0][i].set_xticks(())
            axes[0][i].set_yticks(())
        
        for i in range(n_cols):
            axes[1][i].clear()
            axes[1][i].imshow(np.reshape(decoded_data[i], (28, 28)), cmap='gray')
            axes[1][i].set_xticks(())
            axes[1][i].set_yticks(())
    
    else:
        for i in range(n_cols):
            axes[i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray')
            axes[i].set_xticks(())
            axes[i].set_yticks(())
    
    plt.show()
    
def plot_pca(data, model):
    labels = data.classes
    plt.suptitle("Reduction of latent space")
    plt.figure(figsize=(10, 6))
    pca = PCA(2)

    z = model.encode(train_data.data.view(-1, 784).float().to(device))
    reduced_z = pca.fit_transform(z.detach().cpu().numpy())
    
    for class_idx in range(10):
        indices = (data.targets == class_idx)
        plt.scatter(
            reduced_z[indices, 0], reduced_z[indices, 1],
            s=2., label=labels[class_idx])

    plt.legend()
    plt.show()

    
torch.manual_seed(1337) 
batch_size = 128 
transforms = Compose([ToTensor(), Lambda(lambda x: x.flatten())])

train_data = MNIST(root='.', 
                   train=True, 
                   transform=transforms,    
                   download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

## Discriminative vs. generative models

There are (generally) two approaches to statistical classification:
- *generative*, where we model the joint probability distribution $P(X,Y)$,
- *discriminative*, where we model the conditional probability of the target $Y$ given an observation $x$, $P(Y\vert X=x)$ (classifiers computed without using a probability model are also loosely referred to as "discriminative").

One can also think of generative models as learning the distribution of individual classes and discriminative models as learning (hard or soft) boundaries between classes.

Generative models allow you to generate data similar to training data, whereas discriminative models might be easier to learn.

See the [Wikipedia article](https://en.wikipedia.org/wiki/Generative_model) on generative models, the [CrossValidated question](https://stats.stackexchange.com/questions/12421/generative-vs-discriminative), and the classic ML paper [On Discriminative vs. Generative Classifiers: A comparison of logistic regression and naive Bayes](https://ai.stanford.edu/~ang/papers/nips01-discriminativegenerative.pdf) for more disambiguation.

Today we will work with the autoencoder model, first showing how to use it for semi-supervised learning, and later building a generative model.

## Vanilla autoencoder

An *autoencoder* is a neural network that is trained to copy it's input to its output. The network may be viewed as consisting of two parts: an *encoder* $g_\phi$, which takes in an input $\mathbf{x}$ and produces a *code* (also: *hidden representation*; *latent vector*) $\mathbf{z}=g_\phi(\mathbf{x})$ , and the *decoder* $f_\theta$, which produces a reconstruction $\mathbf{x'}=f_\theta(\mathbf{z})$.

![auto-encoder](figures/ae.png)
<center>Source: <a href="https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html">From Autoencoder to Beta-VAE</a>.</center>

If an autoencoder succeeds in simply learning to set $f_\theta(g_\phi(\mathbf{x}))=\mathbf{x}$ everywhere, then it is not especially useful. Instead, autoencoders are designed to be unable to learn to copy perfectly. Because the model is forced to prioritize which aspects of the input should be copied, it often learns useful properties of the data.

The loss function for the vanilla autoencoder is the MSE between the input and output:
$$L_{AE} =\frac{1}{n}\sum_i \lVert\mathbf{x}_i-f_\theta(g_\phi(\mathbf{x}_i))\rVert_2^2.$$
The encoder and the decoder can be arbitrary neural networks, but usually the decoder is comprised of the same transformations as the encoder in reverse order.  

## Task 1 (0.25p)
Implement the encoder and the decoder for a vanilla autoencoder. 

The dimensions in the encoder are supposed to have the following number of neurons: `(784, 128, 128, 64, latent_dim)`. Analogously, for the decoder: `(latent_dim, 64, 128, 128, 784)`. (The input and output dimensionality corresponds to the number of pixels in MNIST.) `latent_dim` is supposed to be a parameter of the constructor.

Use no activation function after the encoder, a sigmoid after the decoder, and ReLU after the hidden layers.

In [None]:
class AutoEncoder(torch.nn.Module):
    def __init__(self, latent_dim):
        
        super(AutoEncoder, self).__init__()
        
        self.latent_dim = latent_dim
        
        self.encoder = ???
        self.decoder = ???
    
    def encode(self, x):
        return ???
    
    def decode(self, encoded):
        return ???

    def forward(self, x):
        encoded = ???
        decoded = ???
        return encoded, decoded

In [None]:
n_epochs = 25
lr = 5e-3        

autoencoder = AutoEncoder(latent_dim=10).to(device)

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=lr)
criterion = torch.nn.MSELoss()

for epoch in range(n_epochs):
    epoch_losses = []  
    for step, (x, _) in enumerate(train_loader):
        
        x = x.to(device)
        
        optimizer.zero_grad()        

        _, decoded = autoencoder(x)
        loss = criterion(decoded, x)   
        loss.backward()          
        optimizer.step()            
        
        epoch_losses.append(loss.item())

    print(f'Epoch: {epoch+1}  |  train loss: {np.mean(epoch_losses):.4f}')

    if epoch % 10 == 0:
        plot_dataset(train_data, autoencoder)
        plot_pca(train_data, autoencoder)

## Semi-supervised learning

In practice building a fully-labeled dataset can be very costly. If we want to train an image classifier, then gathering a large amount of data isn't a problem (we can scrape them from the internet, for example). Labelling them, however, is, and would require human resources. In some cases, labelling can be even more expensive -- in the segmentation task, where we want to assign a class to each pixel in the image, assigning the labels for one picture can take many hours.

Thus, we'd like to have methods which are able to utilize data for which we don't have labels. In the following task we'll build a simple semi-supervised model using an autoencoder.

## Task 2 (0.5p)

Assume that for the 60k examples from MNIST only 100 have the label. The 100 labeled examples are in the variable `labeled_data`.

1. Implement a classifier and train it on the 100 labeled examples. Report the accuracy on the test set. (The net should be relatively simple: max. 4 layers, max. 128 neurons in a layer).
2. Implement a classifier and train it on the 100 labeled examples with a similar architecture to the previous subtask, only that this time the input to the network will be the hidden representation $\mathbf{z}=g_\phi(\mathbf{x})$ created by the autoencoder in Task 1. Report the accuracy on the test set.
3. Compare the results of both models. Which model performed better? Do you have any hypotheses as to why?

In [None]:
labeled_data = Subset(train_data, range(100))
labeled_loader = torch.utils.data.DataLoader(dataset=labeled_data, batch_size=32, shuffle=True)

test_data = MNIST(root='.', 
                   train=False, 
                   transform=transforms,    
                   download=True)

test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=5000, shuffle=True)

In [1]:
# implement and train the baseline model here

In [2]:
# implement and train the model based on the representation produced by the autoencoder here

[your answer here]

## Generative models

Neural-net based generative models allow us to [generate new faces](https://thispersondoesnotexist.com/) or [generate text](https://transformer.huggingface.co/doc/gpt2-large). The next task will involve creating a generative autoencoder and training it on the FashionMNIST dataset. 

In [None]:
transforms = Compose([ToTensor(), Lambda(lambda x: x.flatten())])
train_data = FashionMNIST(root='.', 
                          train=True, 
                          transform=transforms,
                          download=True)

batch_size = 256
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True) 

fig, axes = plt.subplots(2, 5, figsize=(13, 7))
for im, ax in zip(train_data.data[:10], axes.reshape(-1)):
    ax.imshow(im, cmap='gray')
    ax.set_xticks(())
    ax.set_yticks(())
fig.tight_layout()
plt.show()

## Wasserstein Autoencoder
The Wasserstein Autoencoder is identical in architecture to the vanilla one, with the additional constraint that the codes in the latent space are to form a normal distribution. Thanks to this we'll be able to generate new examples by sampling noise from the normal distribution and sending it through the decoder.

The loss function is comprised of two parts; the reconstruction loss and a distance between probability distributions:
$$L_{WAE-MMD} =\frac{1}{n}\sum_i \lVert\mathbf{x}_i-f_\theta(g_\phi(\mathbf{x}_i))\rVert_2^2+C\cdot \text{MMD}(g_\phi(\mathbf{x}_i),(\mathbf{y}_j)),$$
where $\mathbf{y}_j$ are samples from the normal distribution $\mathcal{N}(0, I)$, and $C \in \mathbb{R}$ is hyperparameter which weights the different components of the cost function.

The formula for Maximum Mean Discrepancy is as follows:
$$\text{MMD}((\mathbf{y}_i),(\mathbf{z}_j))=\frac{1}{n^2}\sum_{i,i'}k(\mathbf{y}_i,\mathbf{y}_{i'})+\frac{1}{n^2}\sum_{j,j'}k(\mathbf{z}_j,\mathbf{z}_{j'})-\frac{2}{n^2}\sum_{i,j}k(\mathbf{y}_i,\mathbf{z}_j),$$
where $k$ is a kernel function.

MMD describes a distance between the hidden representation $\mathbf{z}=g_\phi(\mathbf{x})$, obtained by passing the training examples through the encoder, and samples $\mathbf{y}_j\sim\mathcal{N}(0, I)$. Minimizing this cost will make the distribution produced by the encoder be more like the normal distribution, which is what we want to achieve.

We will use the IMQ (inverse multi-quadratic) kernel:
$$k(\mathbf{y}, \mathbf{z})=\frac{\sigma}{\sigma+\lvert\mathbf{y} - \mathbf{z} \rvert^2},$$
where $\sigma$ is a hyperparameter you need to find.

## Task 3 (1p)
Implement the Wasserstein Autoencoder with the Maximum Mean Discrepancy loss component.

1. Implement the autoencoder architecture (encoder + decoder) as in Task 1. The architecture should take into account that FashionMNIST is more complicated than MNIST (e.g. use 50 dimensions for the latent space).
2. Implement a training loop for WAE, where we minimize the loss function $L_{WAE-MMD}$.
3. Find hypeparameters (learning rate, number of training epochs, $C$, $\sigma$, etc.), so that the reconstruction and generated samples look decent (use the `plot_samples` function below). (Start from $C=1$, $\sigma=2D$, where $D$ is the dimensionality of the latent space.)

In [None]:
def plot_samples(model):
    sampled_z = torch.randn(20, model.latent_dim).to(device)
    generated = model.decode(sampled_z)

    generated = generated.cpu().detach().numpy()
    fig, axes = plt.subplots(2, 10, figsize=(15, 4))
    for gen_im, ax in zip(generated, axes.reshape(-1)):
        ax.imshow(gen_im.reshape(28, 28), cmap="gray")
        ax.set_xticks(())
        ax.set_yticks(())
    fig.tight_layout()
    fig.suptitle("Generated samples")
    plt.show()

In [None]:
class WAEMMD(nn.Module):
    
    def __init__(self, latent_dim):
        
        super(WAEMMD, self).__init__()
        
        self.latent_dim = latent_dim
        
        self.encoder = ???
        self.decoder = ???
    
    def encode(self, x):
        return ???
    
    def decode(self, encoded):
        return ???
    
    def forward(self, x):
        encoded = ???
        decoded = ???
        return encoded, decoded
    
    def mmd_loss(self, y, sigma):
        ???

In [None]:
n_epochs = ???
lr = ???
latent_dim = ???

wae = WAEMMD(latent_dim).to(device)

optimizer = torch.optim.Adam(wae.parameters(), lr=lr)

sigma = ???
C = ???

criterion = torch.nn.MSELoss()

for epoch in range(n_epochs):
    epoch_losses = []
    for step, (x, _) in enumerate(train_loader):
        
        optimizer.zero_grad()
        
        x = x.to(device)
        
        encoded, decoded = wae(x)

        rec_loss = ???
        latent_loss = ???
        loss = ???

        loss.backward()
        optimizer.step()

        epoch_losses += [loss.item()]

    print(f'Epoch: {epoch+1} | train loss: {np.mean(epoch_losses):.5f}')

    if epoch % 10 == 0:
        plot_dataset(train_data, wae)
        plot_pca(train_data, wae)
        plot_samples(wae)