# Notebook containing code to produce the results and stats in the supplement


Running this notebook will regenerate the figures and stats in the supplement of the paper.



## Setup

The following blocks load some libraries and set some parameters for the remainder of the notebook.

They also load the data from the experiment and the model fits to generate the figures and stats.

In [1]:
import pandas as pd
import numpy as np

import pingouin as pg
from psifr import fr
import sys
sys.path.append("../../")
from pymer4 import Lmer, Lm

from src.data.process_strat import *


red = '#e41a1c'
green = '#4daf4a'

# make a palette for incorrect and correct
pal = [red,green]

In [3]:
strat_df = pd.read_csv("../../data/processed/exp1/DecayFeatureRL_strat_data_rpe.csv")
rl_df = pd.read_csv("../../data/processed/exp1/DecayFeatureRL_no_resp_st_results.csv")
fr_df = pd.read_csv("../../data/interim/exp1/mem_df.csv")

In [6]:
# get rid of people with fewer than 15 words recalled
fr_df = fr_df.dropna(subset=['word'])
fr_df['recalled'] = fr_df['recall'] & fr_df['study']
recalls = fr_df.groupby(['subject', 'list'])['recalled'].sum()

# Find subjects with less than 5 recalls per list
bad_subs = set(recalls[recalls < 1].reset_index()['subject'])
print(len(bad_subs))
fr_df = fr_df[~fr_df['subject'].isin(bad_subs)]
strat_df = strat_df[~strat_df['subid'].isin(bad_subs)]
rl_df = rl_df[~rl_df['subid'].isin(bad_subs)]
strat_df['word'] = strat_df['word'].str.lower()

0


In [None]:
# with open('../../data/processed/strat_subs.txt', 'w') as f:
#     for sub in strat_df.subid.unique().tolist():
#         f.write(sub + '\n')
    




In [10]:

strat_df['rt_centered'] = strat_df['rt'] - strat_df.groupby(['subid'])['rt'].transform('mean') 
strat_df['rt_z_score'] = strat_df['rt_centered'] / strat_df.groupby(['subid'])['rt'].transform('std')
strat_df['rt_shift'] = strat_df.groupby(['subid','run'])['rt'].shift(-1)

In [11]:
strat_df = strat_df.merge(rl_df[['subid','eta',]],on='subid')
strat_df.head()

Unnamed: 0.1,Unnamed: 0,index,rt,trial_type,trial_index,time_elapsed,internal_node_id,run_id,condition,source_code_version,...,within_across,prev_block_size,disc,rpe,trial_by_trial_loglik,uncertainty,rt_centered,rt_z_score,rt_shift,eta
0,0,19047,,html-keyboard-response,14,1124352,0.0-3.0-0.0-0.0,132,1,7049bc8ab37debb1d4e5dbca0544092c,...,,,False,,0.0,0.0,,,1559.0,0.636722
1,1,19051,1559.0,html-keyboard-response,18,1127922,0.0-3.0-0.1-0.1,132,1,7049bc8ab37debb1d4e5dbca0544092c,...,,,True,0.0,-0.693147,0.0,501.381166,1.189103,2684.0,0.636722
2,2,19055,2684.0,html-keyboard-response,22,1134070,0.0-3.0-0.2-0.2,132,1,7049bc8ab37debb1d4e5dbca0544092c,...,,,False,1.0,-0.693147,0.04822,1626.381166,3.857214,1643.0,0.636722
3,3,19059,1643.0,html-keyboard-response,26,1138050,0.0-3.0-0.3-0.3,132,1,7049bc8ab37debb1d4e5dbca0544092c,...,,,False,-0.273444,-0.040173,0.026052,585.381166,1.388322,1722.0,0.636722
4,4,19063,1722.0,html-keyboard-response,30,1143147,0.0-3.0-0.4-0.4,132,1,7049bc8ab37debb1d4e5dbca0544092c,...,,,True,-0.462614,-0.693147,0.036781,664.381166,1.575683,628.0,0.636722


In [12]:

fr_df = fr_df.merge(strat_df[["subid","word","rpe","trial_within_block","correct_rule","rt_z_score","rt_shift",'disc','item_rule','within_across','eta']], on=['subid','word'], how="left")

boundary_labels = {0:'Boundary',1:'Post-Boundary',-1:'Pre-Boundary'}
fr_df['boundary_label'] = fr_df['rel_subj_boundary'].apply(lambda x: boundary_labels[x] if x in boundary_labels else 'Non-Boundary') 

## RT in WRIT

In [29]:
# does current rpe on trial_t predict rt_z_score on next trial?
subset_df = strat_df.copy()
subset_df = subset_df[subset_df['rt'] != -1].reset_index(drop=True)
subset_df['rt_z_score_shift'] = subset_df.groupby(['subid','run'])['rt_z_score'].shift(-1)
subset_df['rt_shift'] = subset_df.groupby(['subid','run'])['rt'].shift(-1)
subset_df = subset_df.dropna(subset=['rt_z_score_shift'])
# subset_df['boundary_label'] = subset_df['rel_subj_boundary'].apply(lambda x: boundary_labels[x] if x in boundary_labels else 'Non-Boundary')
subset_df['ispost'] = subset_df['trial_within_block'].apply(lambda x: 1 if x == 1 else 0)


