In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../") # go to parent dir
import gym
from gym import spaces
import numpy as np
import sciunit
import scipy
import pandas as pd

In [20]:
sciunit.__file__

'/anaconda3/envs/testing-ldmm/lib/python3.6/site-packages/sciunit/__init__.py'

## Env for decision making

In [3]:
from src.ldmunit.env import BanditEnv, BanditAssociateEnv
from src.ldmunit.models import decision_making, associative_learning
from src.ldmunit.models.utils import loglike, train_with_obs, simulate, multi_from_single
from src.ldmunit.tests import NLLTest, AICTest, BICTest

In [4]:
MultiRWCKModel = multi_from_single(decision_making.RWCKModel)
MultiNWSLSModel = multi_from_single(decision_making.NWSLSModel)
MultiRandomRespondModel = multi_from_single(decision_making.RandomRespondModel)

MultiRwNormModel = multi_from_single(associative_learning.RwNormModel)
MultiKrwNormModel = multi_from_single(associative_learning.KrwNormModel)
MultiBetaBinomialModel = multi_from_single(associative_learning.BetaBinomialModel)
MultiLSSPDModel = multi_from_single(associative_learning.LSSPDModel)
MultiRandomRespondALModel = multi_from_single(associative_learning.RandomRespondModel)

## Decision Marking

In [5]:
def from_np_array(array_string):
    import ast

    array_string = ','.join(array_string.replace('[ ', '[').split())
    return np.array(ast.literal_eval(array_string))

def get_observation_from_idx(df, idx=None, np_array=False):
    if np_array and isinstance(df, str):
        df = pd.read_csv(df, converters={'stimuli': from_np_array})
    elif not np_array and isinstance(df, str):
        df = pd.read_csv(df)
    assert idx in df['sub'].values, "subject index not in dataframe"
    res = df.loc[df['sub'] == idx]
    res = res.loc[:,['rewards','actions','stimuli']]
    return res.to_dict(orient='list')

def get_observation(df, n_sub, np_array=False):
    res = {'rewards': [], 'actions': [], 'stimuli': []}
    for i in range(n_sub):
        df_dict = get_observation_from_idx(df, i, np_array=np_array)
        for k, v in df_dict.items():
            res[k].append(v)
    
    return res

