<a href="https://colab.research.google.com/github/ParticleEM/ParEM_neural_latent_variable_model/blob/master/notebooks/CelebA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import modules

In [None]:
# Install missing modules
%%capture
!pip install torchtyping

In [None]:
# Import standard modules
import torch
import numpy as np
import sys
import matplotlib.pyplot as plt
import argparse
#from pathlib import Path
from google.colab import drive

In [None]:
# Import custom modules
!rm -rf ParEM_neural_latent_variable_model
!git clone https://github.com/ParticleEM/ParEM_neural_latent_variable_model.git
sys.path.append("/content/ParEM_neural_latent_variable_model/")
from parem.model import G
from parem.pga import PGA
from parem.dataset_loaders import get_celeba

# Set config variables

In [None]:
# Declare dicitonary-like object for storing config variables:
args = argparse.Namespace()

# Data setttings
args.n_images = 40000  # M: training set size 

# Training settings
args.n_batch = 128 # M_b: batch size for theta updates
args.n_epochs = 30 # n_epochs = K * M_b / M where K = total number of iterations
args.seed = 1 # Seed for PRNG
# Device on which to carry out computations:
args.device = "cuda" if torch.cuda.is_available() else "cpu"

# Model Settings
args.x_dim = 64  # d_x: dimension of latent space
args.likelihood_var = 0.3 ** 2  # sigma^2

# PGA Settings
args.h = 1e-4 # h: step size 
args.lambd = 1e-3 / (args.h * args.n_images)  # lambda
args.n_particles = 10 # N: number of particles

# Load dataset

In [None]:
%%capture
from pathlib import Path
drive.mount("/content/gdrive", force_remount=False)
GDRIVE_CELEBA_PATH = Path("/content/gdrive/MyDrive/celeba/celeba", force_remount=False)
LOCAL_CELEBA_DIR_PATH = Path("/content/") / "celeba"
assert GDRIVE_CELEBA_PATH.is_dir()
if not LOCAL_CELEBA_DIR_PATH.is_dir():
  !cp -r $GDRIVE_CELEBA_PATH -d /content/
  img_aligned_zip_path = LOCAL_CELEBA_DIR_PATH / "img_align_celeba.zip"
  !unzip $img_aligned_zip_path -d $LOCAL_CELEBA_DIR_PATH

dataset = get_celeba(LOCAL_CELEBA_DIR_PATH / "img_align_celeba", args.n_images)  # Load dataset

# Define and train model

In [None]:
# Define model:
model = G(args.x_dim, sigma2=args.likelihood_var, nc=3).to(args.device)

# Define training algorithm:
pga = PGA(model, dataset, args.h, args.lambd, args.n_particles)

# Split dataset into batches for training:
training_batches = torch.utils.data.DataLoader(dataset, batch_size=args.n_batch, 
                                               shuffle=True, pin_memory=True)

# Train:
losses = []
for epoch in range(args.n_epochs):
  # model.train()
  avg_loss = 0
  for imgs, idx in training_batches:
      loss = pga.step(imgs.to(device=args.device), idx)
      avg_loss += loss
      print(".", end='')
  avg_loss = avg_loss / len(training_batches) 
  losses.append(avg_loss)
  print(f"Epoch {epoch}: Loss {avg_loss}")

# Show particle cloud

In [None]:
#@title Load auxiliary functions
from torchvision.utils import make_grid
import torchvision.transforms.functional as F

to_range_0_1 = lambda x: (x + 1.) / 2.

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, dpi=400)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    return fig

In [None]:
model.eval()  # Turn on evaluation mode
i = 0  # Image index

with torch.no_grad():
  torch.random.manual_seed(1)
  original_img = to_range_0_1(dataset[i][0].unsqueeze(0))
  particle_img = to_range_0_1(model(pga._particles[i, :].to(args.device))).to(original_img.device)
  grid = make_grid(torch.concat([original_img, particle_img], dim=0))
  show(grid)

## Generate synthetic samples

In [None]:
with torch.no_grad():
  n_cols = 8
  n_rows = 8
  mean = torch.mean(pga._particles, [0, 1, 3, 4])
  cov = torch.cov(pga._particles.flatten(0,1).flatten(1, 3).transpose(0, 1))
  normal_approx = torch.distributions.multivariate_normal.MultivariateNormal(loc = mean, covariance_matrix=cov)
  z = normal_approx.sample(sample_shape=torch.Size([n_cols * n_rows])).unsqueeze(-1).unsqueeze(-1)
  samples = to_range_0_1(model(z.to(args.device)))
  grid = make_grid(samples)
  fig = show(grid)

## Inpainting 

In [None]:
n_missing_img = 10
missing_imgs = torch.stack(dataset[:n_missing_img][0], dim=0)
init_x = normal_approx.sample(sample_shape=torch.Size([n_missing_img])).unsqueeze(-1).unsqueeze(-1).requires_grad_(True)
opt = torch.optim.Adam([init_x], 1e-2)
mse = torch.nn.MSELoss()
missing_mask = torch.zeros_like(missing_imgs, dtype=torch.bool)

for i in range(10, 22):
  for j in range(10, 22):
        missing_mask[..., i, j] = True

for i in range(1000):
  opt.zero_grad()
  filled_imgs = model.forward(init_x.to(args.device)).to('cpu')
  loss = mse(filled_imgs[~missing_mask], missing_imgs[~missing_mask])
  loss.backward()
  opt.step()

filled_imgs = to_range_0_1(filled_imgs)
missing_imgs = to_range_0_1(missing_imgs)
input = missing_imgs.detach().clone()
input[missing_mask] = 0.2

for i in range(n_missing_img):
  grid = make_grid(torch.concat([input[[i]], filled_imgs[[i]], missing_imgs[[i]]], dim=0))
  fig = show(grid)