In [None]:
# If you run on colab uncomment the following line
#!pip install git+https://github.com/clementchadebec/benchmark_VAE.git

In [None]:
import torch
import torchvision.datasets as datasets

%load_ext autoreload
%autoreload 2

In [None]:
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

train_dataset = mnist_trainset.data[:-50000].reshape(-1, 1, 28, 28) / 255.
eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.

In [None]:
from pythae.models import VQVAE, VQVAEConfig
from pythae.trainers import BaseTrainingConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models.nn.benchmarks.mnist import Encoder_AE_MNIST, Decoder_AE_MNIST

In [None]:
config = BaseTrainingConfig(
    output_dir='my_model',
    learning_rate=1e-3,
    batch_size=100,
    num_epochs=10,
)


model_config = VQVAEConfig(
    latent_dim=16,
    input_dim=(1, 28, 28),
    num_embeddings=10
)

model = VQVAE(
    model_config=model_config,
    encoder=Encoder_AE_MNIST(model_config), 
    decoder=Decoder_AE_MNIST(model_config) 
)

In [None]:
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)

In [None]:
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset
)

In [None]:
import os

In [None]:
last_training = sorted(os.listdir('my_model'))[-1]
trained_model = AE.load_from_folder(os.path.join('my_model', last_training, 'final_model'))

In [None]:
recon = trained_model({'data': eval_dataset[:50]}).recon_x.detach().cpu()

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(recon[i*5 +j].cpu().permute(1, 2, 0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(train_data[i*5 +j].cpu().permute(1, 2, 0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)