# Notebook containing code to produce the results and stats in the supplement


Running this notebook will regenerate the figures and stats in the supplement of the paper.



## Setup

The following blocks load some libraries and set some parameters for the remainder of the notebook.

They also load the data from the experiment and the model fits to generate the figures and stats.

In [1]:
import pandas as pd
import numpy as np

import pingouin as pg
from psifr import fr
import sys
sys.path.append("../../")
from pymer4 import Lmer, Lm

from src.data.process_strat import *


red = '#e41a1c'
green = '#4daf4a'

# make a palette for incorrect and correct
pal = [red,green]

  return warn(


In [2]:
strat_df = pd.read_csv("../../data/processed/DecayFeatureRL_strat_data_rpe.csv")
rl_df = pd.read_csv("../../data/processed/DecayFeatureRL_no_resp_st_results.csv")
fr_df = pd.read_csv("../../data/interim/mem_df.csv")

  strat_df = pd.read_csv("../../data/processed/DecayFeatureRL_strat_data_rpe.csv")


In [3]:
# get rid of people with fewer than 15 words recalled
fr_df = fr_df.dropna(subset=['word'])
fr_df['recalled'] = fr_df['recall'] & fr_df['study']
recalls = fr_df.groupby(['subject', 'list'])['recalled'].sum()

# Find subjects with less than 5 recalls per list
bad_subs = set(recalls[recalls < 1].reset_index()['subject'])
print(len(bad_subs))
fr_df = fr_df[~fr_df['subject'].isin(bad_subs)]
strat_df = strat_df[~strat_df['sona_id'].isin(bad_subs)]
rl_df = rl_df[~rl_df['sona_id'].isin(bad_subs)]
strat_df['word'] = strat_df['word'].str.lower()

1


In [4]:
with open('../../data/processed/strat_subs.txt', 'w') as f:
    for sub in strat_df.sona_id.unique().tolist():
        f.write(sub + '\n')
    




In [5]:

strat_df['rt_centered'] = strat_df['rt'] - strat_df.groupby(['sona_id'])['rt'].transform('mean') 
strat_df['rt_z_score'] = strat_df['rt_centered'] / strat_df.groupby(['sona_id'])['rt'].transform('std')
strat_df['rt_shift'] = strat_df.groupby(['sona_id','run'])['rt'].shift(-1)

In [6]:
strat_df = strat_df.merge(rl_df[['sona_id','eta',]],on='sona_id')
strat_df.head()

Unnamed: 0.1,Unnamed: 0,index,view_history,rt,trial_type,trial_index,time_elapsed,internal_node_id,run_id,condition,...,item_rule_idx1,inv_item_rule_idx0,inv_item_rule_idx1,rpe,trial_by_trial_loglik,uncertainty,rt_centered,rt_z_score,rt_shift,eta
0,0,19047,,,html-keyboard-response,14,1124352,0.0-3.0-0.0-0.0,132,1,...,1,0,0,,0.0,-0.0,,,1559.0,0.089219
1,1,19051,,1559.0,html-keyboard-response,18,1127922,0.0-3.0-0.1-0.1,132,1,...,1,1,0,0.0,-5.048029,-1.0,501.381166,1.189103,2684.0,0.089219
2,2,19055,,2684.0,html-keyboard-response,22,1134070,0.0-3.0-0.2-0.2,132,1,...,0,0,1,1.0,-0.006443,-1.0,1626.381166,3.857214,1643.0,0.089219
3,3,19059,,1643.0,html-keyboard-response,26,1138050,0.0-3.0-0.3-0.3,132,1,...,0,0,1,0.821561,-0.002625,-0.578259,585.381166,1.388322,1722.0,0.089219
4,4,19063,,1722.0,html-keyboard-response,30,1143147,0.0-3.0-0.4-0.4,132,1,...,1,0,0,-0.162519,-0.006443,-1.0,664.381166,1.575683,628.0,0.089219


In [7]:

fr_df = fr_df.merge(strat_df[["sona_id","word","rpe","trial_within_block","correct_rule","rt_z_score","rt_shift",'disc','item_rule','within_across','eta']], on=['sona_id','word'], how="left")

boundary_labels = {0:'Boundary',1:'Post-Boundary',-1:'Pre-Boundary'}
fr_df['boundary_label'] = fr_df['rel_subj_boundary'].apply(lambda x: boundary_labels[x] if x in boundary_labels else 'Non-Boundary') 

## RT in WRIT

In [8]:
# does current rpe on trial_t predict rt_z_score on next trial?
subset_df = strat_df.copy()
subset_df = subset_df[subset_df['rt'] != -1].reset_index(drop=True)
subset_df['rt_z_score_shift'] = subset_df.groupby(['sona_id','run'])['rt_z_score'].shift(-1)
subset_df['rt_shift'] = subset_df.groupby(['sona_id','run'])['rt'].shift(-1)
subset_df = subset_df.dropna(subset=['rt_z_score_shift'])
# subset_df['boundary_label'] = subset_df['rel_subj_boundary'].apply(lambda x: boundary_labels[x] if x in boundary_labels else 'Non-Boundary')
subset_df['ispost'] = subset_df['trial_within_block'].apply(lambda x: 1 if x == 1 else 0)


def cumulative_count(s):
    count = 0
    counts = []
    for val in s:
        if val == 0:
            count = 0
        else:
            count += 1
        counts.append(count)
    return pd.Series(counts, index=s.index)

# Assume df is your DataFrame and 'subject' and 'points' are the column names
subset_df['cumulative_points'] = subset_df.groupby(['sona_id','run'])['points'].apply(cumulative_count)

To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=True)
  subset_df['cumulative_points'] = subset_df.groupby(['sona_id','run'])['points'].apply(cumulative_count)


