In [1]:
import torch
import torchvision.datasets as datasets
import logging
device = "cuda" if torch.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2
    

In [2]:
logger = logging.getLogger(__name__)

# make it print to the console.
#console = logging.StreamHandler()
#logger.addHandler(console)
logger.setLevel(logging.ERROR)


In [3]:
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.
eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.
train_dataset.shape, eval_dataset.shape

(torch.Size([50000, 1, 28, 28]), torch.Size([10000, 1, 28, 28]))

In [4]:
from pythae.models import AE, AEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_AE_MNIST, Decoder_ResNet_AE_MNIST

In [5]:
from polcanet import PolcaNetPythae, PolcaNetConfig, LinearDecoderPythae

In [6]:
config = BaseTrainerConfig(
    output_dir='polca_mnist',
    learning_rate=1e-4,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_epochs=10, # Change this to train the model a bit more
)


model_config = PolcaNetConfig(
    input_dim=(1, 28, 28),
    latent_dim=16,
    alpha=0.1,
    beta=1.0,
    gamma=1.0,
)

decoder = LinearDecoderPythae(
        args=model_config,
        hidden_dim=256,
        num_layers=3
    )


model = PolcaNetPythae(
    model_config=model_config,
    encoder=Encoder_ResNet_AE_MNIST(model_config), 
    decoder=decoder 
)
model

PolcaNetPythae(
  (decoder): LinearDecoderPythae(
    (decoder): LinearDecoder(
      (decoder): Sequential(
        (0): Linear(in_features=16, out_features=256, bias=True)
        (1): Linear(in_features=256, out_features=256, bias=True)
        (2): Linear(in_features=256, out_features=784, bias=True)
      )
    )
  )
  (encoder): Encoder_ResNet_AE_MNIST(
    (layers): ModuleList(
      (0): Sequential(
        (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      )
      (1): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      )
      (2): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
      (3): Sequential(
        (0): ResBlock(
          (conv_block): Sequential(
            (0): ReLU()
            (1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1)

In [7]:
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)

In [9]:
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset,
    callbacks=None,
)

Preprocessing train data...
Checking train dataset...
Preprocessing eval data...

Checking eval dataset...
Using Base Trainer

Model passed sanity check !
Ready for training.

Created polca_mnist/PolcaNet_training_2024-07-10_14-41-20. 
Training config, checkpoints and final model will be saved here.

Training params:
 - max_epochs: 10
 - per_device_train_batch_size: 64
 - per_device_eval_batch_size: 64
 - checkpoint saving every: None
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0
)
Scheduler: None

Successfully launched training !



Training of epoch 1/10:   0%|          | 0/782 [00:00<?, ?batch/s]

Eval of epoch 1/10:   0%|          | 0/157 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 0.058
Eval loss: 0.0605
--------------------------------------------------------------------------


Training of epoch 2/10:   0%|          | 0/782 [00:00<?, ?batch/s]

Eval of epoch 2/10:   0%|          | 0/157 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 0.056
Eval loss: 0.0553
--------------------------------------------------------------------------


Training of epoch 3/10:   0%|          | 0/782 [00:00<?, ?batch/s]

Eval of epoch 3/10:   0%|          | 0/157 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 0.0549
Eval loss: 0.0544
--------------------------------------------------------------------------


Training of epoch 4/10:   0%|          | 0/782 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [None]:
import os
from pythae.models import AutoModel

In [None]:
last_training = sorted(os.listdir('my_model'))[-1]
trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))

In [None]:
from pythae.samplers import NormalSampler

In [None]:
# create normal sampler
normal_samper = NormalSampler(
    model=trained_model
)

In [None]:
# sample
gen_data = normal_samper.sample(
    num_samples=25
)

In [None]:
import matplotlib.pyplot as plt

In [None]:
# show results with normal sampler
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

In [None]:
from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig

In [None]:
# set up GMM sampler config
gmm_sampler_config = GaussianMixtureSamplerConfig(
    n_components=10
)

# create gmm sampler
gmm_sampler = GaussianMixtureSampler(
    sampler_config=gmm_sampler_config,
    model=trained_model
)

# fit the sampler
gmm_sampler.fit(train_dataset)

In [None]:
# sample
gen_data = gmm_sampler.sample(
    num_samples=25
)

In [None]:
# show results with gmm sampler
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

## ... the other samplers work the same

## Visualizing reconstructions

In [None]:
reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()

In [None]:
# show reconstructions
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

In [None]:
# show the true data
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

## Visualizing interpolations

In [None]:
interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()

In [None]:
# show interpolations
fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))

for i in range(5):
    for j in range(10):
        axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)