In [1]:
import sys
print('Notebook is running:', sys.executable)

# further check your python version
from platform import python_version

print('The current Python version is', python_version())

# If you are sure that conda is installed, also check the package that install
#!conda list  # list the conda

import hddm, IPython, kabuki, pymc
import numpy as np
import pandas as pd
import seaborn as sns
print('The current HDDM version is', hddm.__version__) # 0.8.0
print('The current Kabuki version is', kabuki.__version__) # 0.6.3
print('The current PyMC version is', pymc.__version__) # 2.3.8

# Warning:`IPython.parallel` package has been deprecated since IPython 4.0. 
print('The current IPython version is', IPython.__version__) 

print('The current Numpy version is', np.__version__) 

print('The current Pandas version is', pd.__version__)

print('The current seaborn version is', sns.__version__)

Notebook is running: /opt/conda/bin/python
The current Python version is 3.7.6
The current HDDM version is 0.8.0
The current Kabuki version is 0.6.3
The current PyMC version is 2.3.8
The current IPython version is 7.15.0
The current Numpy version is 1.19.4
The current Pandas version is 1.0.5
The current seaborn version is 0.11.1




In [2]:
# Preparation
import os, hddm, time, csv
import glob
import datetime
from datetime import date

import pymc as pm
import hddm
import kabuki

import arviz as az
import numpy as np
import pandas as pd
import feather
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from patsy import dmatrix

from p_tqdm import p_map
from functools import partial

# set the color of plots
from cycler import cycler
plt.rcParams['axes.prop_cycle'] = cycler(color='bgrcmykw')

In [3]:
# NOTE: I hacked the `post_pred_gen`, 
# more detals: https://groups.google.com/g/hddm-users/c/Is6AM7eN0fo
from post_pred_gen_redifined import _parents_to_random_posterior_sample
from post_pred_gen_redifined import _post_pred_generate
from post_pred_gen_redifined import post_pred_gen

from pointwise_loglik_gen import _pointwise_like_generate
from pointwise_loglik_gen import pointwise_like_gen

# import self-defined functions
from SimData import SimData
from run_models import run_m1, run_m2, run_m4, run_m5, run_m7

model_func = [run_m1, run_m2, run_m4, run_m5, run_m7]

m_keys = ["ms1",
          "ms2",
          "ms4",
          "ms5",
          "ms7"]

df_keys = ["sim_df1", 
           "sim_df2", 
           "sim_df4", 
           "sim_df5",
           "sim_df7"]


In [4]:
def model_recov(data=None, m_keys=None, model_func=None):
    """
    This func is for model recovery. 
    
    data: input data, can be simulated data or real data
    m_keys: id for different models
    model_func: a list of model functions
    
    """
    InfData = {}
    models = {}
    for ii in range(len(m_keys)):
        m_key = m_keys[ii]

        ### Run models
        save_name = "./tmp/" + m_key + "_tmp"
        print("start model fitting for ", m_key)
        ms_tmp = p_map(partial(model_func[ii], 
                               df=data, 
                               samples=samples,
                               burn=burn,
                               save_name=save_name),
                       range(chains))

        ### Observations
        xdata_observed = ms_tmp[0].data.copy()
        xdata_observed.index.names = ['trial_idx']
        xdata_observed = xdata_observed[['rt', 'response']]
        xdata_observed = xr.Dataset.from_dataframe(xdata_observed)

        ### posteriors
        xdata_posterior = []
        for jj in range(len(ms_tmp)):
            trace_tmp = ms_tmp[jj].get_traces()
            trace_tmp['chain'] = jj
            trace_tmp['draw'] = np.arange(len(trace_tmp), dtype=int)
            xdata_posterior.append(trace_tmp)
        xdata_posterior = pd.concat(xdata_posterior)
        xdata_posterior = xdata_posterior.set_index(["chain", "draw"])
        xdata_posterior = xr.Dataset.from_dataframe(xdata_posterior)

        ### PPC
        xdata_post_pred = [] # define an empty dict    
        print("start PPC for ", m_key)
        start_time = time.time()  
        xdata_post_pred = p_map(partial(post_pred_gen), ms_tmp)
        print("Running PPC for ", m_key, " costs %f seconds" % (time.time() - start_time))
        xdata_post_pred = pd.concat(xdata_post_pred, names=['chain'], 
                                keys = list(range(len(xdata_post_pred))))
        xdata_post_pred = xdata_post_pred.reset_index(level=1, drop=True)
        xdata_post_pred = xr.Dataset.from_dataframe(xdata_post_pred)

        ### Point-wise log likelihood
        xdata_loglik = [] # define an empty dict
        print("start calculating loglik for ", m_key)
        start_time = time.time()  # the start time of the processing
        xdata_loglik = p_map(partial(pointwise_like_gen), ms_tmp)
        print("Generating loglik costs %f seconds" % (time.time() - start_time))

        xdata_loglik = pd.concat(xdata_loglik, names=['chain'], 
                                keys = list(range(len(xdata_loglik))))
        xdata_loglik = xdata_loglik.reset_index(level=1, drop=True)
        xdata_loglik = xr.Dataset.from_dataframe(xdata_loglik)
        
        ### convert to InfData
        InfData[m_key] = az.InferenceData(posterior=xdata_posterior, 
                                                 observed_data=xdata_observed,
                                                 posterior_predictive=xdata_post_pred,
                                                 log_likelihood = xdata_loglik)
        models[m_key] = ms_tmp
    return models, InfData

