In [95]:
import numpyro
numpyro.enable_x64()
numpyro.util.set_host_device_count(6)
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median, init_to_value


import tensorflow_probability.substrates.jax as tfp

from jax import numpy as jnp
from jax import random, vmap

import numpy as np
import pandas as pd
import seaborn as sns
import arviz as az

from dfply import *

from plotnine import *

import pickle

%run -i 'model_helpers.py'
%run -i 'models.py'
# %run -i 'reloo-compare.py'

def load_data_exp2_trials():
    df = load_raw_data(2)
    
    df["condition"] = np.select(
    [
        df.querydetail.str.contains("windy|cloudy"), 
        df.querydetail.str.contains("cold|rainy"),
        df.querydetail.str.contains("warm|snowy")

    ], 
    [
        0,
        1,
        2
    ], 
    default=0 )
    
    original_ids = list(np.unique(df.ID))
    fix_id_dict = {original_ids[i]:i for i in range(0, len(original_ids))}
    
    df = df.assign(ID = df.ID.apply(lambda x: fix_id_dict[x]))

    return df

In [111]:
df = load_data_exp2_trials() # see data_helpers.py
df = df[df["condition"]!=2] # filter out "warm/snowy" as per paper
# df = df.sort_values(by=["ID","block","condition"]) # don't think I need to sort?

df = (df >> 
      s.mutate(block = _.block-1) >> 
      s.filter(_.ID < 15) >>
      # s.group_by(_.ID, _.condition, _.querytype, _.querydetail) >> 
      # s.summarize(estimate = _.estimate.mean()) >>
      # s.mutate(estimate =  np.round(_.estimate*100)) # for 100
      s.mutate(estimate =  np.round(_.estimate*20)) # round to nearest 5
     )
df.estimate = df.estimate.astype("int64")
df.head(5)

Unnamed: 0,ID,block,trial,query,querydetail,querytype,estimate,starttime,endtime,RT,condition
2,0,0,3,If the weather in England is cloudy on a rando...,windy given cloudy,AgB,16,76763.706192,76772.946467,9.240275,0
3,0,0,4,If the weather in England is windy on a random...,not cloudy given windy,notBgA,8,76773.07275,76782.666277,9.593527,0
4,0,0,5,What is the probability that the weather will ...,cold and rainy,AandB,14,76782.789357,76794.954048,12.164691,1
8,0,0,9,What is the probability that the weather will ...,not windy or not cloudy,notAornotB,15,76847.388571,76878.392191,31.003619,0
9,0,0,10,What is the probability that the weather will ...,windy and cloudy,AandB,14,76878.521995,76895.023852,16.501857,0


In [112]:
X_data, y_data = make_model_data(df) # see data_helpers.py

print(len(y_data), "observations")

1800 observations


In [98]:
## rounding stuff
from jax import vmap

def spread_vec(x, step_size): # this works without static arguments
    base_steps = x.shape[0]
    x_split = jnp.split(x, base_steps)
    pad = jnp.zeros(step_size-1)
    probs = jnp.stack([jnp.concatenate((i,pad)) for i in x_split]).flatten()
    probs = probs[0:21]
    return probs/jnp.sum(probs)


def f(mu, k, responses):
    
    a = mu*k
    b = (1.-mu)*k
    
    n_resps = (responses.shape[0]-1)
    step = int(100/n_resps)
    rnd_unit_scaled = 1/n_resps
    
    lower = jnp.clip((responses/n_resps) - rnd_unit_scaled/2., 1e-8, 1-1e-8)
    upper = jnp.clip((responses/n_resps) + rnd_unit_scaled/2., 1e-8, 1-1e-8)
    
    prob_resps = tfp.math.betainc(a, b, upper) - tfp.math.betainc(a, b, lower)
    prob_resps = (spread_vec(prob_resps, step) + 1e-30)
    prob_resps = (prob_resps)/jnp.sum(prob_resps)
    
    return(prob_resps)


lbeta_cat_probs = vmap(f, (0, 0, None))

responses_10 = jnp.linspace(0, 10, num=11)
responses_5 = jnp.linspace(0, 20, num=21)

In [99]:
## models

