In [1]:
import numpyro
numpyro.enable_x64()
# numpyro.set_platform('gpu')

import jax
print(jax.devices())

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'

import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median, init_to_value, init_to_sample, Trace_ELBO, TraceGraph_ELBO

from numpyro.infer.svi import SVI
from numpyro.infer.autoguide import AutoDelta, AutoNormal, AutoLaplaceApproximation, AutoDAIS, AutoBNAFNormal, AutoIAFNormal, AutoMultivariateNormal
from numpyro.optim import Adam, ClippedAdam

import tensorflow_probability.substrates.jax as tfp

from jax import numpy as jnp
from jax import random, vmap

import numpy as np
import pandas as pd
import arviz as az

import siuba as s
from siuba import _

from plotnine import *

from matplotlib import pyplot as plt

import pickle

from lib.helpers import *
from lib.models import *

[CpuDevice(id=0)]


In [2]:
## Define functions

def fit_svi(model, x_data, y_data, optimizer=Adam(1e-3), loss=TraceGraph_ELBO(), n_steps=10_000, full_rank = False, filename=None):
    if full_rank:
        guide = AutoMultivariateNormal(model)
    else:
        guide = AutoNormal(model)

    if filename is not None and os.path.exists(filename):
        result = pickle.load(open(filename, "rb"))
        return result
    else:
        svi = SVI(model, guide, optimizer, loss)
        result = svi.run(random.PRNGKey(1), 10_000, x_data, y_data)
        output = {"params": result.params, "guide": guide, "losses": result.losses}

        if filename is not None:
            pickle.dump(output, open(filename, "wb"))

        return(output)


def arviz_from_svi(model, guide, params, *args, obs_data=None, num_samples = 1_000):
    
    posterior_samples = guide.sample_posterior(random.PRNGKey(1), params=params, sample_shape=(num_samples,))
    samples_posterior_predictive = Predictive(model=model, posterior_samples=posterior_samples)(random.PRNGKey(1), *args)
    samples_prior_predictive = Predictive(model=model, params=None, num_samples=num_samples)(random.PRNGKey(2), *args)

    return az.from_dict(
        {k: np.expand_dims(v, 0) for k, v in posterior_samples.items()},
        prior = {k: np.expand_dims(v, 0) for k, v in samples_prior_predictive.items()},
        posterior_predictive = {k: np.expand_dims(v, 0) for k, v in samples_posterior_predictive.items()},
        observed_data = {"yhat": obs_data}
    )

In [3]:
## Load Experiment 1 data

df1 = (load_data_exp1() >>
    s.arrange(_.ID) >>
    s.mutate(block = _.block-1) >> 
      s.mutate(estimate =  np.round(_.estimate*20).astype("int64"))
)

df1

X_exp1, y_exp1 = make_model_data(df1)
print(len(y_exp1), "observations")


7080 observations


In [4]:
## Load Experiment 2 data
df2 = (load_data_exp2_trials()  >> 
        s.filter(_.condition!=2) >>
      s.mutate(block = _.block-1) >> 
      s.mutate(estimate =  np.round(_.estimate*20).astype("int64"))
     )

X_exp2, y_exp2 = make_model_data(df2)
print(len(y_exp2), "observations")

10080 observations


In [5]:
res_bs_exp2 = fit_svi(bs_complex_mlm_trial_level, X_exp2, y_exp2, filename="local/svi-bs_complex_mlm_trial_level-exp2.p")

  0%|          | 10/10000 [00:58<16:12:59,  5.84s/it]


KeyboardInterrupt: 

In [None]:
az_bs_exp2 = arviz_from_svi(model, res_ptn_exp2["guide"], res_ptn_exp2["params"], X_exp2, y_exp2)