<a href="https://colab.research.google.com/github/juankuntz/ParEM/blob/torch_code/torch/notebooks/CelebA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import modules

In [None]:
# Install missing modules
%%capture
# Install missing packages
!pip install torchtyping
!pip install wandb
!pip install torchmetrics[image]

In [None]:
# Import standard modules
import torch
import numpy as np
import sys
import matplotlib.pyplot as plt
import argparse
#from pathlib import Path
from google.colab import drive

In [None]:
%%capture
# Import custom modules
from pathlib import Path
CHECKPOINTS_PATH = Path('/content/gdrive/MyDrive/ParEM/celeba/checkpoints')

!rm -rf ParEM
!git clone https://github.com/juankuntz/ParEM.git
REPOSITORY_PATH = '/content/ParEM/torch'
if REPOSITORY_PATH not in sys.path:
    sys.path.append(REPOSITORY_PATH)

# Import custom modules
from parem.models import NLVM
from parem.algorithms import (PGA,
                              ShortRun,
                              VI,
                              AlternatingBackprop)
from parem.utils import get_celeba, load_checkpoint

# Set config variables

In [None]:
# Data setttings
N_IMAGES = 10000  # M: training set size 

# Training settings
N_BATCH = 128 # M_b: batch size for theta updates
N_EPOCHS = 500 # n_epochs = K * M_b / M where K = total number of iterations
SEED = 1 # Seed for PRNG
# Device on which to carry out computations:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OPTIMIZER = 'rmsprop'  # Theta optimizer

# Model Settings
X_DIM = 64  # d_x: dimension of latent space
LIKELIHOOD_VAR = 0.01 ** 2  # sigma^2

# PGA Settings
STEP_SIZE = 1e-4 # h: step size 
LAMBDA = 1e-3 / (STEP_SIZE * N_IMAGES)  # lambda
N_PARTICLES = 10 # N: number of particles

constants_to_be_logged = {'Number of training images': N_IMAGES, 
                          'Batch sizes':N_BATCH, 
                          'PRNG seed': SEED, 
                          'Latent variable dimension': X_DIM, 
                          'sigma^2': LIKELIHOOD_VAR,
                          'step size': STEP_SIZE, 
                          'lambda': LAMBDA, 
                          'Number of particles': N_PARTICLES}

# Load dataset

In [None]:
%%capture
from pathlib import Path
drive.mount("/content/gdrive", force_remount=False)
GDRIVE_CELEBA_PATH = Path("/content/gdrive/MyDrive/datasets/celeba")
LOCAL_CELEBA_DIR_PATH = Path("/content/") / "celeba"
assert GDRIVE_CELEBA_PATH.is_dir()
if not LOCAL_CELEBA_DIR_PATH.is_dir():
  !cp -r $GDRIVE_CELEBA_PATH -d /content/
  img_aligned_zip_path = LOCAL_CELEBA_DIR_PATH / "img_align_celeba.zip"
  !unzip $img_aligned_zip_path -d $LOCAL_CELEBA_DIR_PATH

dataset = get_celeba(LOCAL_CELEBA_DIR_PATH / "img_align_celeba", N_IMAGES)  # Load dataset

# Define and train model

In [None]:
# # Define model:
model = NLVM(x_dim=X_DIM, sigma2=LIKELIHOOD_VAR, nc=3).to(DEVICE)

# Define training algorithm:
pga = PGA(model=model, lambd=LAMBDA, dataset=dataset, train_batch_size=N_BATCH,
                          particle_step_size=STEP_SIZE, device=DEVICE,
                          theta_optimizer=OPTIMIZER)

pga.run(N_EPOCHS, CHECKPOINTS_PATH / '/mnist_vae1.pt', compute_stats=False, wandb_log=False, log_images=False)

# Show particle cloud

In [None]:
#@title Load auxiliary functions
from torchvision.utils import make_grid
import torchvision.transforms.functional as F

to_range_0_1 = lambda x: (x + 1.) / 2.

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, dpi=400)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    return fig

In [None]:
pga.sample_image_posterior(10, 1)

## Generate synthetic samples

In [None]:
pga.synthesize_images(n=64, approx_type='gmm', n_components=3)

## Inpainting 

In [None]:
n_missing_img = 20
images = torch.stack(dataset[:n_missing_img][0], dim=0)
mask = torch.ones(dataset.height, dataset.width, dtype=torch.bool)

for i in range(10, 22):
  for j in range(10, 22):
        mask[i, j] = False

pga.reconstruct(images, mask)