def ptn_simplecond_mlm_trial_level_disc(data, y=None):
    
    # parameterized in terms of d and d' for comparison of model fit

    # Data processing
    trial, subj, cond = data["trial"], data["subj"], data["cond"]
    n_Ps, n_conds = np.unique(subj).shape[0], np.unique(cond).shape[0] 
    
    # setup "design matrix" (of sorts)
    X_num, X_denom = jnp.stack([num_vecs[i] for i in trial]), jnp.stack([denom_vecs[i] for i in trial])
    conjdisj, not_conjdisj = jnp.array([is_conjdisj(i) for i in trial]), abs(1-jnp.array([is_conjdisj(i) for i in trial]))

    # population level parameters/priors
    k = numpyro.sample("k", dist.HalfCauchy(20)) # noise parameter
    rnd_policy = numpyro.sample("rnd_policy", dist.Dirichlet(jnp.ones(3)))
    
    d_base_pop = numpyro.sample("d_base_pop", dist.Normal(-1.0, 1.0))
    d_delta_pop = numpyro.sample("d_delta_pop", dist.Normal(0, .5)) # bias toward lower values for non conj/disj trials
    d_base_sd = numpyro.sample("d_base_sd", dist.LogNormal(-1., 1.)) # was halfcauchy(1)
    d_delta_sd = numpyro.sample("d_delta_sd", dist.LogNormal(-1., 1.)) # approx uniform altogether we hope

    # subject-level parameters/priors 
    with numpyro.plate("subj", n_Ps):
        d_bases = numpyro.sample("d_base_r", dist.Normal(0, 1))
        d_deltas = numpyro.sample("d_delta_r", dist.Normal(0, 1))
#         ks = numpyro.sample("k", dist.HalfCauchy(20)) # noise parameter
        
    # subject/query-level parameters/priors
    with numpyro.plate("cond", n_Ps*n_conds):
        thetas = numpyro.sample("theta", dist.Dirichlet(jnp.ones(4)))
    
    d_lin = (d_base_pop + 
             d_bases[subj]*d_base_sd + 
             jnp.exp(d_delta_pop + d_delta_sd*d_deltas[subj])*conjdisj
            )  # exp() constrains d_delta to be positive
    d = sigmoid(d_lin)/2.0 # require this be in [0, 1/3]
    
    numpyro.deterministic("d_subj", sigmoid(d_base_pop + d_bases*d_base_sd)/3.)
    numpyro.deterministic("d_prime_subj", 
                          sigmoid(d_base_pop + 
                                  d_bases*d_base_sd + 
                                  jnp.exp(d_delta_pop + d_deltas*d_delta_sd)
                                 )/2.
                         )
    
    theta_ind = ((subj*n_conds)+cond)
    theta = thetas[theta_ind,:]
        
    p_bs = prob_judge_BS_d(theta, X_num, X_denom, d)
#     k = ks[subj]
    
    resp_probs = (
        1./21.*rnd_policy[0] +
        lbeta_cat_probs(p_bs, k, responses_5)*rnd_policy[1] + 
        lbeta_cat_probs(p_bs, k, responses_10)*rnd_policy[2]
    )

    # Likelihood
    with numpyro.plate("data", len(trial)):
        # yhat = numpyro.sample("yhat", dist.Beta(p_bs*k, (1-p_bs)*k), obs=y)
        yhat = numpyro.sample("yhat", dist.Categorical(probs=resp_probs), obs=y) # rounded
        return(yhat)
    
    
