# Tutorial 3

In this notebook, we will see how to pass your own encoder and decoder's architectures to your VAE model using pythae! This is an illustraion of the section ``Making yor own autoencoder mdoel`` of the documentation.

In [None]:
# If you run on colab uncomment the following lines
#!git clone https://github.com/clementchadebec/pythae.git
#!pip install pythae

In [None]:
import torch
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import os

%matplotlib inline
%load_ext autoreload
%autoreload 2

### Get the data

In [None]:
cifar_trainset = datasets.CIFAR10(root='../data', train=True, download=True, transform=None)
n_samples = 200
_dataset_to_augment = cifar_trainset.data[np.array(cifar_trainset.targets)==2][:n_samples]

In [None]:
fig, axes = plt.subplots(2, 10, figsize=(10, 2))
for i in range(2):
        for j in range(10):
                axes[i][j].matshow(_dataset_to_augment[i*10 +j].reshape(32, 32, 3), cmap='gray')
                axes[i][j].axis('off')

plt.tight_layout(pad=0.8)

## Let's build a custom auto-encoding architecture!

### First thing, you need to import the ``BaseEncoder`` and ``BaseDecoder`` classes from pythae by running

In [None]:
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_utils import ModelOuput

### Then build your own architectures

In [None]:
import torch.nn as nn

class Encoder_VAE_CIFAR(BaseEncoder):

    def __init__(self, args):
        BaseEncoder.__init__(self)

        self.input_dim = (3, 32, 32)
        self.latent_dim = args.latent_dim
        self.n_channels = 3
        
        self.conv_layers = nn.Sequential(
                        nn.Conv2d(self.n_channels, 128, 4, 2, padding=1),
                        nn.BatchNorm2d(128),
                        nn.ReLU(),
                        nn.Conv2d(128, 256, 4, 2, padding=1),
                        nn.BatchNorm2d(256),
                        nn.ReLU(),
                        nn.Conv2d(256, 512, 4, 2, padding=1),
                        nn.BatchNorm2d(512),
                        nn.ReLU(),
                        nn.Conv2d(512, 1024, 4, 2, padding=1),
                        nn.BatchNorm2d(1024),
                        nn.ReLU(),
                    )

        self.embedding = nn.Linear(1024*2*2, args.latent_dim)
        self.log_var =  nn.Linear(1024*2*2, args.latent_dim)

    def forward(self, x: torch.Tensor):
        h1 = self.conv_layers(x).reshape(x.shape[0], -1)
        output = ModelOuput(
            embedding=self.embedding(h1),
            log_covariance=self.log_var(h1)
        )
        return output


class Decoder_AE_CIFAR(BaseDecoder):
    def __init__(self, args):
        BaseDecoder.__init__(self)
        self.input_dim = (3, 28, 28)
        self.latent_dim = args.latent_dim
        self.n_channels = 3

        self.fc = nn.Linear(args.latent_dim, 1024*8*8)
        self.deconv_layers = nn.Sequential(
                        nn.ConvTranspose2d(1024, 512, 4, 2, padding=1),
                        nn.BatchNorm2d(512),
                        nn.ReLU(),
                        nn.ConvTranspose2d(512, 256, 4, 2, padding=1, output_padding=1),
                        nn.BatchNorm2d(256),
                        nn.ReLU(),
                        nn.ConvTranspose2d(256, self.n_channels, 4, 1, padding=2),
                        nn.Sigmoid()
                )
    
    def forward(self, z: torch.Tensor):
        h1 = self.fc(z).reshape(z.shape[0], 1024, 8, 8)
        output = ModelOuput(
            reconstruction=self.deconv_layers(h1)
        )
        return output

### Define a model configuration (in which the latent will be stated). Here, we use the RHVAE model.

In [None]:
from pythae.models import VAEConfig

model_config = VAEConfig(
    input_dim=(3, 32, 32),
    latent_dim=10
    )

### Build your encoder and decoder

In [None]:
encoder = Encoder_VAE_CIFAR(model_config)
decoder= Decoder_AE_CIFAR(model_config)

