# Tutorial

In this notebook, we will see how to pass your own encoder and decoder's architectures to your VAE model using pythae!

In [None]:
# If you run on colab uncomment the following line
#!pip install git+https://github.com/clementchadebec/benchmark_VAE.git

In [None]:
import torch
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import os

%matplotlib inline
%load_ext autoreload
%autoreload 2

### Get the data

In [None]:
mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)
n_samples = 200
dataset = mnist_trainset.data[np.array(mnist_trainset.targets)==2][:n_samples].reshape(-1, 1, 28, 28) / 255.

In [None]:
fig, axes = plt.subplots(2, 10, figsize=(10, 2))
for i in range(2):
        for j in range(10):
                axes[i][j].matshow(dataset[i*10 +j].reshape(28, 28), cmap='gray')
                axes[i][j].axis('off')

plt.tight_layout(pad=0.8)

## Let's build a custom auto-encoding architecture!

### First thing, you need to import the ``BaseEncoder`` and ``BaseDecoder`` as well as ``ModelOutput`` classes from pythae by running

In [None]:
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_utils import ModelOutput

### Then build your own architectures

In [None]:
import torch.nn as nn


class Encoder_VAE_MNIST(BaseEncoder):
    def __init__(self, args):
        BaseEncoder.__init__(self)

        self.input_dim = (1, 28, 28)
        self.latent_dim = args.latent_dim
        self.n_channels = 1

        self.conv_layers = nn.Sequential(
            nn.Conv2d(self.n_channels, 128, 4, 2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, 4, 2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, 4, 2, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
        )

        self.embedding = nn.Linear(1024, args.latent_dim)
        self.log_var = nn.Linear(1024, args.latent_dim)

    def forward(self, x: torch.Tensor):
        h1 = self.conv_layers(x).reshape(x.shape[0], -1)
        output = ModelOutput(
            embedding=self.embedding(h1),
            log_covariance=self.log_var(h1)
        )
        return output


class Decoder_AE_MNIST(BaseDecoder):
    def __init__(self, args):
        BaseDecoder.__init__(self)
        self.input_dim = (1, 28, 28)
        self.latent_dim = args.latent_dim
        self.n_channels = 1

        self.fc = nn.Linear(args.latent_dim, 1024 * 4 * 4)
        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 3, 2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 3, 2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, self.n_channels, 3, 2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def forward(self, z: torch.Tensor):
        h1 = self.fc(z).reshape(z.shape[0], 1024, 4, 4)
        output = ModelOutput(reconstruction=self.deconv_layers(h1))

        return output

### Define a model configuration (in which the latent will be stated). Here, we use the RHVAE model.

In [None]:
from pythae.models import VAEConfig

model_config = VAEConfig(
    input_dim=(1, 28, 28),
    latent_dim=10
    )

### Build your encoder and decoder

In [None]:
encoder = Encoder_VAE_MNIST(model_config)
decoder= Decoder_AE_MNIST(model_config)

### Last but not least. Build you RHVAE model by passing the ``encoder`` and ``decoder`` arguments

In [None]:
from pythae.models import VAE

model = VAE(
    model_config=model_config,
    encoder=encoder,
    decoder=decoder
)

### Now you can see the model that you've just built contains the custom autoencoder and decoder

In [None]:
model

### *note*: If you want to launch a training of such a model, try to ensure that the provided architectures are suited for the data. pythae performs a model sanity check before launching training and raises an error if the model cannot encode and decode an input data point

## Train the model !

In [None]:
from pythae.trainers import BaseTrainingConfig
from pythae.pipelines import TrainingPipeline

### Build the training pipeline with your ``TrainingConfig`` instance

In [None]:
training_config = BaseTrainingConfig(
    output_dir='my_model_with_custom_archi',
    learning_rate=1e-3,
    batch_size=200,
    steps_saving=None,
    num_epochs=200)

In [None]:
pipeline = TrainingPipeline(
    model=model,
    training_config=training_config)

### Launch the ``Pipeline``

In [None]:
torch.manual_seed(8)
torch.cuda.manual_seed(8)

pipeline(
    train_data=dataset
)

### *note 1*: You will see now that a ``encoder.pkl`` and ``decoder.pkl`` appear in the folder ``my_model_with_custom_archi/training_YYYY_MM_DD_hh_mm_ss/final_model`` to allow model rebuilding with your own architecture ``Encoder_VAE_MNIST`` and ``Decoder_AE_MNIST``.

### *note 2*: Model rebuilding is based on the [dill](https://pypi.org/project/dill/) librairy allowing to reload the class whithout importing them. Hence, you should still be able to reload the model even if the classes ``Encoder_VAE_MNIST`` or ``Decoder_AE_MNIST`` were not imported.  

In [None]:
last_training = sorted(os.listdir('my_model_with_custom_archi'))[-1]
print(last_training)

### You can now reload the model easily using the classmethod ``VAE.load_from_folder``

In [None]:
model_rec = VAE.load_from_folder(os.path.join('my_model_with_custom_archi', last_training, 'final_model'))
model_rec

## The model can now be used to generate new samples !

In [None]:
from pythae.samplers import NormalSampler


sampler = NormalSampler(
    model=model_rec
)
gen_data = sampler.sample(
    num_samples=25
)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().reshape(28, 28), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)