def ptn_simplecond_mlm_trial_level_disc_noise(data, y=None):
    
    # allowing freely varying noise parameter
    # fit named as mymodel

    # Data processing
    trial, subj, cond = data["trial"], data["subj"], data["cond"]
    n_Ps, n_conds = np.unique(subj).shape[0], np.unique(cond).shape[0] 
    
    # setup "design matrix" (of sorts)
    X_num, X_denom = jnp.stack([num_vecs[i] for i in trial]), jnp.stack([denom_vecs[i] for i in trial])
    conjdisj, not_conjdisj = jnp.array([is_conjdisj(i) for i in trial]), abs(1-jnp.array([is_conjdisj(i) for i in trial]))

    # population level parameters/priors
    # k = numpyro.sample("k", dist.HalfCauchy(20)) # noise parameter
    rnd_policy = numpyro.sample("rnd_policy", dist.Dirichlet(jnp.ones(3)))
    
    d_base_pop = numpyro.sample("d_base_pop", dist.Normal(-1.0, 1.0))
    d_delta_pop = numpyro.sample("d_delta_pop", dist.Normal(0, .5)) # bias toward lower values for non conj/disj trials
    d_base_sd = numpyro.sample("d_base_sd", dist.LogNormal(-1., 1.)) # was halfcauchy(1)
    d_delta_sd = numpyro.sample("d_delta_sd", dist.LogNormal(-1., 1.)) # approx uniform altogether we hope

    # subject-level parameters/priors 
    with numpyro.plate("subj", n_Ps):
        d_bases = numpyro.sample("d_base_r", dist.Normal(0, 1))
        d_deltas = numpyro.sample("d_delta_r", dist.Normal(0, 1))
        ks = numpyro.sample("k", dist.HalfCauchy(20)) # noise parameter
        
    # subject/query-level parameters/priors
    with numpyro.plate("cond", n_Ps*n_conds):
        thetas = numpyro.sample("theta", dist.Dirichlet(jnp.ones(4)))
    
    d_lin = (d_base_pop + 
             d_bases[subj]*d_base_sd + 
             jnp.exp(d_delta_pop + d_delta_sd*d_deltas[subj])*conjdisj
            )  # exp() constrains d_delta to be positive
    d = sigmoid(d_lin)/2.0 # require this be in [0, 1/3]
    
    numpyro.deterministic("d_subj", sigmoid(d_base_pop + d_bases*d_base_sd)/3.)
    numpyro.deterministic("d_prime_subj", 
                          sigmoid(d_base_pop + 
                                  d_bases*d_base_sd + 
                                  jnp.exp(d_delta_pop + d_deltas*d_delta_sd)
                                 )/2.
                         )
    
    theta_ind = ((subj*n_conds)+cond)
    theta = thetas[theta_ind,:]
        
    p_bs = prob_judge_BS_d(theta, X_num, X_denom, d)
    k = ks[subj]
    
    resp_probs = (
        1./21.*rnd_policy[0] +
        lbeta_cat_probs(p_bs, k, responses_5)*rnd_policy[1] + 
        lbeta_cat_probs(p_bs, k, responses_10)*rnd_policy[2]
    )

    # Likelihood
    with numpyro.plate("data", len(trial)):
        # yhat = numpyro.sample("yhat", dist.Beta(p_bs*k, (1-p_bs)*k), obs=y)
        yhat = numpyro.sample("yhat", dist.Categorical(probs=resp_probs), obs=y) # rounded
        return(yhat)

    

In [120]:
def bs_dist(p, beta, N):
    return (p * N) / (N + 2 * beta) + beta / (N + 2 * beta)


def bs_dist_inv(x, beta, N):
    return (x - beta / (N + 2. * beta)) * (N + 2. * beta) / N


def bs_dist_cdf(N, beta, a, b, x):
    # where x is untransformed probability
    trans_x = bs_dist_inv(x, beta, N)

    res = jnp.where(
        jnp.logical_or( trans_x <= 0., trans_x >= 1.), 
        jnp.clip(trans_x, 0., 1.), 
        tfp.math.betainc(a, b, jnp.clip(trans_x, 1e-8, 1-1e-8))
    )

    return res


def f_bs(mu, N, beta, responses):
    
    a = mu*N
    b = (1.-mu)*N
    
    n_resps = (responses.shape[0]-1)
    step = int(20/n_resps)
    rnd_unit_scaled = 1/n_resps
    
    lower = jnp.clip((responses/n_resps) - rnd_unit_scaled/2., 1e-8, 1-1e-8)
    upper = jnp.clip((responses/n_resps) + rnd_unit_scaled/2., 1e-8, 1-1e-8)
    
    prob_resps = bs_dist_cdf(N, beta, a, b, upper) - bs_dist_cdf(N, beta, a, b, lower)
    prob_resps = (spread_vec(prob_resps, step) + 1e-30)
    prob_resps = (prob_resps)/jnp.sum(prob_resps)
    
    return(prob_resps)


bs_cat_probs = vmap(f_bs, (0, 0, 0, None))


