In [2]:
import pandas as pd
import numpy as np
from scipy import stats
from pathlib import Path
import matplotlib.pyplot as plt
import holoviews as hv
from holoviews import opts
import hvplot.pandas
import panel as pn
from bokeh.io import output_notebook

output_notebook()
hv.extension('bokeh')
pn.extension('bokeh')

font_dict = {'title': 16, 'labels': 14, 'ticks': 12, 'legend': 12}
hv.opts.defaults(
    hv.opts.Curve(width=600, height=400, tools=['hover'], fontsize=font_dict),
    hv.opts.Scatter(width=600, height=400, size=8, tools=['hover'], fontsize=font_dict),
    hv.opts.Histogram(width=600, height=400, fontsize=font_dict),
    hv.opts.Bars(width=600, height=400, fontsize=font_dict),
)




In [3]:
monkey = 'fiona' # 'yasmin'  or 'fiona' 
base_path = Path.cwd().parent / 'data' / f'{monkey}_sst'
# filepath = base_path.parent / 'csst_trials_pkls' / f'{monkey}_csst_trials_df.pkl'
filepath = base_path.parent / 'csst_trials_pkls' / f'all_{monkey}_CSST_trials_df.pkl'
df = pd.read_pickle(filepath)

In [4]:
print(df.info())
# df.iloc[:2]
df.head()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 110358 entries, 0 to 110357
Data columns (total 28 columns):
 #   Column                  Non-Null Count   Dtype  
---  ------                  --------------   -----  
 0   blinks                  21339 non-null   object 
 1   dir                     110358 non-null  int64  
 2   direction               110358 non-null  object 
 3   filename                110358 non-null  object 
 4   first_relevant_saccade  103093 non-null  object 
 5   flags                   110358 non-null  int64  
 6   go_cue                  110358 non-null  int64  
 7   hPos                    110358 non-null  object 
 8   hVel                    110358 non-null  object 
 9   neural_data             102343 non-null  object 
 10  reaction_time           103093 non-null  float64
 11  saccades                110206 non-null  object 
 12  screen_rotation         110358 non-null  float64
 13  segs_durations          110358 non-null  object 
 14  segs_times          

Unnamed: 0,blinks,dir,direction,filename,first_relevant_saccade,flags,go_cue,hPos,hVel,neural_data,...,ssd_number,stop_cue,trial_failed,trial_length,trial_name,trial_number,trial_session,type,vPos,vVel
0,,180,L,fi210824a.0614,"[1382, 1457]",8206,1054,"[11.275, 11.275, 11.275, 11.275, 11.275, 11.27...","[0.0, 0.0, 0.4594490287247533, 1.3783470861742...","{0: [884.4], 1: [154.42, 329.17999999999995, 1...",...,2.0,1186.0,False,2205,CONT_L_SSD2,614,fi210824a,CONT,"[-0.05, -0.05, -0.05, -0.05, -0.05, -0.05, 0.0...","[-3.1242533953283225, -3.1242533953283225, -4...."
1,,180,L,fi210824a.0520,"[1100, 1172]",8206,914,"[-11.15, -11.15, -11.15, -11.15, -11.15, -11.1...","[-2.7566941723485194, -2.7566941723485194, -3....","{1: [48.52, 264.55, 585.97, 1032.3], 2: [385.2...",...,,,False,2065,GO_L,520,fi210824a,GO,"[-1.15, -1.15, -1.15, -1.15, -1.15, -1.175, -1...","[-0.8270082517045559, -0.8270082517045559, 0.1..."
2,,180,L,fi210824a.1193,"[1099, 1179]",8206,938,"[2.25, 2.25, 2.3, 2.3, 2.3, 2.25, 2.225, 2.225...","[-188.19032216565896, -188.19032216565896, -18...","{1: [574.7, 853.02, 1403.57], 3: [8.35, 301.92...",...,3.0,1118.0,False,2089,CONT_L_SSD3,1193,fi210824a,CONT,"[-27.45, -27.45, -27.45, -27.45, -27.45, -27.0...","[0.0, 0.0, 0.0, 0.0, 0.0, 9.280870380240016, 4..."
3,,180,L,fi210824a.1013,"[1213, 1289]",13326,1081,"[8.425, 8.425, 8.375, 8.375, 8.375, 8.4, 8.4, ...","[1.286457280429309, 1.286457280429309, -0.9188...","{2: [1954.73], 5: [434.05, 484.13, 547.7, 851....",...,3.0,1261.0,True,1961,STOP_L_SSD3,1013,fi210824a,STOP,"[0.275, 0.275, 0.275, 0.275, 0.275, 0.275, 0.2...","[0.9188980574495066, 0.9188980574495066, 0.0, ..."
4,,0,R,fi210824a.1257,,8194,1024,"[-11.475, -11.475, -11.475, -11.475, -11.475, ...","[3.767482035542977, 3.767482035542977, 3.85937...","{0: [923.35], 1: [1125.0], 5: [1026.4, 1359.92...",...,2.0,1156.0,True,1476,CONT_R_SSD2,1257,fi210824a,CONT,"[-1.825, -1.825, -1.8, -1.8, -1.825, -1.825, -...","[0.27566941723485194, 0.27566941723485194, 1.0..."


In [5]:
# remove trials with no relevant saccade
# df = df[~pd.isna(df['first_relevant_saccade'])]

# remove trials with blinks
# df = df[pd.isna(df['blinks'])]

# print(f"Total trials after cleaning: {len(df)}")

In [6]:
# def check_flag_consistency(row):
#     bit = 2 if (row['type'] != 'STOP') else 11
#     return not bool(row['flags'] & (1 << bit))

# df['trial_failed'] = df.apply(check_flag_consistency, axis=1)

In [7]:
# Neural Population Analysis - Behavioral Exploratory Data Analysis

## Cell 1: Data Loading and Basic Info
print(f"Data loaded for {monkey}")
print(f"Total trials: {len(df):,}")
print(f"Date range: {df['trial_session'].str[:8].min()} to {df['trial_session'].str[:8].max()}")


Data loaded for fiona
Total trials: 110,358
Date range: fi210628 to fi211125


In [8]:
## Cell 2: Basic Data Overview

# Basic overview of behavioral data
print("=== BEHAVIORAL DATA OVERVIEW ===")
print(f"\nTrial Types:")
print(df['type'].value_counts().sort_index())

print(f"\nDirections:")
print(df['direction'].value_counts())

print(f"\nTrial Outcomes:")
success_rate = (1 - df['trial_failed'].mean()) * 100
print(f"Overall Success Rate: {success_rate:.1f}%")
print(df['trial_failed'].value_counts())

print(f"\nExperimental Set:")
print(df['set'].value_counts())




=== BEHAVIORAL DATA OVERVIEW ===

Trial Types:
type
CONT    25294
GO      61460
STOP    23604
Name: count, dtype: int64

Directions:
direction
R    55453
L    54905
Name: count, dtype: int64

Trial Outcomes:
Overall Success Rate: 84.6%
trial_failed
False    93361
True     16997
Name: count, dtype: int64

Experimental Set:
set
CSST    110358
Name: count, dtype: int64