def vectorized_cumulative_count(group):
    # Create a mask for reset points (where points == 0)
    resets = group == 0
    
    # Create groups based on cumulative sum of reset points
    # Each time we encounter a 0, we start a new group
    groups = resets.cumsum()
    
    # For each group, create a sequence of numbers, but set to 0 where the original value is 0
    counts = group.groupby(groups).cumcount()
    counts = counts.where(group != 0, 0)
    
    return counts
# Assume df is your DataFrame and 'subject' and 'points' are the column names
subset_df['cumulative_points'] = subset_df.groupby(['subid','run'])['points'].transform(vectorized_cumulative_count)

In [39]:
subset_df['isPost'] = subset_df['rel_subj_boundary'].apply(lambda x: 1 if x ==1 else 0)
model = Lmer("rt_shift ~ cumulative_points + (1|subid)", data=subset_df)
model.fit()



Linear mixed model fit by REML [’lmerMod’]
Formula: rt_shift~cumulative_points+(1|subid)

Family: gaussian	 Inference: parametric

Number of observations: 14274	 Groups: {'subid': 66.0}

Log-likelihood: -108468.210 	 AIC: 216944.421

Random effects:

                 Name         Var      Std
subid     (Intercept)   36173.752  190.194
Residual               229723.549  479.295

No random effect correlations specified

Fixed effects:



  ran_vars = ran_vars.applymap(


Unnamed: 0,Estimate,2.5_ci,97.5_ci,SE,DF,T-stat,P-val,Sig
(Intercept),1268.464,1221.416,1315.511,24.004,67.779,52.843,0.0,***
cumulative_points,-32.624,-36.629,-28.62,2.043,14234.371,-15.968,0.0,***


## Within vs across-dimensional shifts

In [33]:
## Effects of within vs across-dimensional shifts on WRIT performance
gb_df= strat_df.groupby(['subid','within_across'])['points'].mean().reset_index()
pg.ttest(gb_df[gb_df['within_across']=='within']['points'],gb_df[gb_df['within_across']=='across']['points'],paired=True)


Unnamed: 0,T,dof,alternative,p-val,CI95%,cohen-d,BF10,power
T-test,-4.286701,65,two-sided,6.1e-05,"[-0.059679428841140934, -0.021744640402523246]",0.54406,331.8,0.991663


## Recall performance

In [34]:
gb_df = fr_df.groupby(['subject','points'])['recalled'].mean().reset_index()
pg.ttest(gb_df[gb_df['points']==0]['recalled'],gb_df[gb_df['points']==1]['recalled'],paired=True)

Unnamed: 0,T,dof,alternative,p-val,CI95%,cohen-d,BF10,power
T-test,-2.90791,65,two-sided,0.004973,"[-0.03416349475874578, -0.006343494429175171]",0.294758,6.211,0.655121


In [35]:
# merge in cumulative points from subset_df to fr_df 
fr_df = fr_df.merge(subset_df[['subid','word','cumulative_points','rt']],on=['subid','word'])
model = Lmer('recalled ~ cumulative_points + (1|subid) + (1|word)',data = fr_df[fr_df['cumulative_points'].isin(range(8))],family="binomial")
model.fit()



Linear mixed model fit by maximum likelihood  ['lmerMod']
Formula: recalled~cumulative_points+(1|subid)+(1|word)

Family: binomial	 Inference: parametric

Number of observations: 14735	 Groups: {'word': 432.0, 'subid': 66.0}

Log-likelihood: -6335.039 	 AIC: 12678.078

Random effects:

              Name    Var    Std
word   (Intercept)  0.214  0.463
subid  (Intercept)  0.199  0.446

No random effect correlations specified

Fixed effects:



  ran_vars = ran_vars.applymap(


Unnamed: 0,Estimate,2.5_ci,97.5_ci,SE,OR,OR_2.5_ci,OR_97.5_ci,Prob,Prob_2.5_ci,Prob_97.5_ci,Z-stat,P-val,Sig
(Intercept),-1.887,-2.02,-1.754,0.068,0.152,0.133,0.173,0.132,0.117,0.148,-27.738,0.0,***
cumulative_points,0.066,0.041,0.09,0.013,1.068,1.042,1.095,0.516,0.51,0.523,5.23,0.0,***


In [36]:
# effects on memory success

model = Lmer("recalled ~ points + within_across + (1|subject)", data=fr_df[~fr_df['within_across'].isna()], family='binomial')
model.fit()

Linear mixed model fit by maximum likelihood  ['lmerMod']
Formula: recalled~points+within_across+(1|subject)

Family: binomial	 Inference: parametric

Number of observations: 14464	 Groups: {'subject': 66.0}

Log-likelihood: -6308.614 	 AIC: 12625.227

Random effects:

                Name    Var    Std
subject  (Intercept)  0.199  0.446

No random effect correlations specified

Fixed effects:



  ran_vars = ran_vars.applymap(


Unnamed: 0,Estimate,2.5_ci,97.5_ci,SE,OR,OR_2.5_ci,OR_97.5_ci,Prob,Prob_2.5_ci,Prob_97.5_ci,Z-stat,P-val,Sig
(Intercept),-1.737,-1.874,-1.601,0.07,0.176,0.154,0.202,0.15,0.133,0.168,-24.911,0.0,***
points,0.131,0.036,0.226,0.048,1.14,1.036,1.253,0.533,0.509,0.556,2.7,0.007,**
within_acrosswithin,-0.137,-0.235,-0.039,0.05,0.872,0.791,0.962,0.466,0.442,0.49,-2.732,0.006,**
