# VAEs

## Defining a VAE model

`vschaos` allow to easily define signle-layered or multi-layered variational auto-encoding models by only specifying their *signature*, creating automatically the modules described in the previous section. Each model is defined by a set of three dictionaries defining respectively 
* the *input parameters*, that can be accessed with `vae.input_params`
* the *latent parameters*, that can be accessed with `vae.latent_params`
* the *hidden parameters*, that can be accessed with `vae.hidden_params`



In [None]:
import vschaos
import vschaos.distributions as dist
import vschaos.vaes as vaes
from vschaos.modules import *
from vschaos.data import dataset_from_torchvision, Flatten

from torchvision.transforms import Lambda, ToTensor
from vschaos.vaes import VanillaVAE

# import MNIST
transforms = [Lambda(lambda x: x / 255. + 0.01*torch.randn_like(x.float())), Flatten(-2)]
dataset = dataset_from_torchvision('MNIST', transforms=transforms)

# make VAE parameters
input_params = {'dim':784, 'dist': dist.Normal}
class_params = {'dim':10, 'dist':dist.Categorical}
# different architectures for encoders and decoders can be specified using "encoder"
#   and "decoder" keywords
encoder_params = {'dim':800, 'nlayers':2}
decoder_params = {'dim':800, 'nlayers':3, 'normalization':None}
hidden_params = {'encoder':encoder_params, 'decoder':decoder_params}

latent_params = {'dim':8, "dist":dist.Normal}

# as simple as that!
vae = vaes.VanillaVAE(input_params, latent_params, hidden_params=hidden_params)
# if cuda
cuda = -1
device = torch.device(cuda) if 0 else -1
if cuda >= 0:
    vae = vae.cuda(cuda)

with torch.cuda.device(device):
    x, y = dataset[:64]
    out = vae(x, y=y)
print("latent parameters : ", out['z_params_enc'][0])
print("latent samples : ", out['z_enc'][0].shape)
print("data parameters : ", out['x_params'])


Each VAE of `vschaos` package derive from the abstract class `vschaos.vaes.AbstractVAE` , defining several high-level function such as model saving / loading, or registering a set of projected points using invertible dimensionality reduction methods, called *manifolds*.

In [None]:
# the vae can be saved using the `save method`, along with arbitrary data given as keywords
vae.save('test_save.pth', transforms=transforms)

# the load function also save a patch of the class, such that the object
#   can be initialized again
loaded_data = torch.load('test_save.pth')
vae = loaded_data['class'].load(loaded_data)
transforms = loaded_data['transforms']

## Training a VAE model

Training a VAE implies first initializing its optimizer, then choosing an accurate loss, and then repeat the training routine for a given number of epoch. While custom training routines can be defined, a high level object called `vschaos.train.SimpleTrainer` can be used to automatically perform casual tracking operations such as plotting, generating, early stopping, and model saving.

These three steps are proceeded as follows : 

In [None]:
from vschaos.criterions import ELBO
from vschaos.monitor.visualize_dimred import PCA
from vschaos.train import SimpleTrainer, train_model

# initializing optimizer
optim_params = {'optimizer':'Adam', 'optimArgs':{'lr':1e-3}, 'scheduler':'ReduceLROnPlateau'}
vae.init_optimizer(optim_params)

# defining a loss (see next notebook)
loss = ELBO(beta=4.0, warmup=20)

# The Trainer object performs training, monitoring, and automating saving during the training process.
dataset = dataset.retrieve(np.random.permutation(len(dataset.data))[:1000])
plots = {}
plots['reconstructions'] = {'preprocess': False, "transforms":transforms, "n_points":15, 'plot_multihead':True, 'label':['class']}
plots['latent_space'] = {'preprocess':False, 'transformation':PCA, 'tasks':'class', 'balanced':True, 'n_points':3000, 'label':['class'], 'batch_size':512}

trainer = SimpleTrainer(vae, dataset, loss, tasks=["class"], plots=plots, use_tensorboard="runs/")
device = torch.device(cuda) if cuda >= 0 else -1

train_options = {'epochs':100, 'save_epochs':20, 'results_folder':'tutorial_3',  'batch_size':64}
with torch.cuda.device(device):
    train_model(trainer, train_options, save_with={'transforms':dataset.classes})