### Load Test Data and Trained APG Sampler

In [None]:
%matplotlib inline
import os
import torch
import numpy as np
from apgs.bmnist.apg_training import init_models
from apgs.bmnist.affine_transformer import Affine_Transformer

CUDA = torch.cuda.is_available()
device = torch.device('cuda:1')
data_dir = '../../data/bmnist/'
timesteps, num_digits, frame_pixels, mnist_pixels, num_hidden_digit, num_hidden_coor, z_where_dim, z_what_dim = 10, 3, 96, 28, 400, 400, 2, 10
data_paths = []
for file in os.listdir(data_dir + 'train/'):
    data_paths.append(os.path.join(data_dir, 'train', file))
model_version = 'apg-bmnist-num_sweeps=6-num_samples=16'
models = init_models(frame_pixels, mnist_pixels, num_hidden_digit, num_hidden_coor, z_where_dim, z_what_dim, CUDA, device, model_version, lr=None)
AT = Affine_Transformer(frame_pixels, mnist_pixels, CUDA, device)  

### Visualize Samples

In [None]:
from apgs.resampler import Resampler
from apgs.bmnist.objectives import apg_objective
from apgs.bmnist.evaluation import viz_samples
from random import shuffle
batch_size, num_sweeps = 5, 10
shuffle(data_paths)
data = torch.from_numpy(np.load(data_paths[0])).float()[:batch_size].unsqueeze(0)
mnist_mean = torch.from_numpy(np.load('mnist_mean.npy')).float()
mnist_mean = mnist_mean.repeat(batch_size, num_digits, 1, 1).unsqueeze(0)
if CUDA:
    data = data.cuda().to(device)
    mnist_mean = mnist_mean.cuda().to(device)
result_flags = {'loss_required' : False, 'ess_required' : False, 'mode_required' : True, 'density_required': False}
trace = apg_objective(models, AT, data, num_digits, result_flags, num_sweeps, Resampler('systematic', 1, CUDA, device), mnist_mean)
viz_samples(data.squeeze(0).cpu(), trace, num_sweeps, num_digits)

### Computing log joint across all methods

In [None]:
from apgs.bmnist.evaluation import density_all_instances
from random import shuffle
sample_size, num_sweeps = 1000, 10
lf_step_size, lf_num_steps, bpg_factor = 1e-4, [1, 5, 10], 100
density_all_instances(models, AT, data_paths, sample_size, num_digits, z_where_dim, z_what_dim, num_sweeps, lf_step_size, lf_num_steps, bpg_factor, CUDA, device)