In [6]:
import pandas as pd
import numpy as np
from codebase.file_utils import (
    save_obj,
    load_obj,
)
from codebase.plot import plot_density, plot_line, get_post_df
import altair as alt
from codebase.classes_data import Data
from codebase.ibis import exp_and_normalise
from run_ibis_lvm import run_ibis_lvm
from run_mcmc import run_mcmc

from codebase.file_utils import (
    save_obj,
    load_obj,
    make_folder,
    path_backslash
)
from pdb import set_trace
from copy import copy


alt.data_transformers.disable_max_rows()


DataTransformerRegistry.enable('default')

In [2]:

task_handle = 'mcmc_smc2'
gen_model = 0
existing_directory = None

if existing_directory is None:
    log_dir = make_folder(task_handle)  
    print("\n\nCreating new directory: %s" % log_dir)

else:
    log_dir = existing_directory
    log_dir = path_backslash(log_dir)
    print("\n\nReading from existing directory: %s" % log_dir)




Creating new directory: ./log/20210125_204443_mcmc_smc2/


##

In [3]:

# generate data
exp_data = Data(
    name = task_handle, 
    model_num = 1,  
    size = 100,
    random_seed = 2
    )

exp_data.generate()
save_obj(exp_data, 'complete_data', log_dir)


## Load or Run MCMC for the first 50 steps

In [5]:
num_warmup = 2000

param_names = ['beta', 'alpha']
latent_names = ['z', 'y_latent']
ps = run_mcmc(
    stan_data=exp_data.get_stan_data_upto_t(50),
    nsim_mcmc=10000,
    num_warmup = num_warmup,
    model_num=7,
    bundle_size=1000,
    gen_model=gen_model,
    param_names=param_names,
    latent_names=latent_names,
    log_dir=log_dir
)

save_obj(ps, 'mcmc_post_samples', log_dir)


In [45]:

beta_particles = ps['beta'][::10]
alpha_particles = ps['alpha'][::10]


In [46]:
from codebase.classes_ibis_lvm import ParticlesLVM
from codebase.ibis import model_phonebook, essl
from tqdm import tqdm
from scipy.special import logsumexp

In [47]:
num_warmup = 200
model_num = 7
gen_model = False
size = 100
bundle_size = 100


param_names = model_phonebook(model_num)["param_names"]
latent_names = model_phonebook(model_num)["latent_names"]
jitter_corrs = dict()
for p in param_names:
    jitter_corrs[p] = np.zeros(exp_data.size)
particles = ParticlesLVM(
    name="ibis_lvm",
    model_num=model_num,
    size=size,
    bundle_size=bundle_size,
    param_names=param_names,
    latent_names=latent_names,
    latent_model_num=1,
)
particles.set_log_dir(log_dir)
if gen_model:
    particles.compile_prior_model()
    particles.compile_model()
else:
    particles.load_prior_model()
    particles.load_model()

log_lklhds = np.empty(exp_data.size)
degeneracy_limit = 0.5


In [48]:
particles.sample_prior_particles(exp_data.get_stan_data())  # sample prior particles
particles.particles['beta'] = beta_particles
particles.particles['alpha'] = alpha_particles


In [49]:

particles.reset_weights()  # set weights to 0
particles.initialize_bundles(exp_data.get_stan_data())
particles.initialize_latent_var_given_theta(exp_data.get_stan_data())
particles.initialize_counter(exp_data.get_stan_data())

for t in tqdm(range(51, exp_data.size)):
    particles.sample_latent_bundle_at_t(t, exp_data.get_stan_data_at_t(t))
    particles.get_theta_incremental_weights_at_t(t, exp_data.get_stan_data_at_t(t))
    log_lklhds[t] = particles.get_loglikelihood_estimate()

    particles.update_weights()

    if (essl(particles.weights) < degeneracy_limit * particles.size) and (
        t + 1
    ) < exp_data.size:
        particles.add_ess(t)
        particles.resample_particles_bundles()
        particles.jitter_bundles_and_pick_one(exp_data.get_stan_data_upto_t(t + 1))

        ## add corr of param before jitter
        pre_jitter = dict()
        for p in param_names:
            pre_jitter[p] = particles.particles[p].flatten()
        ####

        particles.jitter(t + 1, exp_data.get_stan_data_upto_t(t + 1))

        ## add corr of param
        for p in param_names:
            jitter_corrs[p][t] = np.corrcoef(
                pre_jitter[p], particles.particles[p].flatten()
            )[0, 1]
        ####

        particles.reset_weights()
    else:
        pass

    save_obj(t, "t", log_dir)
    save_obj(particles, "particles", log_dir)
    save_obj(jitter_corrs, "jitter_corrs", log_dir)
    save_obj(log_lklhds, "log_lklhds", log_dir)

print("\n\n")
marg_lklhd = np.exp(logsumexp(log_lklhds))
print("Marginal Likelihood %.5f" % marg_lklhd)
save_obj(marg_lklhd, "marg_lklhd", log_dir)

output = dict()
output["particles"] = particles
output["log_lklhds"] = log_lklhds
output["marg_lklhd"] = marg_lklhd
output["jitter_corrs"] = jitter_corrs


100%|██████████| 49/49 [11:50<00:00, 14.50s/it]




Marginal Likelihood 103.87861





In [56]:
particles = load_obj('particles', log_dir)

In [57]:
particles.resample_particles()
ps = particles.particles.copy()

## Post process loadings for sign flips

In [61]:
nsim = ps['beta'].shape[0]
nrows = ps['beta'].shape[1]
for n in range(nsim):
    for i in range(nrows):
        sign = np.sign(ps['beta'][n,0])
        ps['beta'][n] = sign * ps['beta'][n,]

## Plot MCMC samples

In [62]:
param = 'beta'
df = get_post_df(ps[param])
df_quant = df.groupby(['row', 'col'])[['value']].quantile(0.025).reset_index()
df_quant.rename({'value':'q1'}, axis=1, inplace=True)
df_quant2 = df.groupby(['row', 'col'])[['value']].quantile(0.975).reset_index()
df_quant2.rename({'value':'q2'}, axis=1, inplace=True)

df = df_quant.merge(df_quant2, on=['row', 'col'])

# simple quantile chart
df['source'] = 'smc2'
c1 = alt.Chart(df).mark_bar(opacity=0.6).encode(
        alt.X('q1', title=None),
        alt.X2('q2', title=None),
        alt.Row('row'),
        alt.Column('col'),
        alt.Color('source')
)
c1

In [63]:
df['index'] = 'r_' + df.row.astype(str)+'.c_'+df.col.astype(str)
df = df.loc[:,['index', 'q1', 'q2', 'source']]

dd = pd.DataFrame(exp_data.raw_data['beta'], columns=['data'])
dd['col'] = 0
dd['row'] = np.arange(6)
dd['index'] = 'r_' + dd.row.astype(str)+'.c_'+dd.col.astype(str)
dd = dd.loc[:,['index', 'data']]
plot_data = df.merge(dd, on=['index'])

In [64]:
c1 = alt.Chart(plot_data).mark_bar(opacity=0.6).encode(
    alt.X('q1', title=None,  scale=alt.Scale(domain=[-2,2])),
    alt.X2('q2', title=None),
    alt.Color('source'),    
)
    

c2 = alt.Chart(plot_data).mark_point(opacity=1, color='red').encode(
        alt.X('data', title=None),
)
(c1+c2).facet(
       'index',
    columns=1
    )
