In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
sys.path.append('..')
from utils.data import Subject, load_participant_list

In [2]:
base_dir = '/Users/hugofluhr/phd_local/data/LearningHabits/dev_sample'
sub_ids = load_participant_list(base_dir)

In [3]:
subjects = [Subject(base_dir, sub_id, include_modeling=True, include_imaging=False) for sub_id in sub_ids]



In [4]:
trials = [sub.extended_trials.assign(sub_id=sub.sub_id) for sub in subjects]
trials = pd.concat(trials).set_index('sub_id', append=True).reorder_levels(['sub_id', None])

In [5]:
trials['first_stim_choice_val'] = (
    trials['beta_rl20'] * trials['first_stim_value_rl'] +
    trials['beta_ck20'] * trials['first_stim_value_ck']
)
trials['second_stim_choice_val'] = (
    trials['beta_rl20'] * trials['second_stim_value_rl'] +
    trials['beta_ck20'] * trials['second_stim_value_ck']
)

In [6]:
trials['left_stim_choice_val'] = (
    trials['beta_rl20'] * trials.apply(lambda row: row[f'stim{row["left_stim"]}_value_rl'], axis=1) +
    trials['beta_ck20'] * trials.apply(lambda row: row[f'stim{row["left_stim"]}_value_ck'], axis=1)
)
trials['right_stim_choice_val'] = (
    trials['beta_rl20'] * trials.apply(lambda row: row[f'stim{row["right_stim"]}_value_rl'], axis=1) +
    trials['beta_ck20'] * trials.apply(lambda row: row[f'stim{row["right_stim"]}_value_ck'], axis=1)
)

In [7]:
# t = np.arange(200, 300)
# plt.figure(figsize=(12, 6))
# (trials.first_stim_value_rl.iloc[t] * trials.beta_rl20.iloc[t]).plot(label='first_stim_value_rl * beta_rl20')
# (trials.first_stim_value_ck.iloc[t] * trials.beta_ck20.iloc[t]).plot(label='first_stim_value_ck * beta_ck20')
# trials.first_stim_choice_val.iloc[t].plot(label='first_stim_value')
# plt.xlabel('Trial')
# plt.ylabel('Value')
# plt.title('First Stimulus Value (RL vs. Raw)')
# plt.legend()
# plt.grid(True)
# plt.tight_layout()
# plt.show()

In [8]:
def choice_prob(right, left):
    exp_right = np.exp(right).values
    exp_left = np.exp(left).values
    return (exp_right, exp_left) / (exp_left + exp_right)


In [9]:
choice_probs = choice_prob(trials['right_stim_choice_val'], trials['left_stim_choice_val'])

In [10]:
choice_probs[0]

array([0.78684377, 0.9829903 , 0.55720718, ..., 0.6441164 , 0.92987013,
       0.89573982], shape=(20336,))

In [11]:
comp_df = trials[['action','choice_prob_left','choice_prob_right']].copy()
comp_df['calc_choice_prob_left'] = choice_probs[1]
comp_df['calc_choice_prob_right'] = choice_probs[0]

In [12]:
comp_df = comp_df[comp_df['action'].notna()]

In [13]:
comp_df[(comp_df['choice_prob_right'] - comp_df['calc_choice_prob_right']).abs() > 1e-10]

Unnamed: 0_level_0,Unnamed: 1_level_0,action,choice_prob_left,choice_prob_right,calc_choice_prob_left,calc_choice_prob_right
sub_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
sub-14,10,1.0,0.036114,0.963886,0.963886,0.036114
sub-58,276,1.0,0.888124,0.111876,0.999391,0.000609
sub-72,73,1.0,0.112327,0.887673,0.876125,0.123875


In [14]:
comp_df[(comp_df['choice_prob_left'] - comp_df['calc_choice_prob_left']).abs() > 1e-10]

Unnamed: 0_level_0,Unnamed: 1_level_0,action,choice_prob_left,choice_prob_right,calc_choice_prob_left,calc_choice_prob_right
sub_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
sub-14,10,1.0,0.036114,0.963886,0.963886,0.036114
sub-58,276,1.0,0.888124,0.111876,0.999391,0.000609
sub-72,73,1.0,0.112327,0.887673,0.876125,0.123875


In [46]:
trials.loc['sub-02'].head(15)

Unnamed: 0,block,left_stim,right_stim,left_value,right_value,shift,action,rt,chosen_stim,reward,...,reward1C,stim_diff_prop_trainingS,stim_diff_prop_training_chosenS,diff_valS,diff_val_chosenS,score_EHI,first_stim_choice_val,second_stim_choice_val,left_stim_choice_val,right_stim_choice_val
0,learning1,7,5,4,3,1,1.0,0.615436,7.0,4.0,...,,,,,,,6.859204,3.070175,6.859204,3.070175
1,learning1,5,7,3,4,0,2.0,0.478543,7.0,4.0,...,,,,,,,11.785468,8.823197,8.823197,11.785468
2,learning1,7,5,4,3,1,1.0,0.430684,7.0,4.0,...,,,,,,,11.806611,8.823197,11.806611,8.823197
3,learning1,1,3,1,2,1,1.0,0.434324,1.0,1.0,...,,,,,,,6.15895,1.424579,6.15895,1.424579
4,learning1,7,5,4,3,1,1.0,0.592224,7.0,4.0,...,,,,,,,11.827689,8.823197,11.827689,8.823197
5,learning1,5,3,3,2,0,2.0,0.792481,3.0,2.0,...,,,,,,,5.882131,8.823197,8.823197,5.882131
6,learning1,5,7,3,4,0,2.0,0.443101,7.0,4.0,...,,,,,,,11.848705,8.823197,8.823197,11.848705
7,learning1,2,4,2,3,0,1.0,0.719878,2.0,2.0,...,,,,,,,5.275427,6.827474,6.827474,5.275427
8,learning1,1,2,1,2,1,2.0,0.537363,2.0,2.0,...,,,,,,,2.962272,5.903337,2.962272,5.903337
9,learning1,4,2,3,2,0,1.0,0.510279,4.0,3.0,...,,,,,,,5.924479,8.823197,8.823197,5.924479
