<a href="https://colab.research.google.com/github/juankuntz/ParEM/blob/main/torch/notebooks/MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

## Colab setup

In [None]:
%%capture
# Set paths
from google.colab import drive
drive.mount("/content/gdrive", force_remount=False) # Mount drive to VM in colab
DATASET_PATH = '/content/ParEM_neural_latent_variable_model_dev/datasets/MNIST'
CHECKPOINTS_PATH = '/content/gdrive/MyDrive/ParEM_neural_latent_variable_model_dev/checkpoints'

# Install missing packages
!pip install torchtyping
!pip install torchmetrics
!pip install wandb

# Import standard modules
import sys

# Import custom modules
!rm -rf ParEM
!git clone https://github.com/juankuntz/ParEM.git
!cd ParEM; git checkout torch_code
REPOSITORY_PATH = '/content/ParEM/torch'
if REPOSITORY_PATH not in sys.path:
    sys.path.append(REPOSITORY_PATH)

## General setup

In [None]:
# Import standard modules
import torch
import numpy as np
import matplotlib.pyplot as plt
import pickle

# Import custom modules
from parem.models import NLVM
from parem.algorithms import PGD
from parem.utils import get_mnist, load_checkpoint

# Set config variables

In [None]:
# Data setttings
N_IMAGES = 10000  # M: training set size 

# Training settings
N_BATCH = 128 # M_b: batch size for theta updates
N_EPOCHS = 100 # n_epochs = K * M_b / M where K = total number of iterations
SEED = 1 # Seed for PRNG
# Device on which to carry out computations:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OPTIMIZER = 'rmsprop'  # Theta optimizer

# Model Settings
X_DIM = 64  # d_x: dimension of latent space
LIKELIHOOD_VAR = 0.01 ** 2  # sigma^2

# PGD Settings
STEP_SIZE = 1e-4 # h: step size 
LAMBDA = 1e-3 / (STEP_SIZE * N_IMAGES)  # lambda
N_PARTICLES = 10 # N: number of particles

# Load dataset

In [None]:
mnist = get_mnist(DATASET_PATH, N_IMAGES)  # Load dataset

# Define model

In [None]:
# Define model:
model = NLVM(x_dim=X_DIM, sigma2=LIKELIHOOD_VAR, nc=1).to(DEVICE)

# Define training algorithm:
pgd = PGD(model=model, dataset=mnist, train_batch_size=N_BATCH, lambd=LAMBDA,
          n_particles=N_PARTICLES, particle_step_size=STEP_SIZE, device=DEVICE,
          theta_optimizer=OPTIMIZER)

# Load checkpoint

In [None]:
#  pgd = load_checkpoint(CHECKPOINTS_PATH + '/mnist_working.pt')

# Train

In [None]:
# Train:
pgd.run(N_EPOCHS, CHECKPOINTS_PATH + '/mnist_small_batchother.pt',
        wandb_log=False, log_images=False)

# Show particle cloud

In [None]:
pgd.sample_image_posterior(10, N_PARTICLES)

## Inpainting 

In [None]:
n_missing_img = 20
images = mnist[:n_missing_img][0]
mask = torch.ones(mnist.height, mnist.width, dtype=torch.bool)

for i in range(10, 22):
  for j in range(10, 22):
        mask[i, j] = False

pgd.reconstruct(images, mask)

## Generate synthetic samples

In [None]:
pgd.synthesize_images(n=64, approx_type='gmm', n_components=3)