In [9]:
## Cell 3: Trial Type Distribution and Success Rates

# Create summary statistics for plotting
trial_summary = df.groupby(['type', 'trial_failed']).size().reset_index(name='count')
trial_summary['outcome'] = trial_summary['trial_failed'].map({False: 'Success', True: 'Failed'})

# Calculate success rates by trial type
success_rates = df.groupby('type').agg({
    'trial_failed': ['count', 'sum', 'mean']
}).round(3)
success_rates.columns = ['total_trials', 'failed_trials', 'failure_rate']
success_rates['success_rate'] = (1 - success_rates['failure_rate']) * 100
success_rates['failure_rate'] *= 100
print("Success rates by trial type:")
print(success_rates)

# Create the main visualization
plot1 = trial_summary.hvplot.bar(
    x='type', y='count', by='outcome',
    stacked=True,
    title=f'{monkey.title()} - Trial Distribution by Type and Outcome',
    xlabel='Trial Type',
    ylabel='Number of Trials',
    width=600, height=400,
    color=['#2E8B57', '#CD5C5C'],  # Green for success, red for failed
    legend='top_right'
)

plot1.opts(
    fontsize={'title': 16, 'labels': 14, 'ticks': 12, 'legend': 12},
)
# success_rates

Success rates by trial type:
      total_trials  failed_trials  failure_rate  success_rate
type                                                         
CONT         25294           3933          15.5          84.5
GO           61460           2884           4.7          95.3
STOP         23604          10180          43.1          56.9


In [10]:
## Cell 4: Success Rates by Trial Type (Percentage View)
# Create percentage view of success rates
trial_pct = df.groupby('type').apply(
    lambda x: pd.Series({
        'Success': (1 - x['trial_failed'].mean()) * 100,
        'Failed': x['trial_failed'].mean() * 100
    })
).reset_index()

trial_pct_melted = trial_pct.melt(id_vars='type', var_name='outcome', value_name='percentage')

plot2 = trial_pct_melted.hvplot.bar(
    x='type', y='percentage', by='outcome',
    stacked=True,
    title=f'{monkey.title()} - Success Rate by Trial Type (%)',
    xlabel='Trial Type',
    ylabel='Percentage of Trials',
    width=600, height=400,
    color=['#2E8B57', '#CD5C5C'],
    legend='top',
    ylim=(0, 100)
)

