In [None]:
# If you run on colab uncomment the following line
#!pip install git+https://github.com/clementchadebec/benchmark_VAE.git

In [None]:
import torch
import torchvision.datasets as datasets

%load_ext autoreload
%autoreload 2

In [None]:
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.
eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.

In [None]:
from pythae.models import WAE_MMD, WAE_MMD_Config
from pythae.trainers import BaseTrainingConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models.nn.benchmarks.mnist import Encoder_AE_MNIST, Decoder_AE_MNIST

In [None]:
config = BaseTrainingConfig(
    output_dir='my_model',
    learning_rate=1e-3,
    batch_size=100,
    num_epochs=100,
)


model_config = WAE_MMD_Config(
    input_dim=(1, 28, 28),
    latent_dim=16,
    kernel_choice='imq',
    reg_weight=1.0,
    kernel_bandwidth=2
)

model = WAE_MMD(
    model_config=model_config,
    encoder=Encoder_AE_MNIST(model_config), 
    decoder=Decoder_AE_MNIST(model_config) 
)

In [None]:
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)

In [None]:
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset
)

In [None]:
import os

In [None]:
last_training = sorted(os.listdir('my_model'))[-1]
trained_model = WAE_MMD.load_from_folder(os.path.join('my_model', last_training, 'final_model'))

In [None]:
from pythae.samplers import NormalSampler

In [None]:
# create normal sampler
normal_samper = NormalSampler(
    model=trained_model
)

In [None]:
# sample
gen_data = normal_samper.sample(
    num_samples=25
)

In [None]:
import matplotlib.pyplot as plt

In [None]:
# show results with normal sampler
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

In [None]:
from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig

In [None]:
# set up gmm sampler config
gmm_sampler_config = GaussianMixtureSamplerConfig(
    n_components=10
)

# create gmm sampler
gmm_sampler = GaussianMixtureSampler(
    sampler_config=gmm_sampler_config,
    model=trained_model
)

# fit the sampler
gmm_sampler.fit(train_dataset)

In [None]:
# sample
gen_data = gmm_sampler.sample(
    num_samples=25
)

In [None]:
# show results with gmm sampler
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

## ... the other samplers work the same