def bs_complex_mlm_trial_level(data, y=None):

    # Data processing
    trial, subj, cond = data["trial"], data["subj"], data["cond"]
    n_Ps, n_conds = np.unique(subj).shape[0], np.unique(cond).shape[0] 

    # setup "design matrix" (of sorts)
    X_num, X_denom = jnp.stack([num_vecs[i] for i in trial]), jnp.stack([denom_vecs[i] for i in trial])
    conjdisj, not_conjdisj = jnp.array([is_conjdisj(i) for i in trial]), abs(1-jnp.array([is_conjdisj(i) for i in trial]))

    # population level parameters/priors
    # k = numpyro.sample("k", dist.HalfCauchy(20)) # noise parameter
    beta_pop = numpyro.sample("beta_pop", dist.Normal(-2.75, .9)) # skewed after sigmoid
    beta_sd = numpyro.sample("beta_sd", dist.HalfCauchy(1))

    N_prime_pop = numpyro.sample("N_prime_pop", dist.Normal(0,2)) # mildly informative
    N_delta_pop = numpyro.sample("N_delta_pop", dist.Normal(0,2)) 
    N_prime_sd = numpyro.sample("N_prime_sd", dist.HalfCauchy(2))
    N_delta_sd = numpyro.sample("N_delta_sd", dist.HalfCauchy(2))
    
    rnd_policy = numpyro.sample("rnd_policy", dist.Dirichlet(jnp.ones(3)))

    # subject-level parameters/priors <--- maybe change to non-centered parameterization for all these
    with numpyro.plate("subj", n_Ps):
        betas = numpyro.sample("beta_r", dist.Normal(0, 1))*beta_sd 
        N_deltas = numpyro.sample("N_delta_r", dist.Normal(0, 1))*N_delta_sd
        N_primes = numpyro.sample("N_prime_r", dist.Normal(0, 1))*N_prime_sd

    # subject/query-level parameters/priors
    with numpyro.plate("cond", n_Ps*n_conds):
        thetas = numpyro.sample("theta", dist.Dirichlet(jnp.ones(4)))

    beta = sigmoid(beta_pop + betas[subj])*10 # constrains beta to [0,10]
    # beta = jnp.exp(beta_pop + betas[subj])
    numpyro.deterministic("beta_subj", jnp.exp(beta_pop + betas))

    # exp() needed to constrain N and N_delta positive
    N = 1 + jnp.exp(N_prime_pop + N_primes[subj]) + jnp.exp(N_delta_pop + N_deltas[subj]) * not_conjdisj # they also required N be at least 1

    numpyro.deterministic("N_subj", 1 + jnp.exp(N_prime_pop + N_primes))
    numpyro.deterministic("N_prime_subj", 1 + jnp.exp(N_prime_pop + N_primes) + jnp.exp(N_delta_pop + N_deltas))

    theta_ind = ((subj*n_conds)+cond)
    theta = thetas[theta_ind,:]
    
    pi = calc_prob(theta, X_num, X_denom)

    # Likelihood
    with numpyro.plate("data", len(trial)):
        
        resp_probs = (
        1./21.*rnd_policy[0] +
        bs_cat_probs(pi, N, beta, responses_5)*rnd_policy[1] + 
        bs_cat_probs(pi, N, beta, responses_10)*rnd_policy[2]
        )
        
        yhat = numpyro.sample("yhat", dist.Categorical(probs=resp_probs), obs=y) # rounded


    return yhat

In [121]:
kernel_bs = NUTS(bs_complex_mlm_trial_level, target_accept_prob=.80, init_strategy=init_to_median(num_samples=30))

mcmc_bs = MCMC(kernel_bs, 
               num_warmup=500, 
               num_samples=500, 
               num_chains=1
              )

mcmc_bs.run(random.PRNGKey(0), X_data, y_data)

warmup:   2%|▏         | 15/1000 [2:18:33<151:38:24, 554.22s/it, 1023 steps of size 4.05e-05. acc. prob=0.53]


KeyboardInterrupt: 

It samples!

### Neural transport reparamterization

A test below, "works" but not clear it's any kind of improvement so far

In [118]:
## neural transport test

from jax import lax
from numpyro.infer import ELBO, MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal
from numpyro.infer.reparam import NeuTraReparam
import numpyro.optim as optim

guide = AutoBNAFNormal(bs_complex_mlm_trial_level)
svi = SVI(bs_complex_mlm_trial_level, guide, optim.Adam(0.003), Trace_ELBO())
svi_result = svi.run(random.PRNGKey(2), 5_000, X_data, y_data)

100%|██████████| 5000/5000 [1:03:21<00:00,  1.32it/s, init loss: 5706.3414, avg. loss [4751-5000]: 4616.9065]


In [119]:
neutra = NeuTraReparam(guide, svi_result.params)
kernel = NUTS(neutra.reparam(bs_complex_mlm_trial_level), target_accept_prob=.80)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1, progress_bar=True)
mcmc.run(random.PRNGKey(2), X_data, y_data)
zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"]
samples = neutra.transform_sample(zs)

warmup:   1%|          | 9/1000 [1:16:52<141:04:39, 512.49s/it, 1023 steps of size 1.71e-03. acc. prob=0.49]


KeyboardInterrupt: 