In [None]:
# Install the library
%pip install pythae

In [None]:
import torch
import torchvision.datasets as datasets

device = "cuda" if torch.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2

In [None]:
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

train_dataset = mnist_trainset.data[:-50000].reshape(-1, 1, 28, 28) / 255.
eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.

In [None]:
from pythae.models import VQVAE, VQVAEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models.nn.benchmarks.mnist.resnets import Encoder_ResNet_VQVAE_MNIST, Decoder_ResNet_VQVAE_MNIST

In [None]:
config = BaseTrainerConfig(
    output_dir='my_model',
    learning_rate=1e-3,
    batch_size=100,
    num_epochs=10, # Change this to train the model a bit more
)


model_config = VQVAEConfig(
    latent_dim=16,
    input_dim=(1, 28, 28),
    commitment_loss_factor=0.25,
    quantization_loss_factor=1.0,
    num_embeddings=128,
    use_ema=True,
    decay=0.99
)

model = VQVAE(
    model_config=model_config,
    encoder=Encoder_ResNet_VQVAE_MNIST(model_config), 
    decoder=Decoder_ResNet_VQVAE_MNIST(model_config) 
)

In [None]:
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)

In [None]:
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset
)

In [None]:
import os
from pythae.models import AutoModel

In [None]:
last_training = sorted(os.listdir('my_model'))[-1]
trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))

In [None]:
import torch
from pythae.samplers import PixelCNNSampler, PixelCNNSamplerConfig
from pythae.trainers import BaseTrainerConfig
sampler_config = PixelCNNSamplerConfig(n_layers=3, kernel_size=5) 
pixelcnn_sampler = PixelCNNSampler(model=trained_model, sampler_config=sampler_config)

In [None]:
pixelcnn_sampler.fit(train_data=torch.tensor(train_dataset), eval_data=torch.tensor(eval_dataset), training_config=BaseTrainerConfig(num_epochs=30, learning_rate=1e-4))

In [None]:
gen_data = pixelcnn_sampler.sample(
    num_samples=100,
    #output_dir='generated/mnist/vae_2_stage_mnist'
)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))


for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

## ... the other samplers work the same

## Visualizing reconstructions

In [None]:
reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()

In [None]:
# show reconstructions
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

In [None]:
# show the true data
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

## Visualizing interpolations

In [None]:
interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()

In [None]:
# show interpolations
fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))

for i in range(5):
    for j in range(10):
        axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)