### Last but not least. Build you RHVAE model by passing the ``encoder`` and ``decoder`` arguments

In [None]:
from pythae.models import VAE

model = VAE(
    model_config=model_config,
    encoder=encoder,
    decoder=decoder
)

### Now you can see the model that you've just built contains the custom autoencoder and decoder

In [None]:
model

### *note*: If you want to launch a training of such a model, try to ensure that the provided architectures are suited for the data. pythae performs a model sanity check before launching training and raises an error if the model cannot encode and decode an input data point

## Train the model !

### Likewise tutorial 1, you can now trained you model.

In [None]:
from pythae.trainers import BaseTrainingConfig
from pythae.pipelines import TrainingPipeline

### Build the training pipeline with your ``TrainingConfig`` instance

In [None]:
training_config = BaseTrainingConfig(
    output_dir='my_model_with_custom_archi',
    learning_rate=1e-3,
    batch_size=200,
    steps_saving=None,
    num_epochs=200)

In [None]:
pipeline = TrainingPipeline(
    data_processor=None,
    model=model,
    optimizer=None,
    training_config=training_config)

### Launch the ``Pipeline``

In [None]:
# set channel first
dataset_to_augment = np.transpose(_dataset_to_augment, (0, 3, 2, 1))

torch.manual_seed(8)
torch.cuda.manual_seed(8)

pipeline(
    train_data=dataset_to_augment
)

### *note 1*: You will see now that a ``encoder.pkl`` and ``decoder.pkl`` appear in the folder ``my_model_with_custom_archi/trainin_YYYY_MM_DD_hh_mm_ss/final_model`` to allow model rebuilding with your own architecture ``Encoder_Conv`` and ``Decoder_Conv``.

### *note 2*: Model rebuilding is based on the [dill](https://pypi.org/project/dill/) librairy allowing to reload the class whithout importing them. Hence, you should still be able to reload the model even if the classes ``Encoder_Conv`` or ``Decoder_Conv`` were not imported.  

### You can now reload the model easily using the classmethod ``RHVAE.load_from_folder``

In [None]:
last_training = sorted(os.listdir('my_model_with_custom_archi'))[-1]
print(last_training)

In [None]:
model_rec = RHVAE.load_from_folder(os.path.join('my_model_with_custom_archi', last_training, 'final_model'))
model_rec

## Likewise tutorial 1, the model can then be used to generate new samples !

In [None]:
from pythae.models.rhvae import RHVAESamplerConfig
from pythae.models.rhvae.rhvae_sampler import RHVAESampler

sampler_config = RHVAESamplerConfig(
        output_dir='my_generated_data_with_custom_archi',
        mcmc_steps_nbr=50,
        eps_lf=0.001,
        batch_size=100,
        no_cuda=False
        )

sampler = RHVAESampler(
    model=model_rec,
    sampler_config=sampler_config
)

In [None]:
from pythae.pipelines import GenerationPipeline

generation_pipe = GenerationPipeline(
    model=model,
    sampler=sampler
)

In [None]:
torch.manual_seed(8)
torch.cuda.manual_seed(8)

generation_pipe(100)

In [None]:
last_generation = sorted(os.listdir('my_generated_data_with_custom_archi'))[-1]

In [None]:
generated_data = torch.load(os.path.join('my_generated_data_with_custom_archi', last_generation, 'generated_data_100_0.pt'))

In [None]:
fig, axes = plt.subplots(2, 10, figsize=(10, 2))
im_n = 0
for i in range(2):
    for j in range(10):
                axes[i][j].matshow(np.transpose(generated_data[i*10 +j].cpu().reshape(3, 32, 32), (2, 1,0)) , cmap='gray')
                axes[i][j].axis('off')

plt.tight_layout(pad=0.8)

In [None]:
for i in range(len(generated_data)):
    for j in range(i+1, len(generated_data)):
        assert not torch.equal(generated_data[j].cpu(), torch.tensor(generated_data[i]).cpu().type(torch.float)), (i, j)