In [None]:
# If you run on colab uncomment the following line
#!pip install git+https://github.com/clementchadebec/benchmark_VAE.git

In [None]:
import torch
import torchvision.datasets as datasets

%load_ext autoreload
%autoreload 2

In [None]:
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.
eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.

In [None]:
from pythae.models import RAE_L2, RAE_L2_Config
from pythae.trainers import CoupledOptimizerTrainerConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models.nn.benchmarks.mnist import Encoder_AE_MNIST, Decoder_AE_MNIST

In [None]:
config = CoupledOptimizerTrainerConfig(
    output_dir='my_model',
    learning_rate=1e-4,
    batch_size=100,
    num_epochs=100,
)


model_config = RAE_L2_Config(
    input_dim=(1, 28, 28),
    latent_dim=10,
    embedding_weight=1e-4,
    reg_weight=1e-4
)

model = RAE_L2(
    model_config=model_config,
    encoder=Encoder_AE_MNIST(model_config), 
    decoder=Decoder_AE_MNIST(model_config) 
)

In [None]:
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)

In [None]:
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset
)

In [None]:
import os

In [None]:
last_training = sorted(os.listdir('my_model'))[-1]
trained_model = RAE_L2.load_from_folder(os.path.join('my_model', last_training, 'final_model'))

In [None]:
from pythae.samplers import NormalSampler

In [None]:
normal_samper = NormalSampler(
    model=trained_model
)

In [None]:
gen_data = normal_samper.sample(
    num_samples=10
)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(gen_data[1, 0].detach().cpu().numpy())