# Simulations for Model/Parameter Recovery

## Generate Simulations

In [1]:
import random
import pickle
import pandas as pd
import numpy as np
from itertools import product
from typing import Sequence, Dict, Any, Optional, Tuple, List

from construal_shifting.task_modeling.simulated_participant_model import SimulatedParticipantData
from construal_shifting.task_modeling.base_dataclasses import ParticipantDataBase
from data_analysis import ExperimentDataLoader

In [2]:
exp_data = ExperimentDataLoader(
    trialdata_file="rawtrialdata-anon.csv"
)
summary_df = pd.read_json('./data/summary_df.json')
all_participant_data = [p for p in exp_data.completed_participant_data() if p.sessionId in summary_df.sessionId.values]

In [3]:
from tqdm.notebook import tqdm

def generate_simulated_participants(
    all_participant_data : Sequence[ParticipantDataBase],
    construal_cost_weight : float,
    construal_inverse_temp : float,
    action_inverse_temp : float,
    action_random_choice : float,
    construal_set_stickiness : float,
    rng : random.Random
) -> Sequence[SimulatedParticipantData]:
    simulated_participants = []
    for participant_data in all_participant_data:
        simulated_participant = SimulatedParticipantData.from_real_participant(
            participant_data=participant_data,
            construal_cost_weight=construal_cost_weight,
            construal_inverse_temp=construal_inverse_temp,
            action_inverse_temp=action_inverse_temp,
            action_random_choice=action_random_choice,
            construal_set_stickiness=construal_set_stickiness,
            seed=rng.randint(0, 2**32)
        )
        simulated_participants.append(simulated_participant)
    return simulated_participants

param_space = dict(
    construal_cost_weight=[0, 2, 4, 6, 8],
    construal_inverse_temp=[1.0],
    action_inverse_temp=[float('inf')],
    action_random_choice=[.1],
    construal_set_stickiness=[0, 2, 4, 6, 8],
)
param_names, params_values = zip(*param_space.items())
params_values = list(product(*params_values))
rng = random.Random(399)
simulated_participants = {}
for params_value in tqdm(params_values):
    params = dict(zip(param_names, params_value))
    sim_participants = generate_simulated_participants(
        all_participant_data=all_participant_data,
        rng = rng,
        **params
    )
    simulated_participants[tuple(sorted(params.items()))] = sim_participants

  0%|          | 0/25 [00:00<?, ?it/s]

  f"Values for these states will be set using self.undefined_value={self.undefined_value}"
  f"Values for these states will be set using self.undefined_value={self.undefined_value}"
  f"Values for these states will be set using self.undefined_value={self.undefined_value}"
  f"Values for these states will be set using self.undefined_value={self.undefined_value}"
  f"Values for these states will be set using self.undefined_value={self.undefined_value}"
  f"Values for these states will be set using self.undefined_value={self.undefined_value}"
  f"Values for these states will be set using self.undefined_value={self.undefined_value}"
  f"Values for these states will be set using self.undefined_value={self.undefined_value}"
  f"Values for these states will be set using self.undefined_value={self.undefined_value}"
  f"Values for these states will be set using self.undefined_value={self.undefined_value}"
  f"Values for these states will be set using self.undefined_value={self.undefined_value}"

In [4]:
pickle.dump(simulated_participants, open('./data/exp-1-simulated_participants.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

## Fit simulated data

In [5]:
import pickle
import random
from construal_shifting.task_modeling.model_fitter import ModelFitter

In [6]:
simulated_participants = pickle.load(open('./data/exp-1-simulated_participants.pkl', 'rb'))

In [7]:
pd.DataFrame([dict(keys) for keys in simulated_participants.keys()])

Unnamed: 0,action_inverse_temp,action_random_choice,construal_cost_weight,construal_inverse_temp,construal_set_stickiness
0,inf,0.1,0,1.0,0
1,inf,0.1,0,1.0,2
2,inf,0.1,0,1.0,4
3,inf,0.1,0,1.0,6
4,inf,0.1,0,1.0,8
5,inf,0.1,2,1.0,0
6,inf,0.1,2,1.0,2
7,inf,0.1,2,1.0,4
8,inf,0.1,2,1.0,6
9,inf,0.1,2,1.0,8


In [8]:

fixed_params=dict(
    construal_cost_weight=.0,
    construal_set_stickiness=.0,
    construal_inverse_temp=1.0,
    action_inverse_temp=float('inf'),
    action_random_choice=.1,
)
param_bounds = dict(
    construal_set_stickiness=(0., 10.),
    construal_cost_weight=(0., 10.),
    construal_inverse_temp=(0., 10),
    action_inverse_temp=(1e-2, 10),
    action_random_choice=(.01, 1.), #we need to lower bound to avoid numerical issues
)
param_combos_to_fit = [
    ["action_random_choice"],
    ["construal_cost_weight", "construal_set_stickiness", "action_random_choice"],
    ["construal_set_stickiness", "action_random_choice"],
    ["construal_cost_weight", "action_random_choice"],
]

results = []
rng = random.Random(51191)
for params, sim_parts in simulated_participants.items(): 
    print(dict(params))
    fitter = ModelFitter(sim_parts[:])
    for params_to_fit in param_combos_to_fit:
        fit_res = fitter.fit_params(
            fixed_params={p: v for p, v in fixed_params.items() if p not in params_to_fit},
            params_to_fit=params_to_fit,
            param_bounds=param_bounds,
            maxfun=200,
            runs=2,
            seed=rng.randint(0, int(1e7))
        )
        for i, res in enumerate(fit_res):
            results.append({
                "run": i,
                **res.model_params,
                **{'fit_'+p: True for p in params_to_fit},
                'NLL': res.neg_log_like,
                **{p+"_true": v for p, v in params},
            })
    pickle.dump(results, open('./data/exp-1-simulated_participants-fitting.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
            

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 0, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 0}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 0, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 2}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 0, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 4}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 0, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 6}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 0, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 8}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 2, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 0}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 2, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 2}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 2, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 4}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 2, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 6}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 2, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 8}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 4, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 0}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 4, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 2}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 4, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 4}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 4, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 6}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 4, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 8}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 6, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 0}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 6, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 2}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 6, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 4}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 6, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 6}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 6, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 8}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 8, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 0}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 8, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 2}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 8, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 4}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 8, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 6}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

{'action_inverse_temp': inf, 'action_random_choice': 0.1, 'construal_cost_weight': 8, 'construal_inverse_temp': 1.0, 'construal_set_stickiness': 8}


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [10]:
pickle.dump(results, open('./data/exp-1-simulated_participants-fitting.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)