In [5]:
samples = 2000
burn = 500
chains = 4

In [6]:
%%time

conf_mat_dic2 = pd.DataFrame(0, index=m_keys, columns=df_keys)
conf_mat_loo2 = pd.DataFrame(0, index=m_keys, columns=df_keys)
conf_mat_waic2 = pd.DataFrame(0, index=m_keys, columns=df_keys)

for sim in range (3):   
    for df_key in df_keys:
        ### simulate data
        data = SimData(df_key)

        ### fit the sim data
        print("Start model recovery for ", df_key)
        models, InfData = model_recov(data=data, m_keys=m_keys, model_func=model_func)

        ### compare models
        tmp_loo_comp = az.compare(InfData, ic="loo")
        tmp_loo_comp = tmp_loo_comp.reset_index()
        tmp_waic_comp = az.compare(InfData, ic="waic")
        tmp_waic_comp = tmp_waic_comp.reset_index()
        
        tmp_dic = []
        indx_name = []

        for m_key, model in models.items():
            m_tmp = kabuki.utils.concat_models(model)
            tmp_dic.append(m_tmp.dic)
            indx_name.append(m_key)
            
        tmp_dic_comp = pd.DataFrame(tmp_dic, index=indx_name, columns=['dic'])
        tmp_dic_comp = tmp_dic_comp.sort_values(by=['dic'])
        tmp_dic_comp = tmp_dic_comp.reset_index()
        #conf_mat_dic.rename(columns={'index':'rank'}, inplace=True)

        ### record the best models
        conf_mat_dic2.loc[tmp_dic_comp.loc[0, 'index'], df_key] += 1
        conf_mat_loo2.loc[tmp_loo_comp.loc[0, 'index'], df_key] += 1
        conf_mat_waic2.loc[tmp_waic_comp.loc[0, 'index'], df_key] += 1

        conf_mat_dic2.to_csv('conf_mat_dic2.csv')
        conf_mat_loo2.to_csv('conf_mat_loo2.csv')
        conf_mat_waic2.to_csv('conf_mat_waic2.csv')

