In [None]:
%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
%load_ext autoreload
%autoreload 2
import os
import numpy as np
from math import *
from jax import random
import jax.numpy as jnp
import numpyro as npr
import numpyro.distributions as dist
import tqdm as tqdm

# npr.set_platform('gpu')

### Models

#### True model

In [None]:
# Hyperparams
L = 10

def model_true(y=None, theta=None, rng_key=random.PRNGKey(1)):
    key, *subkeys = random.split(rng_key, 5)  # subkeys
    
    if theta is None:
        # Sample from priors \pi(\mu, \sigma^2)
        mu = npr.sample('mu', dist.Normal(0, 1), rng_key=subkeys[0])
        sigma_sq = npr.sample('sigma_sq', dist.Gamma(1, 1), rng_key=subkeys[1])
    else:
        mu, sigma_sq = theta
    
    # The true likelihood, sum of LogNormal rvs.
    with npr.plate('L', L):
        x = npr.sample('X', dist.LogNormal(mu, sigma_sq), rng_key=subkeys[2])
        
    data = npr.sample('Y', dist.Normal(x.sum(0), 1e-6), rng_key=subkeys[3], obs=y)
    
    if theta is None:
        return (mu, sigma_sq)
    else:
        return data
    
    
# Test
model_true()

#### Approximate model

In [None]:
def model_abc(y=None, theta=None, rng_key=random.PRNGKey(1)):
    key, *subkeys = random.split(rng_key, 4)  # subkeys
    
    if theta is None:
        # Sample from priors \pi(\mu, \sigma^2)
        mu = npr.sample('mu', dist.Normal(0, 1), rng_key=subkeys[0])
        sigma_sq = npr.sample('sigma_sq', dist.Gamma(1, 1), rng_key=subkeys[1])
    else:
        mu, sigma_sq = theta
    
    # Approximate likelihood
    beta_sq = jnp.log((jnp.exp(sigma_sq)-1)/L + 1)
    alpha = mu + jnp.log(L) + 0.5*(sigma_sq - beta_sq)
        
    data = npr.sample('Y', dist.LogNormal(alpha, beta_sq), rng_key=subkeys[2], obs=y)
    
    if theta is None:
        return (mu, sigma_sq)
    else:
        return data
    
    
# Test
model_abc()

### Methods

In [None]:
from numpyro.infer import SVI, Trace_ELBO, MCMC, NUTS
from numpyro.infer.autoguide import *


def mcmc(rng_key, model, y, n_samples=20, n_warmup=20, pbar=False):
    nuts = MCMC(NUTS(model), num_samples=n_samples, num_warmup=n_warmup, progress_bar=False)

    rng_key, rng_subkey = random.split(rng_key)
    nuts.run(rng_subkey, y)

    samples = nuts.get_samples()
    mu = samples['mu'][-1]
    sigma_sq = samples['sigma_sq'][-1]
    
    return mu, sigma_sq


def vb_diag(rng_key, model, y, pbar=False):   
    key, *subkeys = random.split(rng_key, 4)
    
    guide = AutoDiagonalNormal(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(subkeys[0], n_iter, y=y, progress_bar=pbar)
    
    mu = guide.sample_posterior(subkeys[1], svi_result.params)['mu']
    sigma_sq = guide.sample_posterior(subkeys[2], svi_result.params)['sigma_sq']
    
    return mu, sigma_sq


def vb_full(rng_key, model, y, pbar=False):    
    key, *subkeys = random.split(rng_key, 4)
    
    guide = AutoMultivariateNormal(model)
    lr = 5e-4
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=pbar)
    
    mu = guide.sample_posterior(subkeys[1], svi_result.params)['mu']
    sigma_sq = guide.sample_posterior(subkeys[2], svi_result.params)['sigma_sq']
    
    return mu, sigma_sq


def laplace(rng_key, model, y, pbar=False):    
    key, *subkeys = random.split(rng_key, 4)
    
    guide = AutoLaplaceApproximation(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=1))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=pbar)
    
    mu = guide.sample_posterior(subkeys[1], svi_result.params)['mu']
    sigma_sq = guide.sample_posterior(subkeys[2], svi_result.params)['sigma_sq']
    
    return mu, sigma_sq

### Gibbs-prior sampler

In [None]:
POSTERIORS = ['mcmc', 'vb_diag', 'vb_full', 'laplace', 'mcmc']
POSTERIOR_FUNCS = {
    'mcmc': mcmc, 
    'vb_diag': vb_diag, 
    'vb_full': vb_full, 
    'laplace': laplace,
}


def sample_gibbs_prior(rng_key, posterior, T=100):
    assert posterior in POSTERIORS
    
    theta_samples = []
    
    rng_key, rng_subkey = random.split(rng_key)
    y_t  = model_true(y=None, theta=None, rng_key=rng_subkey)
    
    for t in tqdm.trange(T):     
        rng_key, *subkeys = random.split(rng_key, 3)
        
        # Get q(theta | y_t)
        theta_t = POSTERIOR_FUNCS[posterior](subkeys[0], model_abc, y=y_t)
        theta_samples.append(np.array(theta_t).copy())
        
        # Sample y_t, always using the true model
        y_t  = model_true(y=None, theta=theta_t, rng_key=subkeys[1])
        
    return np.array(theta_samples)

# Prior
def sample_prior(L, n_samples, rng_key=random.PRNGKey(1)):
    rng_keys = random.split(rng_key, 2)
    mu = npr.sample('mu', dist.Normal(0, 1), sample_shape=(n_samples, 1), rng_key=rng_keys[0])
    sigma_sq = npr.sample('sigma_sq', dist.Gamma(1, 1), sample_shape=(n_samples, 1), rng_key=rng_keys[1])
    theta = np.concatenate((mu, sigma_sq), axis=-1)
    return theta

### Sample $\pi_G$ for the ABC model

In [None]:
rng_key = random.PRNGKey(9999)

if not os.path.exists('results/abc'):
    os.makedirs('results/abc')
    
    
# Laplace
thetas_laplace = sample_gibbs_prior(rng_key, 'laplace', T=10000)
np.save(f'../res/abc/laplace_approx', thetas_laplace)