def read_prior_csv(filename):
    import csv
    
    with open(filename, mode='rt') as f:
        reader = csv.reader(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        line_count = 0
        paras_keys = []
        paras_list = []
        for row in reader:
            if line_count == 0:
                paras_keys = row
                line_count += 1
            else:
                row = list(map(float, row)) # convert to float
                paras_list.append(dict(zip(paras_keys, row)))
                line_count += 1    
    f.close()
    
    return paras_list

In [7]:
ck = get_observation('../data/multi-ck.csv', 3)
nll_ck = NLLTest(name="CK sim", observation=ck)
aic_ck = AICTest(name="CK sim", observation=ck)
bic_ck = BICTest(name="CK sim", observation=ck)

rw = get_observation('../data/multi-rw.csv', 3)
nll_rw = NLLTest(name="RW sim", observation=rw)
aic_rw = AICTest(name="RW sim", observation=rw)
bic_rw = BICTest(name="RW sim", observation=rw)

rwck = get_observation('../data/multi-rwck.csv', 3)
nll_rwck = NLLTest(name='RWCK sim', observation=rwck)
aic_rwck = AICTest(name='RWCK sim', observation=rwck)
bic_rwck = BICTest(name='RWCK sim', observation=rwck)

rr = get_observation('../data/multi-random-responding.csv', 3)
nll_rr = NLLTest(name='RR sim', observation=rr)
aic_rr = AICTest(name='RR sim', observation=rr)
bic_rr = BICTest(name='RR sim', observation=rr)

nwsls = get_observation('../data/multi-nwsls.csv', 3)
nll_nwsls = NLLTest(name='NWSLS sim', observation=nwsls)
aic_nwsls = AICTest(name='NWSLS sim', observation=nwsls)
bic_nwsls = BICTest(name='NWSLS sim', observation=nwsls)

In [8]:
nll_suite = sciunit.TestSuite([nll_ck, nll_rw, nll_rwck, nll_nwsls], name="NLL suite")
aic_suite = sciunit.TestSuite([aic_ck, aic_rw, aic_rwck, aic_nwsls], name="AIC suite")
bic_suite = sciunit.TestSuite([bic_ck, bic_rw, bic_rwck, bic_nwsls], name="BIC suite")

In [9]:
n_action, n_obs = 3, 3

In [10]:
param_list = read_prior_csv('../data/multi-ck_prior.csv')
multi_ck = MultiRWCKModel(param_list, n_action=n_action, n_obs=n_obs)
multi_ck.name = "ck"

param_list = read_prior_csv('../data/multi-rw_prior.csv')
multi_rw = MultiRWCKModel(param_list, n_action=n_action, n_obs=n_obs)
multi_rw.name = "rw"

param_list = read_prior_csv('../data/multi-rwck_prior.csv')
multi_rwck = MultiRWCKModel(param_list, n_action=n_action, n_obs=n_obs)
multi_rwck.name = "rwck"

param_list = read_prior_csv('../data/multi-random-responding_prior.csv')
multi_rr = MultiRandomRespondModel(param_list, n_action=n_action, n_obs=n_obs)
multi_rr.name = "rr"

param_list = read_prior_csv('../data/multi-nwsls_prior.csv')
multi_nwsls = MultiNWSLSModel(param_list, n_action=n_action, n_obs=n_obs)
multi_nwsls.name = 'nwsls'

In [11]:
nll_suite.judge([multi_ck, multi_rw, multi_rwck, multi_nwsls])

Unnamed: 0,CK sim,RW sim,RWCK sim,NWSLS sim
ck,333,329,330,297
rw,332,323,327,308
rwck,340,328,325,282
nwsls,533,501,476,184


In [12]:
aic_suite.judge([multi_ck, multi_rw, multi_rwck, multi_nwsls])

Unnamed: 0,CK sim,RW sim,RWCK sim,NWSLS sim
ck,696.0,688,690,624
rw,695.0,675,683,645
rwck,710.0,685,681,595
nwsls,1070.0,994,963,383


In [13]:
bic_suite.judge([multi_ck, multi_rw, multi_rwck, multi_nwsls])

Unnamed: 0,CK sim,RW sim,RWCK sim,NWSLS sim
ck,2170.0,2160.0,2160.0,2090.0
rw,2160.0,2150.0,2150.0,2120.0
rwck,2180.0,2160.0,2150.0,2060.0
nwsls,1370.0,1280.0,1270.0,681.0


In [14]:

rr_al = get_observation('../data/multi-rr_al.csv', 3, True)
nll_rr_al = NLLTest(name='rr_al sim', observation=rr_al)
aic_rr_al = AICTest(name='rr_al sim', observation=rr_al)
bic_rr_al = BICTest(name='rr_al sim', observation=rr_al)

rw_norm = get_observation('../data/multi-rw_norm.csv', 3, True)
nll_rw_norm = NLLTest(name="rw_norm sim", observation=rw_norm)
aic_rw_norm = AICTest(name="rw_norm sim", observation=rw_norm)
bic_rw_norm = BICTest(name="rw_norm sim", observation=rw_norm)

krw_norm = get_observation('../data/multi-krw_norm.csv', 3, True)
nll_krw_norm = NLLTest(name="krw_norm sim", observation=krw_norm)
aic_krw_norm = AICTest(name="krw_norm sim", observation=krw_norm)
bic_krw_norm = BICTest(name="krw_norm sim", observation=krw_norm)

lsspd = get_observation('../data/multi-lsspd.csv', 3, True)
nll_lsspd = NLLTest(name='lsspd sim', observation=lsspd)
aic_lsspd = AICTest(name='lsspd sim', observation=lsspd)
bic_lsspd = BICTest(name='lsspd sim', observation=lsspd)

bb = get_observation('../data/multi-beta_binomial.csv', 3, True)
nll_bb = NLLTest(name='Beta Binomial sim', observation=bb)
aic_bb = AICTest(name='Beta Binomial sim', observation=bb)
bic_bb = BICTest(name='Beta Binomial sim', observation=bb)

In [15]:
nll_al_suite = sciunit.TestSuite([nll_rr_al, nll_rw_norm, nll_krw_norm, nll_lsspd, nll_bb], name="NLL suite for learning")
aic_al_suite = sciunit.TestSuite([aic_rr_al, aic_rw_norm, aic_krw_norm, aic_lsspd, aic_bb], name="AIC suite for learning")
bic_al_suite = sciunit.TestSuite([bic_rr_al, bic_rw_norm, bic_krw_norm, bic_lsspd, bic_bb], name="BIC suite for learning")

In [16]:
param_list = read_prior_csv('../data/multi-rw_norm_prior.csv')
multi_rw_norm = MultiRwNormModel(param_list, n_obs=4)
multi_rw_norm.name = "rw_norm"

param_list = read_prior_csv('../data/multi-krw_norm_prior.csv')
multi_krw_norm = MultiKrwNormModel(param_list, n_obs=4)
multi_krw_norm.name = "krw_norm"

param_list = read_prior_csv('../data/multi-lsspd_prior.csv')
multi_lsspd = MultiLSSPDModel(param_list, n_obs=4)
multi_lsspd.name = "lsspd"

param_list = read_prior_csv('../data/multi-beta_binomial_prior.csv')
multi_bb = MultiBetaBinomialModel(param_list, n_obs=4)
multi_bb.name = "bb"

In [17]:
nll_al_suite.judge([multi_rw_norm, multi_krw_norm, multi_lsspd, multi_bb])

Unnamed: 0,rr_al sim,rw_norm sim,krw_norm sim,lsspd sim,Beta Binomial sim
rw_norm,8.54e+18,1.41e+19,8.54e+18,8.66e+18,8.54e+18
krw_norm,135000.0,8.15e+19,274.0,7.43e+16,1760.0
lsspd,15700.0,2.3e+19,222.0,1.3e+16,366.0
bb,18200.0,2.67e+19,319.0,2e+16,286.0


In [18]:
aic_al_suite.judge([multi_rw_norm, multi_krw_norm, multi_lsspd, multi_bb])

Unnamed: 0,rr_al sim,rw_norm sim,krw_norm sim,lsspd sim,Beta Binomial sim
rw_norm,1.71e+19,2.82e+19,1.71e+19,1.73e+19,1.71e+19
krw_norm,270000.0,1.63e+20,596.0,1.49e+17,3560.0
lsspd,31500.0,4.6e+19,492.0,2.59e+16,780.0
bb,36400.0,5.34e+19,662.0,4e+16,595.0


In [19]:
bic_al_suite.judge([multi_rw_norm, multi_krw_norm, multi_lsspd, multi_bb])

Unnamed: 0,rr_al sim,rw_norm sim,krw_norm sim,lsspd sim,Beta Binomial sim
rw_norm,1.71e+19,2.82e+19,1.71e+19,1.73e+19,1.71e+19
krw_norm,272000.0,1.63e+20,2950.0,1.49e+17,5910.0
lsspd,33900.0,4.6e+19,2840.0,2.59e+16,3130.0
bb,37600.0,5.34e+19,1840.0,4e+16,1770.0
