In [1]:
%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
%load_ext autoreload
%autoreload 2
import os
import numpy as np
from jax import random
import jax.numpy as jnp
from math import *
import numpyro as npr
import numpyro.distributions as dist
import tqdm

from numpyro.infer import MCMC, NUTS

# npr.set_platform('gpu')

env: XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


### Model

In [2]:
# Hyperparams
sigma = 0.09
nu = 12
NUM_STEPS = 100


def model(y, theta=None, rng_key=random.PRNGKey(1)):
    """ PyMC3 example http://num.pyro.ai/en/0.6.0/examples/stochastic_volatility.html """
    rng_key, rng_subkey = random.split(rng_key)

    num_steps = len(y) if y is not None else NUM_STEPS
    
    if theta is None:
        log_vol = npr.sample(
            'theta', dist.GaussianRandomWalk(scale=sigma, num_steps=num_steps), rng_key=rng_subkey
        )
    else:
        log_vol = theta
    
    rng_key, rng_subkey = random.split(rng_key)
    returns = npr.sample('y', dist.StudentT(df=nu, loc=0., scale=jnp.exp(log_vol)),
                         rng_key=rng_subkey, obs=y)
    
    if theta is None:  
        return log_vol  # Given y, sample latent
    else:  
        return returns  # Given latent, sample y

### Methods

#### VB and Laplace

In [4]:
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import *


def train_vb_diag(rng_key, y, pbar=True):    
    guide = AutoDiagonalNormal(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=pbar)
    
    return guide, svi_result.params


def train_vb_full(rng_key, y, pbar=True):    
    guide = AutoMultivariateNormal(model)
    lr = 5e-4  # Unstable with large lr
    n_iter = 10000  # Compensate with larger num. of iterations
    
    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=pbar)
    
    return guide, svi_result.params


def train_laplace(rng_key, y, pbar=True):    
    guide = AutoLaplaceApproximation(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=1))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=pbar)
    
    return guide, svi_result.params

### Sampling $\pi_G$

In [5]:
POSTERIORS = ['vb_diag', 'vb_full', 'laplace', 'mcmc']
POSTERIOR_FUNCS = {
    'vb_diag': train_vb_diag, 
    'vb_full': train_vb_full, 
    'laplace': train_laplace,
}


def sample_gibbs_prior(rng_key, posterior, T=100, mcmc_warmup=20, mcmc_samples=20):
    assert posterior in POSTERIORS
    
    theta_samples = []
    
    rng_key, rng_subkey = random.split(rng_key)
    y_t  = model(y=None, rng_key=rng_subkey)
    
    for t in tqdm.trange(T):     
        rng_key, *subkeys = random.split(rng_key, 5)
        
        # Get q(theta | y_t)
        if posterior != 'mcmc':
            guide_t, params_t = POSTERIOR_FUNCS[posterior](subkeys[0], y=y_t, pbar=False)
            theta_t = guide_t.sample_posterior(subkeys[1], params_t)['theta']
        else:
            mcmc = MCMC(NUTS(model), num_warmup=mcmc_warmup, num_samples=mcmc_samples,
                        progress_bar=False)
            mcmc.run(subkeys[2], y_t)
            theta_t = mcmc.get_samples()['theta'][-1]
    
        theta_samples.append(np.array(theta_t).copy())
        
        # Sample y_t
        y_t  = model(y=None, theta=theta_t, rng_key=subkeys[3])
        
    return theta_samples

#### Get samples

In [8]:
thetas_vb_diag = sample_gibbs_prior(random.PRNGKey(9999), 'vb_diag', T=10000)
np.save('../res/volatility/vb_diag', thetas_vb_diag)

thetas_vb_full = sample_gibbs_prior(random.PRNGKey(9999), 'vb_full', T=10000)
np.save('../res/volatility/vb_full', thetas_vb_full)

thetas_laplace = sample_gibbs_prior(random.PRNGKey(9999), 'laplace', T=10000)
np.save('../res/volatility/laplace', thetas_laplace)

thetas_mcmc = sample_gibbs_prior(random.PRNGKey(9999), 'mcmc', T=10000)
np.save('../res/volatility/mcmc', thetas_mcmc)

thetas_mcmc_short = sample_gibbs_prior(random.PRNGKey(9999), 'mcmc', T=10000, mcmc_warmup=10, mcmc_samples=1)
np.save('../res/volatility/mcmc_short_10', thetas_mcmc_short)

100%|██████████| 10000/10000 [4:36:44<00:00,  1.66s/it] 
100%|██████████| 10000/10000 [12:58:32<00:00,  4.67s/it] 
100%|██████████| 10000/10000 [3:34:13<00:00,  1.29s/it] 
100%|██████████| 10000/10000 [6:37:51<00:00,  2.39s/it]  