In [9]:
subset_df['isPost'] = subset_df['rel_subj_boundary'].apply(lambda x: 1 if x ==1 else 0)
model = Lmer("rt_shift ~ cumulative_points + (1|sona_id) + (1|word)", data=subset_df)
model.fit()



Linear mixed model fit by REML [’lmerMod’]
Formula: rt_shift~cumulative_points+(1|sona_id)+(1|word)

Family: gaussian	 Inference: parametric

Number of observations: 15790	 Groups: {'word': 432.0, 'sona_id': 73.0}

Log-likelihood: -119798.320 	 AIC: 239606.640

Random effects:

                 Name         Var      Std
word      (Intercept)     469.084   21.658
sona_id   (Intercept)   39073.213  197.669
Residual               223684.683  472.953

No random effect correlations specified

Fixed effects:



Unnamed: 0,Estimate,2.5_ci,97.5_ci,SE,DF,T-stat,P-val,Sig
(Intercept),1265.646,1219.228,1312.065,23.683,75.018,53.441,0.0,***
cumulative_points,-27.327,-31.097,-23.557,1.923,15741.579,-14.207,0.0,***


## Within vs across-dimensional shifts

In [10]:
## Effects of within vs across-dimensional shifts on WRIT performance
gb_df= strat_df.groupby(['sona_id','within_across'])['points'].mean().reset_index()
pg.ttest(gb_df[gb_df['within_across']=='within']['points'],gb_df[gb_df['within_across']=='across']['points'],paired=True)


Unnamed: 0,T,dof,alternative,p-val,CI95%,cohen-d,BF10,power
T-test,-4.054895,72,two-sided,0.000125,"[-0.05, -0.02]",0.462143,166.123,0.973529


## Recall performance

In [11]:
gb_df = fr_df.groupby(['subject','points'])['recalled'].mean().reset_index()
pg.ttest(gb_df[gb_df['points']==0]['recalled'],gb_df[gb_df['points']==1]['recalled'],paired=True)

Unnamed: 0,T,dof,alternative,p-val,CI95%,cohen-d,BF10,power
T-test,-2.453467,72,two-sided,0.01657,"[-0.03, -0.0]",0.202827,2.099,0.401365


In [12]:
# merge in cumulative points from subset_df to fr_df 
fr_df = fr_df.merge(subset_df[['sona_id','word','cumulative_points','rt']],on=['sona_id','word'])
model = Lmer('recalled ~ cumulative_points + (1|sona_id) + (1|word)',data = fr_df[fr_df['cumulative_points'].isin(range(8))],family="binomial")
model.fit()



Linear mixed model fit by maximum likelihood  ['lmerMod']
Formula: recalled~cumulative_points+(1|sona_id)+(1|word)

Family: binomial	 Inference: parametric

Number of observations: 16285	 Groups: {'word': 432.0, 'sona_id': 73.0}

Log-likelihood: -7005.721 	 AIC: 14019.442

Random effects:

                Name    Var    Std
word     (Intercept)  0.197  0.444
sona_id  (Intercept)  0.247  0.497

No random effect correlations specified

Fixed effects:



Unnamed: 0,Estimate,2.5_ci,97.5_ci,SE,OR,OR_2.5_ci,OR_97.5_ci,Prob,Prob_2.5_ci,Prob_97.5_ci,Z-stat,P-val,Sig
(Intercept),-1.859,-1.996,-1.723,0.07,0.156,0.136,0.179,0.135,0.12,0.151,-26.725,0.0,***
cumulative_points,0.054,0.03,0.078,0.012,1.056,1.031,1.081,0.514,0.508,0.519,4.462,0.0,***


In [13]:
# effects on memory success

model = Lmer("recalled ~ points + within_across + (1|subject)", data=fr_df[~fr_df['within_across'].isna()], family='binomial')
model.fit()

Linear mixed model fit by maximum likelihood  ['lmerMod']
Formula: recalled~points+within_across+(1|subject)

Family: binomial	 Inference: parametric

Number of observations: 15991	 Groups: {'subject': 73.0}

Log-likelihood: -6987.423 	 AIC: 13982.845

Random effects:

                Name    Var    Std
subject  (Intercept)  0.241  0.491

No random effect correlations specified

Fixed effects:



Unnamed: 0,Estimate,2.5_ci,97.5_ci,SE,OR,OR_2.5_ci,OR_97.5_ci,Prob,Prob_2.5_ci,Prob_97.5_ci,Z-stat,P-val,Sig
(Intercept),-1.708,-1.845,-1.571,0.07,0.181,0.158,0.208,0.153,0.136,0.172,-24.429,0.0,***
points,0.088,-0.001,0.177,0.045,1.092,0.999,1.194,0.522,0.5,0.544,1.938,0.053,.
within_acrosswithin,-0.124,-0.218,-0.03,0.048,0.883,0.804,0.971,0.469,0.446,0.493,-2.582,0.01,**
