# Gaussian VAE for CIFAR-10

This notebook showcases the application of a Gaussian-based VAE to the CIFAR-10 data set. Similar as in the slightly simpler MNIST example, a generative model can be trained by calling the main script with some suitable configs. For example: `python scripts/main.py fit --config config/cifar10.yaml`. After the training has been completed, the final model is imported and tested in the following.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import seed_everything

from varautoenc import(
    CIFAR10DataModule,
    ConvVAE,
    generate,
    reconstruct,
    encode_loader
)

In [None]:
_ = seed_everything(111111)  # set random seeds manually

## CIFAR-10 data

In [None]:
cifar = CIFAR10DataModule(
    data_dir='../run/data/',
    mean=(0.5, 0.5, 0.5),
    std=(0.5, 0.5, 0.5),
    batch_size=32
)

cifar.prepare_data()  # download data if not yet done
cifar.setup(stage='test')  # create test set

In [None]:
test_loader = cifar.test_dataloader()
x_batch, y_batch = next(iter(test_loader))

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(5, 4.5))
for idx, ax in enumerate(axes.ravel()):
    image = cifar.renormalize(x_batch[idx]).permute(1, 2, 0).numpy()
    ax.imshow(image)
    ax.set_title(cifar.test_set.classes[y_batch[idx]])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Model import

In [None]:
ckpt_file = '../run/cifar10/version_0/checkpoints/last.ckpt'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

vae = ConvVAE.load_from_checkpoint(ckpt_file)

vae = vae.eval()
vae = vae.to(device)

In [None]:
print(f'Likelihood: {vae.likelihood_type}')
print(f'Beta: {vae.beta}')

try:
    sigma = vae.decoder.dist_params.logsigma.exp().item()
    print(f'Sigma: {sigma:.4f}')
except AttributeError as e:
    pass

## Latent space

In [None]:
z_mu, z_sigma = encode_loader(
    vae, test_loader
)

print('Latent mean: {:.2f}'.format(z_mu.mean()))
print('Latent std.: {:.2f}'.format(z_mu.std()))

## Image reconstruction

In [None]:
x_batch, _ = next(iter(test_loader))

x_recon = reconstruct(vae, x_batch, sample_mode=False)#
x_recon = x_recon.clamp(-1, 1)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=6, figsize=(8, 3.5))

for idx, ax in enumerate(axes[0]):
    image = cifar.renormalize(x_batch[idx]).permute(1, 2, 0).numpy()
    ax.imshow(image)
    ax.set_title('$x^{{({})}}$'.format(idx + 1))
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')

for idx, ax in enumerate(axes[1]):
    image = cifar.renormalize(x_recon[idx]).permute(1, 2, 0).numpy()
    ax.imshow(image)
    ax.set_title('$\\hat{{x}}^{{({})}}$'.format(idx + 1))
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')

fig.tight_layout()

## Random generation

In [None]:
num_samples = 25

num_latents = vae.decoder.dense_layers[0][0].in_features
x_gen = generate(vae, sample_shape=(num_latents,), num_samples=num_samples)
x_gen = x_gen.clamp(-1, 1)

In [None]:
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(5, 5))
for idx, ax in enumerate(axes.ravel()):
    image = cifar.renormalize(x_gen[idx]).permute(1, 2, 0).numpy()
    ax.imshow(image)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()