In [1]:
#!pip install pythae

In [2]:
import torch
import torchvision.datasets as datasets

%load_ext autoreload
%autoreload 2

In [3]:
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

n_samples = 1000 / 255.
train_dataset = mnist_trainset.data[:n_samples].reshape(-1, 1, 28, 28) / 255.
eval_dataset = mnist_trainset.data[-100:].reshape(-1, 1, 28, 28) / 255.

In [4]:
from pythae.models import VAMP, VAMPConfig
from pythae.trainers import BaseTrainingConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models.nn.benchmarks.mnist import Encoder_VAE_MNIST, Decoder_AE_MNIST

In [7]:
config = BaseTrainingConfig(
    output_dir='my_model',
    learning_rate=1e-3,
    batch_size=200, # Set to 200 for demo purposes to speed up (default: 50)
    num_epochs=500, # Set to 500 for demo purposes. Augment this in your case to access to better generative model (default: 20000)
)


model_config = VAMPConfig(
    input_dim=(1, 28, 28),
    latent_dim=10,
    number_components=100

)

model = VAMP(
    model_config=model_config,
    encoder=Encoder_VAE_MNIST(model_config), 
    decoder=Decoder_AE_MNIST(model_config) 
)

In [8]:
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)

In [10]:
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset
)

Preprocessing train data...
Preprocessing eval data...
Model passed sanity check !

Created my_model/VAMP_training_2021-10-14_12-51-40. 
Training config, checkpoints and final model will be saved here.

Successfully launched training !
Training of epoch 1:   0%|          | 0/5 [00:01<?, ?batch/s]


KeyboardInterrupt: 

In [None]:
from pythae.pipelines.generation import GenerationPipeline

In [None]:
last_training = sorted(os.listdir('my_model'))[-1]
trained_model = VAE.load_from_folder(os.path.join('my_model', last_training, 'final_model'))

In [None]:
from pythae.samplers import NormalSampler

In [None]:
normal_samper = NormalSampler(
    model=trained_model
)

In [None]:
gen_data = normal_samper.sample(
    num_samples=10
)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(gen_data[1, 0].detach().numpy())