# Fit model parameters

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns

from data_analysis import download_data, download_condition_counts, calc_condition_counts, ExperimentDataLoader
exp_data = ExperimentDataLoader(
    trialdata_file="rawtrialdata-anon.csv"
)

In [2]:
summary_df = pd.read_json('./data/summary_df.json')

In [4]:
from itertools import product
from frozendict import frozendict
from tqdm.notebook import tqdm
from functools import lru_cache
from construal_shifting.task_modeling.participant_model import ParticipantModel
from construal_shifting.task_modeling.construal_trial_model import ConstrualTrialModel

included_sessionIds = list(summary_df.sessionId)
@lru_cache(maxsize=None)
def evaluate_model_params(model_params):
    total_log_prob = 0
    iterator = list(exp_data.completed_participant_data())
    for participant_data in iterator:
        if participant_data.sessionId not in included_sessionIds:
            continue
        # n_trial_models = len(ConstrualTrialModel.__cache__)
        # iterator.set_description(f"Log-prob: {total_log_prob:.2f}; N trial models: {n_trial_models}")
        pmod = ParticipantModel(participant_data)
        total_log_prob += pmod.trials_log_probability(**model_params)
    return total_log_prob

In [5]:
import random
from dataclasses import dataclass
from scipy.optimize import fmin_l_bfgs_b
    
@dataclass
class FitResult:
    model_params : dict
    neg_log_like : float
    x0 : np.ndarray
    info : dict
    
def param_dict_to_str(params):
    pstr = []
    pnames = sorted(params.keys())
    for n in pnames:
        short_n = n.replace('construal', 'c').\
            replace('cost_weight', 'cost').\
            replace('inverse_temp', 'invt').\
            replace('random_choice', 'eps').\
            replace('action', 'a').\
            replace('set_stickiness', 'stick')
        pstr.append(f'{short_n}:{params[n]:.2f}')
    return ', '.join(pstr)

def fit_params_once(
    fixed_params,
    initial_params,
    param_bounds,
    maxfun,
):
    param_names = sorted(initial_params.keys())
    iterator = tqdm(total=maxfun)
    def func(x):
        iterator.update(n=1)
        model_params = dict(zip(param_names, x))
        model_params = {**fixed_params, **model_params}
        log_prob = evaluate_model_params(frozendict(model_params))
        iterator.set_description(f"NLL: {-log_prob:.0f}; {param_dict_to_str(model_params)}")
        return -log_prob

    x0 = np.array([initial_params[n] for n in param_names])
    x, nll, info = fmin_l_bfgs_b(
        func=func,
        x0=x0,
        bounds=[param_bounds[n] for n in param_names],
        maxfun=maxfun,
        approx_grad=True
    )
    iterator.close()
    model_params = dict(zip(param_names, x))
    model_params = {**fixed_params, **model_params}
    return FitResult(
        model_params=model_params,
        neg_log_like=nll,
        x0=x0,
        info=info
    )

def fit_params(
    fixed_params,
    params_to_fit,
    param_bounds,
    maxfun,
    runs,
    seed=None
):
    rng = random.Random(seed)
    results = []
    for _ in range(runs):
        initial_params = {}
        for name in params_to_fit:
            vmin, vmax = param_bounds[name]
            initial_params[name] = vmin + rng.random()*(vmax - vmin)
        res = fit_params_once(
            fixed_params={p: v for p, v in fixed_params.items() if p not in params_to_fit},
            initial_params=initial_params,
            param_bounds=param_bounds,
            maxfun=maxfun
        )
        res.initial_params = initial_params
        results.append(res)
        ConstrualTrialModel.clear_instance_method_caches()
    return results

In [7]:
fixed_params=dict(
    construal_cost_weight=.0,
    construal_set_stickiness=.0,
    construal_inverse_temp=1.0,
    action_inverse_temp=float('inf'),
    action_random_choice=.1,
)
initial_params=dict(
    construal_cost_weight=5,
    construal_set_stickiness=5,
    construal_inverse_temp=1.0,
    action_inverse_temp=1.,
    action_random_choice=.1,
)
param_bounds = dict(
    construal_set_stickiness=(0., 10.),
    construal_cost_weight=(0., 10.),
    construal_inverse_temp=(0., 10),
    action_inverse_temp=(1e-2, 10),
    action_random_choice=(5e-2, 1.),
)
param_combos_to_fit = [
    ["action_random_choice"],
    ["construal_cost_weight", "construal_set_stickiness", "action_random_choice"],
    ["construal_set_stickiness", "action_random_choice"],
    ["construal_cost_weight", "action_random_choice"],
]
results = []
rng = random.Random(12490)
for params_to_fit in param_combos_to_fit:
    fit_res = fit_params(
        fixed_params={p: v for p, v in fixed_params.items() if p not in params_to_fit},
        params_to_fit=params_to_fit,
        param_bounds=param_bounds,
        maxfun=200,
        runs=3,
        seed=rng.randint(0, int(1e7))
    )
    for i, res in enumerate(fit_res):
        results.append({
            "run": i,
            **res.model_params,
            **{'fit_'+p: True for p in params_to_fit},
            'NLL': res.neg_log_like,
        })
        

In [67]:
import pandas as pd
res_df = pd.DataFrame(results)
res_df['df'] = res_df[['fit_construal_cost_weight', 'fit_construal_set_stickiness', 'fit_action_random_choice']].apply(lambda r: sum([t == True for t in r]), axis=1)
res_df['AIC'] = res_df.apply(lambda r: 2*r['df'] + r['NLL'], axis=1)
res_df['dAIC'] = res_df['AIC'] - res_df['AIC'].min()
optimal_results = res_df.loc[[0, 6, 9, 4]].reset_index(drop=True)

In [70]:
optimal_results.to_pickle('./data/model_fits.pkl')

### Calculate and save joint model statistics

In [71]:
import pandas as pd
optimal_results = pd.read_pickle('./data/model_fits.pkl')

In [72]:
joint_fit = optimal_results[optimal_results['fit_construal_cost_weight'] & optimal_results['fit_construal_set_stickiness']].iloc[0]
joint_fit = joint_fit.to_dict()
joint_fit = {k: joint_fit[k] for k in [
    "action_inverse_temp", "action_random_choice", "construal_inverse_temp",
    "construal_set_stickiness", "construal_cost_weight"
]}

In [74]:
pmod_stats_rec = []
for pdata in tqdm(list(exp_data.completed_participant_data())):
    pmod = ParticipantModel(pdata)
    trial_stats = pmod.trials_model_stats(**joint_fit)
    pmod_stats_rec.extend(trial_stats)
pmod_stats = pd.DataFrame(pmod_stats_rec)

  0%|          | 0/419 [00:00<?, ?it/s]

In [75]:
pmod_stats.to_pickle('./data/participantmodel_stats.pkl')