plot2
# trial_pct
# trial_pct_melted

  trial_pct = df.groupby('type').apply(


In [11]:
trial_pct = df.groupby(['type', 'ssd_number']).apply(
    lambda x: pd.Series({
        'Success': (1 - x['trial_failed'].mean()) * 100,
        'Failed': x['trial_failed'].mean() * 100
    })
).reset_index()

trial_pct_melted = trial_pct.melt(id_vars=['ssd_number', 'type'], var_name='outcome', value_name='percentage')

cont_success_rates = trial_pct_melted[trial_pct_melted['type'] == 'CONT'].groupby('ssd_number').apply(
    lambda x: pd.Series({
        'success_rate': x.loc[x['outcome'] == 'Success', 'percentage'].mean()
})).reset_index()

stop_failure_rates = trial_pct_melted[trial_pct_melted['type'] == 'STOP'].groupby('ssd_number').apply(
    lambda x: pd.Series({
        'failure_rate': x.loc[x['outcome'] == 'Failed', 'percentage'].mean()
})).reset_index()

cont_success_rate_plot = cont_success_rates.hvplot.line(
    x='ssd_number', y='success_rate',
    title=f'{monkey.title()} - Success Rate by Trial Type (%)',
    xlabel='Trial Type',
    ylabel='Percentage of Trials',
    width=700, height=400,
    color=["#134FE6", '#CD5C5C'],
    legend='top',
    label='Correct continue',
)

stop_failure_rates_plot = stop_failure_rates.hvplot.line(
    x='ssd_number', y='failure_rate',
    color=["#CD5C5C"],  # Red for failed        
    label='Failed stop',
    marker='o',
    ylim=(0, 100),
    xlim=(0, 5),
)

(cont_success_rate_plot * stop_failure_rates_plot)
# trial_pct_melted
# cont_success_rates


  trial_pct = df.groupby(['type', 'ssd_number']).apply(
  cont_success_rates = trial_pct_melted[trial_pct_melted['type'] == 'CONT'].groupby('ssd_number').apply(
  stop_failure_rates = trial_pct_melted[trial_pct_melted['type'] == 'STOP'].groupby('ssd_number').apply(


In [12]:
trial_pct = df.groupby(['type', 'ssd_len']).apply(
    lambda x: pd.Series({
        'Success': (1 - x['trial_failed'].mean()) * 100,
        'Failed': x['trial_failed'].mean() * 100
    })
).reset_index()

trial_pct_melted = trial_pct.melt(id_vars=['ssd_len', 'type'], var_name='outcome', value_name='percentage')

cont_success_rates = trial_pct_melted[trial_pct_melted['type'] == 'CONT'].groupby('ssd_len').apply(
    lambda x: pd.Series({
        'success_rate': x.loc[x['outcome'] == 'Success', 'percentage'].mean()
})).reset_index()

stop_failure_rates = trial_pct_melted[trial_pct_melted['type'] == 'STOP'].groupby('ssd_len').apply(
    lambda x: pd.Series({
        'failure_rate': x.loc[x['outcome'] == 'Failed', 'percentage'].mean()
})).reset_index()

cont_success_rate_plot = cont_success_rates.hvplot.line(
    x='ssd_len', y='success_rate',
    title=f'{monkey.title()} - Success Rate by Trial Type (%)',
    xlabel='Trial Type',
    ylabel='Percentage of Trials',
    width=700, height=400,
    color=["#134FE6", '#CD5C5C'],
    legend='top',
    label='Correct continue',
)

stop_failure_rates_plot = stop_failure_rates.hvplot.line(
    x='ssd_len', y='failure_rate',
    color=["#CD5C5C"],  # Red for failed        
    label='Failed stop',
    marker='o',
    # ylim=(0, 100),
    # xlim=(0, 5),
)

(cont_success_rate_plot * stop_failure_rates_plot)
trial_pct_melted
# cont_success_rates


  trial_pct = df.groupby(['type', 'ssd_len']).apply(
  cont_success_rates = trial_pct_melted[trial_pct_melted['type'] == 'CONT'].groupby('ssd_len').apply(
  stop_failure_rates = trial_pct_melted[trial_pct_melted['type'] == 'STOP'].groupby('ssd_len').apply(


Unnamed: 0,ssd_len,type,outcome,percentage
0,24,CONT,Success,85.365854
1,48,CONT,Success,84.437350
2,72,CONT,Success,92.857143
3,84,CONT,Success,87.846890
4,108,CONT,Success,89.140271
...,...,...,...,...
65,180,STOP,Failed,60.940325
66,192,STOP,Failed,43.949045
67,204,STOP,Failed,68.201754
68,228,STOP,Failed,85.412631


In [13]:

## Cell 5: Direction and Trial Type Interaction
# Analyze direction effects across trial types
direction_summary = df.groupby(['type', 'direction', 'trial_failed']).size().reset_index(name='count')
direction_summary['outcome'] = direction_summary['trial_failed'].map({False: 'Success', True: 'Failed'})

# Success rates by type and direction
dir_success = df.groupby(['type', 'direction']).agg({
    'trial_failed': ['count', 'mean']
}).round(3)
dir_success.columns = ['total_trials', 'failure_rate']
dir_success['success_rate'] = (1 - dir_success['failure_rate']) * 100
dir_success = dir_success.reset_index()

print("Success rates by trial type and direction:")
print(dir_success.pivot(index='type', columns='direction', values='success_rate'))

# Visualization
plot3 = direction_summary[direction_summary['outcome'] == 'Success'].hvplot.bar(
    x='type', y='count', by='direction',
    title=f'{monkey.title()} - Successful Trials by Type and Direction',
    xlabel='Trial Type',
    ylabel='Number of Successful Trials',
    width=600, height=400,
    legend='top_right'
)

plot3


Success rates by trial type and direction:
direction     L     R
type                 
CONT       89.6  79.6
GO         95.1  95.6
STOP       50.8  63.0


In [14]:
## Cell 6: Trial Length Distribution
# Analyze trial length distributions
plot4 = df.hvplot.hist(
    y='trial_length', by='type',
    bins=50, alpha=0.7,
    title=f'{monkey.title()} - Trial Length Distribution by Type',
    xlabel='Trial Length (ms)',
    ylabel='Frequency',
    width=800, height=400,
    legend='top_right'
)

# Add summary statistics
length_stats = df.groupby('type')['trial_length'].describe()
print("Trial length statistics by type (ms):")
print(length_stats.round(1))

plot4


Trial length statistics by type (ms):
        count    mean    std     min     25%     50%     75%     max
type                                                                
CONT  25294.0  2054.4  243.8  1312.0  2076.0  2134.0  2191.0  2251.0
GO    61460.0  2131.4  119.4  1184.0  2093.0  2145.0  2197.0  2251.0
STOP  23604.0  1816.6  107.4   956.0  1749.0  1823.0  1890.0  2099.0


In [15]:
## Cell 7: Go Cue Timing Analysis
# Analyze go cue timing
plot5 = df.hvplot.hist(
    y='go_cue', by='type',
    bins=50, alpha=0.7,
    title=f'{monkey.title()} - Go Cue Timing Distribution by Type',
    xlabel='Go Cue Time (ms)',
    ylabel='Frequency',
    width=800, height=400,
    legend='top_right'
)

# Summary statistics
go_cue_stats = df.groupby('type')['go_cue'].describe()
print("Go cue timing statistics by type (ms):")
print(go_cue_stats.round(1))

plot5


Go cue timing statistics by type (ms):
        count    mean   std    min    25%     50%     75%     max
type                                                             
CONT  25294.0   999.7  57.6  900.0  950.0   999.0  1049.0  1100.0
GO    61460.0   999.5  58.0  900.0  949.0  1000.0  1049.0  1100.0
STOP  23604.0  1000.0  58.1  900.0  950.0  1000.0  1051.0  1100.0


In [16]:
## Cell 8: Stop/Continue Signal Delay Analysis
# Filter for STOP and CONTINUE trials only (these have SSD parameters)
signal_trials = df[df['type'].isin(['STOP', 'CONT'])].copy()

print("=== STOP/CONTINUE SIGNAL DELAY ANALYSIS ===")
print(f"\nTotal STOP trials: {len(signal_trials[signal_trials['type'] == 'STOP']):,}")
print(f"Total CONTINUE trials: {len(signal_trials[signal_trials['type'] == 'CONT']):,}")

# Check for missing values
print(f"\nMissing ssd_number: {signal_trials['ssd_number'].isna().sum():,}")
print(f"Missing ssd_len: {signal_trials['ssd_len'].isna().sum():,}")

# Overview of SSD conditions
print(f"\nUnique SSD numbers: {sorted(signal_trials['ssd_number'].dropna().unique())}")
print(f"SSD length range: {signal_trials['ssd_len'].min():.0f} - {signal_trials['ssd_len'].max():.0f} ms")

# SSD number distribution by trial type
ssd_dist = signal_trials.groupby(['type', 'ssd_number']).size().reset_index(name='count')
print(f"\nSSD number distribution by trial type:")
print(ssd_dist.pivot(index='ssd_number', columns='type', values='count').fillna(0))



=== STOP/CONTINUE SIGNAL DELAY ANALYSIS ===

Total STOP trials: 23,604
Total CONTINUE trials: 25,294

Missing ssd_number: 0
Missing ssd_len: 0

Unique SSD numbers: [np.float64(1.0), np.float64(2.0), np.float64(3.0), np.float64(4.0)]
SSD length range: 24 - 252 ms

SSD number distribution by trial type:
type        CONT  STOP
ssd_number            
1.0         6328  5882
2.0         6189  5890
3.0         6464  5927
4.0         6313  5905


In [17]:
## Cell 9: SSD Length Distribution by Trial Type and SSD Number
# Create a combined grouping variable for better visualization
signal_trials['type_ssd'] = signal_trials['type'] + '_SSD' + signal_trials['ssd_number'].astype(str)

# Plot SSD length distributions with different colors for each SSD number
plot6 = signal_trials.hvplot.hist(
    y='ssd_len', by='type_ssd',
    bins=30, alpha=0.7,
    title=f'{monkey.title()} - Stop/Continue Signal Delay Distribution by SSD Number',
    xlabel='Signal Delay Length (ms)',
    ylabel='Frequency',
    width=800, height=400,
    legend='top_right'
)

# Summary statistics by trial type and SSD number
ssd_detailed_stats = signal_trials.groupby(['type', 'ssd_number'])['ssd_len'].describe()
print("SSD length statistics by trial type and SSD number (ms):")
print(ssd_detailed_stats.round(1))

# Show the unique SSD length for each SSD number (should be consistent)
ssd_mapping = signal_trials.groupby(['type', 'ssd_number'])['ssd_len'].unique()
print(f"\nSSD number to length mapping:")
for (trial_type, ssd_num), lengths in ssd_mapping.items():
    print(f"{trial_type} SSD{ssd_num}: {lengths}")

plot6


SSD length statistics by trial type and SSD number (ms):
                  count   mean   std    min    25%    50%    75%    max
type ssd_number                                                        
CONT 1.0         6328.0   49.7  13.5   24.0   48.0   48.0   48.0   84.0
     2.0         6189.0  108.1  11.9   72.0  108.0  108.0  108.0  132.0
     3.0         6464.0  167.3  10.4  120.0  168.0  168.0  168.0  192.0
     4.0         6313.0  225.4  11.6  168.0  228.0  228.0  228.0  252.0
STOP 1.0         5882.0   49.5  13.6   24.0   48.0   48.0   48.0   84.0
     2.0         5890.0  108.2  12.1   72.0  108.0  108.0  108.0  132.0
     3.0         5927.0  166.9  11.0  120.0  168.0  168.0  168.0  192.0
     4.0         5905.0  225.3  11.8  168.0  228.0  228.0  228.0  252.0

SSD number to length mapping:
CONT SSD1.0: [84 48 24 72]
CONT SSD2.0: [132 108  72  84]
CONT SSD3.0: [180 168 120 144 192]
CONT SSD4.0: [228 168 204 252]
STOP SSD1.0: [84 48 24 72]
STOP SSD2.0: [132 108  72  84]
STOP SSD3.

In [18]:
## Cell 10: SSD Number vs Length Relationship - Violin Plots
# Examine relationship between ssd_number and ssd_len
ssts_df = df[df['type'].isin(['STOP', 'CONT'])].copy()

plot7 = hv.Violin(
    ssts_df, kdims=['ssd_number', 'type'], vdims='ssd_len'
).opts(
    opts.Violin(
        show_legend=True, height=400, width=600,
        violin_color=hv.dim('type').str(),
        legend_position='top_left',
        split='type',
        title=f'{monkey.title()} - SSD/CSD Length Distribution by Number and Type',
        xlabel='number',
        ylabel='SSD/CSD duration (ms)',
        show_grid=True,
        violin_width=2,
        invert_axes=True,
        tools=['hover'],
    ),
    # tools=['hover'],
)
plot7

In [19]:
## Cell 11: Success Rate by SSD Condition
# Analyze success rates across different SSD conditions
success_by_ssd = signal_trials.groupby(['type', 'ssd_number']).agg({
    'trial_failed': ['count', 'sum', 'mean']
}).round(3)
success_by_ssd.columns = ['total_trials', 'failed_trials', 'failure_rate']
success_by_ssd['success_rate'] = (1 - success_by_ssd['failure_rate']) * 100

print("Success rates by trial type and SSD number:")
print(success_by_ssd)

# Plot success rates as bar plot
success_plot_data = success_by_ssd.reset_index()
plot8 = success_plot_data.hvplot.bar(
    x='ssd_number', y='success_rate', by='type',
    title=f'{monkey.title()} - Success Rate by SSD/CSD Number',
    xlabel='SSD/CSD Number',
    ylabel='Success Rate (%)',
    width=700, height=400,
    alpha=0.8,
    legend='top_right'
)

plot8

Success rates by trial type and SSD number:
                 total_trials  failed_trials  failure_rate  success_rate
type ssd_number                                                         
CONT 1.0                 6328            949         0.150          85.0
     2.0                 6189            688         0.111          88.9
     3.0                 6464           1289         0.199          80.1
     4.0                 6313           1007         0.160          84.0
STOP 1.0                 5882            606         0.103          89.7
     2.0                 5890           1566         0.266          73.4
     3.0                 5927           3119         0.526          47.4
     4.0                 5905           4889         0.828          17.2


In [20]:
## Cell 12: SSD Length vs Performance
# Analyze how SSD length affects performance
# Create SSD length bins for analysis
signal_trials['ssd_bin'] = pd.cut(signal_trials['ssd_len'], bins=8, precision=0)

perf_by_length = signal_trials.groupby(['type', 'ssd_bin']).agg({
    'trial_failed': ['count', 'mean']
}).round(3)
perf_by_length.columns = ['trial_count', 'failure_rate']
perf_by_length['success_rate'] = (1 - perf_by_length['failure_rate']) * 100
perf_by_length = perf_by_length.reset_index()

print("Performance by SSD length bins:")
print(perf_by_length[perf_by_length['trial_count'] >= 10])  # Only show bins with sufficient trials

# Aggregate success rates by SSD length for bar plot
success_by_length = signal_trials.groupby(['type', 'ssd_len']).agg({
    'trial_failed': ['count', 'mean']
}).round(3)
success_by_length.columns = ['trial_count', 'failure_rate']
success_by_length['success_rate'] = (1 - success_by_length['failure_rate']) * 100
success_by_length = success_by_length.reset_index()

# Plot as bar chart
plot9 = success_by_length.hvplot.bar(
    x='ssd_len', y='success_rate', 
    by='type',
    title=f'{monkey.title()} - Success Rate by SSD Length',
    xlabel='SSD Length (ms)',
    ylabel='Success Rate (%)',
    width=900, height=400,
    alpha=0.8,
    legend='top_right',
    rot=90,
)

plot9

Performance by SSD length bins:
    type         ssd_bin  trial_count  failure_rate  success_rate
0   CONT    (24.0, 52.0]         5586         0.155          84.5
1   CONT    (52.0, 81.0]          308         0.071          92.9
2   CONT   (81.0, 110.0]         5907         0.111          88.9
3   CONT  (110.0, 138.0]          834         0.133          86.7
4   CONT  (138.0, 166.0]          428         0.096          90.4
5   CONT  (166.0, 195.0]         6065         0.207          79.3
6   CONT  (195.0, 224.0]          470         0.149          85.1
7   CONT  (224.0, 252.0]         5696         0.161          83.9
8   STOP    (24.0, 52.0]         5197         0.094          90.6
9   STOP    (52.0, 81.0]          316         0.149          85.1
10  STOP   (81.0, 110.0]         5549         0.255          74.5
11  STOP  (110.0, 138.0]          838         0.321          67.9
12  STOP  (138.0, 166.0]          451         0.388          61.2
13  STOP  (166.0, 195.0]         5493       

  perf_by_length = signal_trials.groupby(['type', 'ssd_bin']).agg({


In [21]:
# Cell 13: Compute Reaction Time Measures
import ast

print("=== COMPUTING REACTION TIME MEASURES ===")

# Create working copy of the dataframe
df_rt = df.copy()

# Helper function to safely extract saccade start time
def extract_saccade_start(saccade_array):
    """Extract saccade start time from first_relevant_saccade array"""
    
    # Handle None values
    if saccade_array is None:
        return np.nan
    
    # Handle numpy arrays directly
    if isinstance(saccade_array, np.ndarray):
        if len(saccade_array) >= 2:
            return float(saccade_array[0])
        else:
            return np.nan
    
    # Handle lists and tuples
    if isinstance(saccade_array, (list, tuple)):
        if len(saccade_array) >= 2:
            return float(saccade_array[0])
        else:
            return np.nan
    
    # Handle string representations
    if isinstance(saccade_array, str):
        saccade_array = saccade_array.strip()
        if saccade_array == '' or saccade_array == 'nan':
            return np.nan
        try:
            # Try to evaluate string representation of array
            saccade_data = ast.literal_eval(saccade_array)
            if isinstance(saccade_data, (list, tuple)) and len(saccade_data) >= 2:
                return float(saccade_data[0])
            else:
                return np.nan
        except:
            return np.nan
    
    # Handle pandas NA/NaN values
    try:
        if pd.isna(saccade_array):
            return np.nan
    except:
        pass
    
    # If we get here, we couldn't parse it
    return np.nan

# Let's first explore what types we have in the first_relevant_saccade column
# print("Exploring first_relevant_saccade column types...")
# sample_values = df_rt['first_relevant_saccade'].dropna().iloc[:10]
# print("Sample values and their types:")
# for i, val in enumerate(sample_values):
#     print(f"  {i}: {type(val)} - {str(val)[:100]}...")

# Extract saccade start times
print("\nExtracting saccade start times...")
df_rt['saccade_start'] = df_rt['first_relevant_saccade'].apply(extract_saccade_start)

# Initialize RT columns
df_rt['computed_rt'] = np.nan
df_rt['rt_type'] = ''
df_rt['signal_delay'] = np.nan

# Calculate RTs for each trial
print("Computing reaction times by trial type...")

for idx, row in df_rt.iterrows():
    if pd.notna(row['saccade_start']) and pd.notna(row['go_cue']):
        rt = row['saccade_start'] - row['go_cue']
        
        # Only consider positive RTs (saccade after go cue)
        if rt > 0:
            df_rt.loc[idx, 'computed_rt'] = rt
            
            # Classify RT type based on trial type and outcome
            if row['type'] == 'GO' and not row['trial_failed']:
                df_rt.loc[idx, 'rt_type'] = 'GO_RT'
            elif row['type'] == 'STOP' and row['trial_failed']:
                df_rt.loc[idx, 'rt_type'] = 'Error_Stop_RT'
            elif row['type'] == 'CONT':
                df_rt.loc[idx, 'rt_type'] = 'Continue_RT'
            else:
                df_rt.loc[idx, 'rt_type'] = 'Other'
    
    # Calculate signal delays for STOP and CONT trials
    if row['type'] in ['STOP', 'CONT'] and pd.notna(row['stop_cue']) and pd.notna(row['go_cue']):
        signal_delay = row['stop_cue'] - row['go_cue']
        if signal_delay > 0:  # Ensure signal came after go cue
            df_rt.loc[idx, 'signal_delay'] = signal_delay

# Summary statistics
print("\n=== RT COMPUTATION SUMMARY ===")
rt_summary = df_rt['rt_type'].value_counts()
print("RT types computed:")
print(rt_summary)

print(f"\nValid saccade start times: {df_rt['saccade_start'].notna().sum():,} / {len(df_rt):,}")
print(f"Valid computed RTs: {df_rt['computed_rt'].notna().sum():,}")
print(f"Valid signal delays: {df_rt['signal_delay'].notna().sum():,}")

# RT statistics by type
print("\n=== RT STATISTICS BY TYPE ===")
if df_rt['computed_rt'].notna().sum() > 0:
    rt_stats = df_rt[df_rt['computed_rt'].notna()].groupby('rt_type')['computed_rt'].agg(['count', 'mean', 'std', 'min', 'max']).round(1)
    print(rt_stats)
else:
    print("No valid RTs computed")

# Signal delay statistics
print("\n=== SIGNAL DELAY STATISTICS ===")
if df_rt['signal_delay'].notna().sum() > 0:
    signal_stats = df_rt[df_rt['signal_delay'].notna()].groupby('type')['signal_delay'].agg(['count', 'mean', 'std', 'min', 'max']).round(1)
    print(signal_stats)
else:
    print("No valid signal delays computed")

print("\n=== DATA VALIDATION ===")
# Check for negative RTs (should be rare - anticipatory responses)
negative_rts = ((df_rt['saccade_start'] - df_rt['go_cue']) < 0).sum()
print(f"Negative RTs (anticipatory): {negative_rts}")

# Check for very fast RTs (< 100ms, potentially artifacts)
very_fast_rts = (df_rt['computed_rt'] < 100).sum()
print(f"Very fast RTs (<100ms): {very_fast_rts}")

# Check for very slow RTs (> 1000ms, potentially missed responses)
very_slow_rts = (df_rt['computed_rt'] > 1000).sum()
print(f"Very slow RTs (>1000ms): {very_slow_rts}")

=== COMPUTING REACTION TIME MEASURES ===

Extracting saccade start times...
Computing reaction times by trial type...

=== RT COMPUTATION SUMMARY ===
RT types computed:
rt_type
GO_RT            58576
Continue_RT      22902
Other            11444
Error_Stop_RT    10171
                  7265
Name: count, dtype: int64

Valid saccade start times: 103,093 / 110,358
Valid computed RTs: 103,093
Valid signal delays: 48,898

=== RT STATISTICS BY TYPE ===
               count   mean    std  min    max
rt_type                                       
Continue_RT    22902  250.3  109.0  1.0  703.0
Error_Stop_RT  10171  178.7   82.4  1.0  898.0
GO_RT          58576  208.3   74.2  1.0  518.0
Other          11444  445.1  207.0  1.0  912.0

=== SIGNAL DELAY STATISTICS ===
      count   mean   std   min    max
type                                 
CONT  25294  137.9  66.6  24.0  252.0
STOP  23604  137.6  66.6  24.0  252.0

=== DATA VALIDATION ===
Negative RTs (anticipatory): 0
Very fast RTs (<100ms): 80

In [22]:
## Cell 14: Stop and Continue Performance by Signal Delay (Figure 1b Replication)
# Replicate Figure 1b: Continue and stop performance
print("=== REPLICATING FIGURE 1B: STOP AND CONTINUE PERFORMANCE ===")

# Filter for STOP and CONT trials with valid signal delays
signal_perf_data = df_rt[df_rt['type'].isin(['STOP', 'CONT']) & df_rt['signal_delay'].notna()].copy()

# Calculate performance by signal delay for each trial type
print(f"Analyzing {len(signal_perf_data):,} trials with valid signal delays")

# For STOP trials: Calculate error rate (percentage of failed stops) by signal delay
stop_performance = signal_perf_data[signal_perf_data['type'] == 'STOP'].groupby('signal_delay').agg({
    'trial_failed': ['count', 'sum', 'mean']
}).round(3)
stop_performance.columns = ['total_trials', 'failed_trials', 'error_rate']
stop_performance['error_percentage'] = stop_performance['error_rate'] * 100
stop_performance = stop_performance.reset_index()

# For CONT trials: Calculate correct rate (percentage of successful continues) by signal delay
cont_performance = signal_perf_data[signal_perf_data['type'] == 'CONT'].groupby('signal_delay').agg({
    'trial_failed': ['count', 'sum', 'mean']
}).round(3)
cont_performance.columns = ['total_trials', 'failed_trials', 'failure_rate']
cont_performance['correct_percentage'] = (1 - cont_performance['failure_rate']) * 100
cont_performance = cont_performance.reset_index()

print("STOP trial error rates by signal delay:")
print(stop_performance[['signal_delay', 'total_trials', 'error_percentage']])

print("\nCONT trial success rates by signal delay:")
print(cont_performance[['signal_delay', 'total_trials', 'correct_percentage']])

# Create the plot replicating Figure 1b
stop_plot = stop_performance.hvplot.line(
    x='signal_delay', y='error_percentage',
    color='red', line_width=3, 
    label=f'Error stop ({monkey.title()})',
    markers=True, marker_size=8
)

cont_plot = cont_performance.hvplot.line(
    x='signal_delay', y='correct_percentage', 
    color='blue', line_width=3,
    label=f'Correct continue ({monkey.title()})',
    markers=True, marker_size=8, line_dash='dashed'
)

# Combine plots
plot_fig1b = (stop_plot * cont_plot).opts(
    title=f'{monkey.title()} - Stop and Continue Performance (Figure 1b)',
    xlabel='Stop/continue signal delay (ms)',
    ylabel='Percentage of saccades',
    width=700, height=400,
    ylim=(0, 100),
    legend_position='top',
    show_grid=True,
    fontsize={'title': 16, 'labels': 14, 'ticks': 12, 'legend': 12},
    # show_legend=True
)

print(f"\n=== RACE MODEL PREDICTIONS ===")
print(f"Error stop rates should INCREASE with longer SSDs (race model prediction)")
print(f"Continue success should be relatively STABLE across CSDs")

# Check race model predictions
ssd_range = stop_performance['signal_delay'].max() - stop_performance['signal_delay'].min()
error_range = stop_performance['error_percentage'].max() - stop_performance['error_percentage'].min()

print(f"\nSSD range: {ssd_range:.0f} ms")
print(f"Error rate change: {error_range:.1f} percentage points")

if error_range > 20:  # Arbitrary threshold for "substantial" increase
    print("✓ Error rates show substantial increase with SSD (consistent with race model)")
else:
    print("? Error rates show modest increase with SSD")

cont_range = cont_performance['correct_percentage'].max() - cont_performance['correct_percentage'].min()
print(f"Continue success variability: {cont_range:.1f} percentage points")

if cont_range < 20:  # Arbitrary threshold for "stable"
    print("✓ Continue success rates relatively stable (consistent with preserved saccade generation)")
else:
    print("? Continue success shows notable variability")

plot_fig1b



=== REPLICATING FIGURE 1B: STOP AND CONTINUE PERFORMANCE ===
Analyzing 48,898 trials with valid signal delays
STOP trial error rates by signal delay:
    signal_delay  total_trials  error_percentage
0           24.0           579               7.8
1           48.0          4618               9.6
2           72.0           316              14.9
3           84.0           962              18.5
4          108.0          4587              26.9
5          120.0           128              34.4
6          132.0           710              31.7
7          144.0           451              38.8
8          168.0          4783              53.6
9          180.0           553              60.9
10         192.0           157              43.9
11         204.0           456              68.2
12         228.0          5162              85.4
13         252.0           142              70.4

CONT trial success rates by signal delay:
    signal_delay  total_trials  correct_percentage
0           24.0     

In [23]:
## Cell 14b: Stop and Continue Performance by Signal Delay (Figure 1b Replication)
# Replicate Figure 1b: Continue and stop performance
print("=== REPLICATING FIGURE 1B: STOP AND CONTINUE PERFORMANCE ===")

# Filter for STOP and CONT trials with valid signal delays
signal_perf_data = df_rt[df_rt['type'].isin(['STOP', 'CONT']) & df_rt['signal_delay'].notna()].copy()

# Calculate performance by signal delay for each trial type
# print(f"Analyzing {len(signal_perf_data):,} trials with valid signal delays")

# For STOP trials: Calculate error rate (percentage of failed stops) by signal delay
stop_performance = signal_perf_data[signal_perf_data['type'] == 'STOP'].groupby('ssd_number').agg({
    'trial_failed': ['count', 'sum', 'mean']
}).round(3)
stop_performance.columns = ['total_trials', 'failed_trials', 'error_rate']
stop_performance['error_percentage'] = stop_performance['error_rate'] * 100
stop_performance = stop_performance.reset_index()

# For CONT trials: Calculate correct rate (percentage of successful continues) by signal delay
cont_performance = signal_perf_data[signal_perf_data['type'] == 'CONT'].groupby('ssd_number').agg({
    'trial_failed': ['count', 'sum', 'mean']
}).round(3)
cont_performance.columns = ['total_trials', 'failed_trials', 'failure_rate']
cont_performance['correct_percentage'] = (1 - cont_performance['failure_rate']) * 100
cont_performance = cont_performance.reset_index()

# print("STOP trial error rates by signal delay:")
# print(stop_performance[['ssd_number', 'total_trials', 'error_percentage']])

# print("\nCONT trial success rates by signal delay:")
# print(cont_performance[['ssd_number', 'total_trials', 'correct_percentage']])

# Create the plot replicating Figure 1b
stop_plot = stop_performance.hvplot.line(
    x='ssd_number', y='error_percentage',
    color='red', line_width=3, 
    label=f'Error stop ({monkey.title()})',
    markers=True, marker_size=8
)

cont_plot = cont_performance.hvplot.line(
    x='ssd_number', y='correct_percentage', 
    color='blue', line_width=3,
    label=f'Correct continue ({monkey.title()})',
    markers=True, marker_size=8, line_dash='dashed'
)

# Combine plots
plot_fig1b = (stop_plot * cont_plot).opts(
    title=f'{monkey.title()} - Stop and Continue Performance (Figure 1b)',
    xlabel='Stop/continue ssd_number',
    ylabel='Percentage of saccades',
    width=700, height=400,
    ylim=(0, 100),
    legend_position='top',
    show_grid=True
)

# print(f"\n=== RACE MODEL PREDICTIONS ===")
# print(f"Error stop rates should INCREASE with longer SSDs (race model prediction)")
# print(f"Continue success should be relatively STABLE across CSDs")

# Check race model predictions
ssd_range = stop_performance['ssd_number'].max() - stop_performance['ssd_number'].min()
error_range = stop_performance['error_percentage'].max() - stop_performance['error_percentage'].min()

print(f"\nSSD range: {ssd_range:.0f} ms")
print(f"Error rate change: {error_range:.1f} percentage points")

# if error_range > 20:  # Arbitrary threshold for "substantial" increase
#     print("✓ Error rates show substantial increase with SSD (consistent with race model)")
# else:
#     print("? Error rates show modest increase with SSD")

# cont_range = cont_performance['correct_percentage'].max() - cont_performance['correct_percentage'].min()
# print(f"Continue success variability: {cont_range:.1f} percentage points")

# if cont_range < 20:  # Arbitrary threshold for "stable"
#     print("✓ Continue success rates relatively stable (consistent with preserved saccade generation)")
# else:
#     print("? Continue success shows notable variability")

plot_fig1b
# stop_performance
for g, gdf in df.groupby('ssd_number'):
    print(f"SSD Number: {g}")
    print(gdf['ssd_len'].value_counts().idxmax())

=== REPLICATING FIGURE 1B: STOP AND CONTINUE PERFORMANCE ===





SSD range: 3 ms
Error rate change: 72.5 percentage points
SSD Number: 1.0
48
SSD Number: 2.0
108
SSD Number: 3.0
168
SSD Number: 4.0
228


In [33]:
## Cell 14b: Stop and Continue Performance by Signal Delay (Figure 1b Replication)
# Replicate Figure 1b: Continue and stop performance
print("=== REPLICATING FIGURE 1B: STOP AND CONTINUE PERFORMANCE ===")

# Filter for STOP and CONT trials with valid signal delays
signal_perf_data = df_rt[df_rt['type'].isin(['STOP', 'CONT']) & df_rt['signal_delay'].notna()].copy()

# Calculate performance by signal delay for each trial type
# print(f"Analyzing {len(signal_perf_data):,} trials with valid signal delays")

# For STOP trials: Calculate error rate (percentage of failed stops) by signal delay
stop_performance = signal_perf_data[signal_perf_data['type'] == 'STOP'].groupby('ssd_number').agg({
    'trial_failed': ['count', 'sum', 'mean']
}).round(3)
stop_performance.columns = ['total_trials', 'failed_trials', 'error_rate']
stop_performance['error_percentage'] = stop_performance['error_rate'] * 100
stop_performance = stop_performance.reset_index()

# For CONT trials: Calculate correct rate (percentage of successful continues) by signal delay
cont_performance = signal_perf_data[signal_perf_data['type'] == 'CONT'].groupby('ssd_number').agg({
    'trial_failed': ['count', 'sum', 'mean']
}).round(3)
cont_performance.columns = ['total_trials', 'failed_trials', 'failure_rate']
cont_performance['correct_percentage'] = (1 - cont_performance['failure_rate']) * 100
cont_performance = cont_performance.reset_index()

ssd_dict = {
    1.0 :  48,
    2.0 :  108,
    3.0 :  168,
    4.0 :  228
}
ssd_len_col = [48, 108, 168, 228]
stop_performance['ssd_len'] = ssd_len_col
cont_performance['ssd_len'] = ssd_len_col
# print("STOP trial error rates by signal delay:")
# print(stop_performance[['ssd_number', 'total_trials', 'error_percentage']])

# print("\nCONT trial success rates by signal delay:")
# print(cont_performance[['ssd_number', 'total_trials', 'correct_percentage']])

# Create the plot replicating Figure 1b
stop_plot = stop_performance.hvplot.line(
    x='ssd_len', y='error_percentage',
    color='red', line_width=3, 
    label=f'Error stop ({monkey.title()})',
    markers=True, marker_size=8
)

cont_plot = cont_performance.hvplot.line(
    x='ssd_len', y='correct_percentage', 
    color='blue', line_width=3,
    label=f'Correct continue ({monkey.title()})',
    markers=True, marker_size=8, line_dash='dashed'
)

# Combine plots
plot_fig1b = (stop_plot * cont_plot).opts(
    title=f'{monkey.title()} - Stop and Continue Performance (Figure 1b)',
    xlabel='Stop/continue ssd_number',
    ylabel='Percentage of saccades',
    width=700, height=400,
    ylim=(0, 100),
    legend_position='top',
    show_grid=True
)

# print(f"\n=== RACE MODEL PREDICTIONS ===")
# print(f"Error stop rates should INCREASE with longer SSDs (race model prediction)")
# print(f"Continue success should be relatively STABLE across CSDs")

# Check race model predictions
ssd_range = stop_performance['ssd_number'].max() - stop_performance['ssd_number'].min()
error_range = stop_performance['error_percentage'].max() - stop_performance['error_percentage'].min()

print(f"\nSSD range: {ssd_range:.0f} ms")
print(f"Error rate change: {error_range:.1f} percentage points")

# if error_range > 20:  # Arbitrary threshold for "substantial" increase
#     print("✓ Error rates show substantial increase with SSD (consistent with race model)")
# else:
#     print("? Error rates show modest increase with SSD")

# cont_range = cont_performance['correct_percentage'].max() - cont_performance['correct_percentage'].min()
# print(f"Continue success variability: {cont_range:.1f} percentage points")

# if cont_range < 20:  # Arbitrary threshold for "stable"
#     print("✓ Continue success rates relatively stable (consistent with preserved saccade generation)")
# else:
#     print("? Continue success shows notable variability")

plot_fig1b
len_col_lst = []

for g, gdf in df.groupby('ssd_number'):
    # print(f"SSD Number: {g}")
    len_col_lst.append(gdf['ssd_len'].value_counts().idxmax())
    print(g, ": ",gdf['ssd_len'].value_counts().idxmax())

len_col_lst.sort()
print(len_col_lst)



=== REPLICATING FIGURE 1B: STOP AND CONTINUE PERFORMANCE ===

SSD range: 3 ms
Error rate change: 23.1 percentage points
1.0 :  48
2.0 :  108
3.0 :  168
4.0 :  228
[np.int64(48), np.int64(108), np.int64(168), np.int64(228)]


In [25]:
rt_scatter_data = df_rt[df_rt['computed_rt'].notna() & df_rt['rt_type'].isin(['GO_RT', 'Continue_RT', 'Error_Stop_RT'])].copy()

scatter_df = rt_scatter_data.groupby(['rt_type', 'trial_session']).agg({
    'computed_rt': ['mean']
})
scatter_df.columns = ['mean_rt']  # Flatten the MultiIndex columns
scatter_df = scatter_df.reset_index()  # Reset index to make 'rt_type' and 'trial_session' columns
# Filter the data for 'GO_RT' and 'Continue_RT'
scatter_data = scatter_df[scatter_df['rt_type'].isin(['GO_RT', 'Continue_RT'])]

# Pivot the data to have 'GO_RT' and 'Continue_RT' as columns
scatter_pivot = scatter_data.pivot(index='trial_session', columns='rt_type', values='mean_rt').reset_index()

# Create the scatter plot
scatter_plot = scatter_pivot.hvplot.scatter(
    x='GO_RT', y='Continue_RT',
    # title='Mean RT: GO_RT vs Continue_RT',
    xlabel='Mean RT (GO_RT)',
    # ylabel='Mean RT (Continue_RT)',
    width=700, height=400,
    alpha=0.7,
    color='purple',
    label=f'Continue RT',
    legend=True
)

# Filter the data for 'GO_RT' and 'Error_Stop_RT'
scatter_data_error = scatter_df[scatter_df['rt_type'].isin(['GO_RT', 'Error_Stop_RT'])]

# Pivot the data to have 'GO_RT' and 'Error_Stop_RT' as columns
scatter_pivot_error = scatter_data_error.pivot(index='trial_session', columns='rt_type', values='mean_rt').reset_index()

# Create the scatter plot for Error_Stop_RT
scatter_plot_error = scatter_pivot_error.hvplot.scatter(
    x='GO_RT', y='Error_Stop_RT',
    title=f'{monkey} Session mean RT',
    # xlabel='Mean RT (GO_RT)',
    ylabel='Mean RT (Error_Stop_RT / Continue_RT)',
    width=700, height=400,
    alpha=0.7,
    color='green',
    legend=True, 
    label=f'Error stop RT'
)
scatter_df
# # Add diagonal line y=x for reference
# min_rt = min(scatter_df['GO_RT'].min(), scatter_df['other_rt'].min())
min_rt = min(scatter_df['mean_rt'])
max_rt = max(scatter_df['mean_rt'])
diagonal_line = hv.Curve([(min_rt, min_rt), (max_rt, max_rt)]).opts(color='black', line_dash='solid')
    

# # Combine both scatter plots
scatter_plot = scatter_plot * scatter_plot_error * diagonal_line
scatter_plot.opts(width=600, height=400, legend_position='top_left', show_legend=True)


In [26]:
## Cell 16: RT Distributions by Signal Delay (Figure 1d)
# Replicate Figure 1d: Continue and error stop RT distributions
print("=== FIGURE 1D: RT DISTRIBUTIONS BY SIGNAL DELAY ===")

# Get RT data with signal delays
rt_dist_data = df_rt[
    (df_rt['computed_rt'].notna()) & 
    (df_rt['signal_delay'].notna()) &
    (df_rt['type'].isin(['STOP', 'CONT']))
].copy()

# Create delay bins for better visualization
# rt_dist_data['delay_bin'] = pd.cut(rt_dist_data['signal_delay'], bins=4, precision=0)
rt_dist_data['delay_bin'] = pd.cut(rt_dist_data['computed_rt'], bins=range(0, int(rt_dist_data['computed_rt'].max()) + 20, 20))
rt_dist_data.drop(
    columns=[
        'hPos', 'vPos', 'hVel', 'vVel', 'speed',
        'blinks', 'neural_data', 'saccades', 'first_relevant_saccade',
        'filename', 'direction', 'go_cue', 'segs_durations', 'segs_times', 'set',
        'trial_number', 'trial_session',
    ], 
    inplace=True
)
print(f"Trials for RT distributions: {len(rt_dist_data):,}")
# print("Signal delay bins:")
# print(rt_dist_data['delay_bin'].value_counts().sort_index())

# Create RT histograms by trial type and delay
if len(rt_dist_data) > 0:
    # Separate STOP and CONT trials
    stop_rt_data = rt_dist_data[rt_dist_data['type'] == 'STOP']
    cont_rt_data = rt_dist_data[rt_dist_data['type'] == 'CONT']
    tot_stop = len(stop_rt_data)
    tot_cont = len(cont_rt_data)
    print(f"  STOP trials: N={tot_stop:,}")
    print(f"  CONT trials: N={tot_cont:,}")
    
    cont_df = cont_rt_data.groupby(['delay_bin', 'ssd_number']).agg({
        'trial_failed': ['sum']
    })
    cont_df.columns = ['failed_trials']
    cont_df.reset_index(inplace=True)
    cont_df['failed_trials'] /= tot_cont
    cont_df['failed_trials'] *= 100  # Convert to percentage
    cont_df['bin'] = cont_df['delay_bin'].apply(lambda x: x.left)
    cont_df['ssd_number'] = cont_df.apply(lambda row: f'CSD{int(row['ssd_number'])}', axis=1)

    stop_df = stop_rt_data.groupby(['delay_bin', 'ssd_number']).agg({
        'trial_failed': ['sum']
    })
    stop_df.columns = ['failed_trials']
    stop_df.reset_index(inplace=True)
    stop_df['failed_trials'] /= tot_stop
    stop_df['failed_trials'] *= 100  # Convert to percentage
    stop_df['bin'] = stop_df['delay_bin'].apply(lambda x: x.left)
    stop_df['ssd_number'] = stop_df.apply(lambda row: f'SSD{int(row['ssd_number'])}', axis=1)

    # Create histogram for continue trials
    if len(cont_df) > 0:
        plot_1d_cont = cont_df.hvplot.line(
            x='bin', y='failed_trials', by='ssd_number',
            title=f'{monkey.title()} - Continue RT Distribution by Signal Delay',
            xlabel='Reaction time (ms)',
            ylabel='Percentage of total trials',
            width=600, height=400,
            line_dash='dashed',
            line_width=2,
        )
    
    # Create histogram for error stop trials
    if len(stop_df) > 0:
        plot_1d_stop = stop_df.hvplot.line(
            x='bin', y='failed_trials', by='ssd_number',
            line_width=2,
        )#.opts(legend_position='top')
    
(plot_1d_cont * plot_1d_stop).opts(
    legend_position='top_right',
    xlim=(0, 500),    
)

=== FIGURE 1D: RT DISTRIBUTIONS BY SIGNAL DELAY ===
Trials for RT distributions: 42,161
  STOP trials: N=19,259
  CONT trials: N=22,902


  cont_df = cont_rt_data.groupby(['delay_bin', 'ssd_number']).agg({
  stop_df = stop_rt_data.groupby(['delay_bin', 'ssd_number']).agg({


In [27]:
failed_stop_trials = df_rt[df_rt['type'].isin(['STOP']) & df_rt['trial_failed'].isin([True])]
successful_cont_trials = df_rt[df_rt['type'].isin(['CONT']) & df_rt['trial_failed'].isin([False])]

tot_cont_trials = len(df_rt[df_rt['type'].isin(['CONT'])])
tot_stop_trials = len(df_rt[df_rt['type'].isin(['STOP'])])
tot_trial = tot_cont_trials + tot_stop_trials

print(f"Total CONT trials: {tot_cont_trials:,}")
print(f" Total STOP trials: {tot_stop_trials:,}")

def frame_it(df, normalizer, ssd_prefix):
    tmp_df = df.groupby(['computed_rt', 'ssd_number']).size()
    tmp_df = tmp_df.reset_index().rename(columns={'computed_rt': 'Reaction Time', 'ssd_number': 'SSD Number', 0: 'Count'})
    tmp_df['percentage'] = (tmp_df['Count'] / normalizer) * 100
    tmp_df['SSD Number'] = tmp_df.apply(lambda row: f'{ssd_prefix}{int(row["SSD Number"])}', axis=1)
    tmp_df = tmp_df.groupby(['Reaction Time', 'SSD Number'])['percentage'].sum().reset_index()
    # Group 'Reaction Time' into 20 millisecond bins and sum percentages within each bin
    tmp_df['Reaction Time Bin'] = (tmp_df['Reaction Time'] // 20) * 20
    tmp_df = tmp_df.groupby(['Reaction Time Bin', 'SSD Number'], as_index=False)['percentage'].sum()
    return tmp_df

cont_df = frame_it(successful_cont_trials, tot_cont_trials, 'CSD')
stop_df = frame_it(failed_stop_trials, tot_stop_trials, 'SSD')

cont_plot = cont_df.hvplot.line(
    x='Reaction Time Bin', y='percentage', by='SSD Number',
    title=f'{monkey.title()} - Continue and error stop RT',
    xlabel='Reaction time (ms)',
    ylabel='Percentage of total trials',
    width=800, height=400,
    line_dash='dashed',
    line_width=3,
)

stop_plot = stop_df.hvplot.line(
    x='Reaction Time Bin', y='percentage', by='SSD Number',
    line_width=3, 
    xlim=(0, 600)
)

(cont_plot * stop_plot).opts(legend_position='top_right')

# stop_df

Total CONT trials: 25,294
 Total STOP trials: 23,604


In [28]:
df_rt[~(df_rt['computed_rt'] == df_rt['reaction_time'])][['reaction_time', 'computed_rt']].isna().all().all()
bin(df_rt[
    ~(df_rt['flags'].apply(
        lambda x: len(bin(x))
    ) == 16)
]['flags'].iloc[0])
# (df_rt['flags'].apply(
#     lambda x: len(bin(x))
# ) == 16).value_counts()
# FLAG_IS_ST_OK
bool(df_rt.iloc[0]['flags'] & (1 << 11))

False

In [29]:
row = df_rt.iloc[160]
bit = 3 if (row['type'] != 'STOP') else 11
is_fail = not bool(row['flags'] & (1 << bit))
is_fail, row['trial_failed']

def check_flag_consistency(row):
    bit = 3 if (row['type'] != 'STOP') else 12
    return not bool(row['flags'] & (1 << bit))

df_rt['trial_failed'] = df_rt.apply(check_flag_consistency, axis=1)