In [2]:
import torch
import torchvision.datasets as datasets

device = torch.device('mps')

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.
eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.

In [4]:
import numpy as np

##corrupt data
def corrupt_data(data):
  square_size = data.shape[-1]
  center_coord = np.random.randint(0,square_size,(2,data.shape[0]))
  square_size = np.repeat(np.random.randint(0,square_size//2,data.shape[0]).reshape(1,-1),2,axis=0)
  upper_left_coord = np.minimum(np.maximum(center_coord - (square_size//2),0),square_size)
  for i in range(data.shape[0]):
    data[i,:,upper_left_coord[0,i]:upper_left_coord[0,i]+square_size[0,i],upper_left_coord[1,i]:upper_left_coord[1,i]+square_size[1,i]] = -1
  return data

train_dataset = corrupt_data(train_dataset)
eval_dataset = corrupt_data(eval_dataset)

In [5]:
from pythae.models import vAE, AEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models.nn.benchmarks.mnist import Encoder_vAE_MNIST, Decoder_AE_MNIST
#from pythae.models.nn import Encoder_vAE

In [6]:
config = BaseTrainerConfig(
    output_dir='my_model',
    learning_rate=1e-4,
    batch_size=100,
    num_epochs=10, # Change this to train the model a bit more
)


model_config = AEConfig(
    input_dim=(1, 28, 28),
    latent_dim=16
)

model = vAE(
    model_config=model_config,
    encoder=Encoder_vAE_MNIST(model_config), 
    decoder=Decoder_AE_MNIST(model_config) 
)

In [11]:
model.nU = 2

In [7]:
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)

In [None]:
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset
)

Preprocessing train data...
Preprocessing eval data...

Using Base Trainer

Model passed sanity check !

Created my_model/vAE_training_2022-11-17_15-50-46. 
Training config, checkpoints and final model will be saved here.

Successfully launched training !



Training of epoch 1/10:   0%|                        | 0/500 [00:00<?, ?batch/s][A[A


Eval of epoch 1/10:   0%|                            | 0/100 [00:00<?, ?batch/s][A[A[A

Training of epoch 1/10:   0%|                | 1/500 [00:05<49:04,  5.90s/batch][A[A

In [None]:
import os
from pythae.models import AutoModel

In [None]:
last_training = sorted(os.listdir('my_model'))[-1]
trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))

In [None]:
from pythae.samplers import NormalSampler

In [None]:
# create normal sampler
normal_samper = NormalSampler(
    model=trained_model
)

In [None]:
# sample
gen_data = normal_samper.sample(
    num_samples=25
)

In [None]:
import matplotlib.pyplot as plt

In [None]:
# show results with normal sampler
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

In [None]:
from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig

In [None]:
# set up GMM sampler config
gmm_sampler_config = GaussianMixtureSamplerConfig(
    n_components=10
)

# create gmm sampler
gmm_sampler = GaussianMixtureSampler(
    sampler_config=gmm_sampler_config,
    model=trained_model
)

# fit the sampler
gmm_sampler.fit(train_dataset)

In [None]:
# sample
gen_data = gmm_sampler.sample(
    num_samples=25
)

In [None]:
# show results with gmm sampler
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

## ... the other samplers work the same

## Visualizing reconstructions

In [None]:
reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()

In [None]:
# show reconstructions
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

In [None]:
# show the true data
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

## Visualizing interpolations

In [None]:
interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()

In [None]:
# show interpolations
fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))

for i in range(5):
    for j in range(10):
        axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)