Start model recovery for  sim_df1
start model fitting for  ms1


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_with_indexer(indexer, value)


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

  tmp2 = (x - v) * (fx - fw)
  tmp2 = (x - v) * (fx - fw)
  tmp2 = (x - v) * (fx - fw)
  tmp2 = (x - v) * (fx - fw)


                   1%                  ] 20 of 2000 complete in 27.6 sec                  0%                  ] 2 of 2000 complete in 0.5 sec[                  0%                  ] 2 of 2000 complete in 1.2 sec[                  0%                  ] 4 of 2000 complete in 3.0 sec[                  0%                  ] 3 of 2000 complete in 2.2 sec[                  0%                  ] 2 of 2000 complete in 1.2 sec[                  0%                  ] 3 of 2000 complete in 2.8 sec[                  0%                  ] 5 of 2000 complete in 5.3 sec[                  0%                  ] 4 of 2000 complete in 3.9 sec[                  0%                  ] 4 of 2000 complete in 4.4 sec[                  0%                  ] 3 of 2000 complete in 3.3 sec[                  0%                  ] 6 of 2000 complete in 6.5 sec[                  0%                  ] 5 of 2000 complete in 5.5 sec[                  0%                  ] 5 of 2000 complete in 6.0 sec[                  

                   2%                  ] 41 of 2000 complete in 67.8 secc[                  1%                  ] 31 of 2000 complete in 47.2 sec[                  1%                  ] 34 of 2000 complete in 49.2 sec[                  1%                  ] 30 of 2000 complete in 47.2 sec[                  1%                  ] 31 of 2000 complete in 48.8 sec[                  1%                  ] 32 of 2000 complete in 49.4 sec[                  1%                  ] 35 of 2000 complete in 51.0 sec[                  1%                  ] 31 of 2000 complete in 49.0 sec[                  1%                  ] 32 of 2000 complete in 50.7 sec[                  1%                  ] 33 of 2000 complete in 51.2 sec[                  1%                  ] 36 of 2000 complete in 52.8 sec[                  1%                  ] 32 of 2000 complete in 50.7 sec[                  1%                  ] 34 of 2000 complete in 52.3 sec[                  1%                  ] 33 of 2000 complete in

 [-                 2%                  ] 58 of 2000 complete in 95.8 sec[-                 2%                  ] 59 of 2000 complete in 96.1 sec[-                 3%                  ] 60 of 2000 complete in 95.4 sec[-                 3%                  ] 63 of 2000 complete in 98.5 sec[-                 3%                  ] 60 of 2000 complete in 97.4 sec[-                 2%                  ] 59 of 2000 complete in 97.8 sec[-                 3%                  ] 61 of 2000 complete in 97.1 sec[-                 3%                  ] 64 of 2000 complete in 100.9 sec[-                 3%                  ] 61 of 2000 complete in 99.4 sec[-                 3%                  ] 60 of 2000 complete in 99.8 sec[-                 3%                  ] 62 of 2000 complete in 98.9 sec[-                 3%                  ] 62 of 2000 complete in 100.8 sec[-                 3%                  ] 63 of 2000 complete in 99.9 sec[-                 3%                  ] 61 of 2000 complete 

 [-                 4%                  ] 87 of 2000 complete in 138.8 sec[-                 4%                  ] 92 of 2000 complete in 137.8 sec[-                 4%                  ] 91 of 2000 complete in 137.1 sec[-                 4%                  ] 88 of 2000 complete in 140.0 sec[-                 4%                  ] 84 of 2000 complete in 138.4 sec[-                 4%                  ] 92 of 2000 complete in 138.4 sec[-                 4%                  ] 93 of 2000 complete in 139.6 sec[-                 4%                  ] 89 of 2000 complete in 141.5 sec[-                 4%                  ] 85 of 2000 complete in 140.0 sec[-                 4%                  ] 93 of 2000 complete in 139.4 sec[-                 4%                  ] 94 of 2000 complete in 140.7 sec[-                 4%                  ] 90 of 2000 complete in 142.9 sec[-                 4%                  ] 94 of 2000 complete in 140.8 sec[-                 4%                  ] 86 of 200

Process ForkPoolWorker-16:
Process ForkPoolWorker-6:
Process ForkPoolWorker-5:
Process ForkPoolWorker-8:


Halting at iteration Halting at iteration 

Process ForkPoolWorker-14:
Process ForkPoolWorker-7:


Halting at iteration 

Process ForkPoolWorker-15:


Halting at iteration 

Process ForkPoolWorker-9:


 

Traceback (most recent call last):
Traceback (most recent call last):


 

Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):


  

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()


110

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):


97 

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()


107

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)


 of 

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)


 

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 355, in get
    res = self._reader.recv_bytes()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/synchronize.py", line 102, in __enter__
    return self._semlock.__enter__()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()


  of  

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()


102

KeyboardInterrupt
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()


 of 2000 


  File "/opt/conda/lib/python3.7/site-packages/multiprocess/synchronize.py", line 102, in __enter__
    return self._semlock.__enter__()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
KeyboardInterrupt


 

KeyboardInterrupt
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/connection.py", line 219, in recv_bytes
    buf = self._recv_bytes(maxlength)


2000

KeyboardInterrupt
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/connection.py", line 410, in _recv_bytes
    buf = self._recv(4)





  File "/opt/conda/lib/python3.7/site-packages/multiprocess/synchronize.py", line 102, in __enter__
    return self._semlock.__enter__()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/connection.py", line 382, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt
KeyboardInterrupt
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/synchronize.py", line 102, in __enter__
    return self._semlock.__enter__()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/synchronize.py", line 102, in __enter__
    return self._semlock.__enter__()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/synchronize.py", line 102, in __e

  of Could not generate output statistics forCould not generate output statistics for  2000z_subj_trans.5
 
tCould not generate output statistics for2000
 
Could not generate output statistics fort_subj.5 
a_subj.9Could not generate output statistics for
 Could not generate output statistics forz_subj.0Could not generate output statistics for
Could not generate output statistics for z_subj.4 
z_subj.7
Could not generate output statistics forCould not generate output statistics for a_subj.5  t_subj.12z_subj.3



Could not generate output statistics forCould not generate output statistics for  a_subj.7
Could not generate output statistics for Could not generate output statistics forz_subj.0 
a_subj.2Could not generate output statistics for
 Could not generate output statistics forCould not generate output statistics for  a_subj.7v_subj.5a_subj.12


Could not generate output statistics for Could not generate output statistics forCould not generate output statistics forv_subj.1 t_subj.3
Co

Process ForkPoolWorker-11:
Process ForkPoolWorker-12:




 
Could not generate output statistics forCould not generate output statistics for

Traceback (most recent call last):


Could not generate output statistics for

Traceback (most recent call last):


z_subj_trans.9 

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()


 

Process ForkPoolWorker-10:



 

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()


a_subj.5

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
Process ForkPoolWorker-13:
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()


t_std
a_subj.10

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)





  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:


Could not generate output statistics for


  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()


Could not generate output statistics for

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/synchronize.py", line 102, in __enter__
    return self._semlock.__enter__()


Could not generate output statistics for

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:


 

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()


 Could not generate output statistics for

KeyboardInterrupt
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)


 

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/synchronize.py", line 102, in __enter__
    return self._semlock.__enter__()


t_subj.4

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)


z_subj.7 

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()




a_subj.3

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:


Could not generate output statistics for

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/pool.py", line 110, in worker
    task = get()



Could not generate output statistics for

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/synchronize.py", line 102, in __enter__
    return self._semlock.__enter__()
  File "/opt/conda/lib/python3.7/site-packages/multiprocess/queues.py", line 354, in get
    with self._rlock:


  

KeyboardInterrupt


t_subj.2

  File "/opt/conda/lib/python3.7/site-packages/multiprocess/synchronize.py", line 102, in __enter__
    return self._semlock.__enter__()


v_subj.5

KeyboardInterrupt


Could not generate output statistics for

v_subj.10 

KeyboardInterrupt


Could not generate output statistics for
 a_subj.9a_subj.13
Could not generate output statistics for 
a_std
Could not generate output statistics forCould not generate output statistics for Could not generate output statistics for z_subj.10 Could not generate output statistics fort_subj.12
a_subj.8 
Could not generate output statistics forv_subj.2
 Could not generate output statistics fort_subj.0
Could not generate output statistics for
Could not generate output statistics for Could not generate output statistics for z_subj_trans.4z_subj.11

Could not generate output statistics forCould not generate output statistics for   v_subj.9z_subj_trans.8 v

t_subj.11
Could not generate output statistics forCould not generate output statistics forCould not generate output statistics for
  Could not generate output statistics foraz_subj_trans.10 

 t_stdCould not generate output statistics for
Could not generate output statistics for v_subj.0Could not generate output statistics for t_subj.13
 Coul

KeyboardInterrupt: 

 z_subj_trans.4
Could not generate output statistics for a_subj.11v_subj.1

Could not generate output statistics for Could not generate output statistics for z_subj_trans.2
v_subj.9
Could not generate output statistics for Could not generate output statistics for z_subj_trans.9sv 
Could not generate output statistics for
z_subj_trans.2Could not generate output statistics for 
 v_subj.4Could not generate output statistics fort_subj.9 
a_subj.1Could not generate output statistics for

Could not generate output statistics forCould not generate output statistics for  z_subj.2a_subj.3

Could not generate output statistics forCould not generate output statistics for  t_subj.5
a_subj.12
Could not generate output statistics for z_subj.11
 a_subj.6
Could not generate output statistics for t_subj.9
Could not generate output statistics for sv
Could not generate output statistics for z_subj.12


In [7]:
tmp_loo_comp

NameError: name 'tmp_loo_comp' is not defined

In [None]:
tmp_waic_comp

In [None]:
models['ms7'][0].plot_posteriors()