In [None]:
# load librairies
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from os.path import join
import seaborn as sns
import statsmodels.api as sm
import scipy
import json
import statsmodels.formula.api as smf
from collections import defaultdict
from scipy.stats import wilcoxon, ttest_rel

from functions import utils
from functions import plotting
from functions import io

# 1. Load subjects #

In [None]:
# Use the Included or excluded.xlsx document:
onedrive_path = utils._get_onedrive_path()

working_path = os.path.dirname(os.getcwd())
results_path = join(working_path, "results")
behav_results_saving_path = join(results_path, "behav_results")
if not os.path.isdir(behav_results_saving_path):
    os.makedirs(behav_results_saving_path)

In [None]:
# read the json file containing the included and excluded subjects
# Open and read the JSON file
included_excluded_file = join(behav_results_saving_path, 'final_included_subjects.json')
with open(included_excluded_file, 'r') as file:
    included_subjects = json.load(file)

In [None]:
included_subjects

In [None]:
subject_colors = utils.create_color_palette(included_subjects)

In [None]:
plotting.plot_color_palette(subject_colors, behav_results_saving_path)

In [None]:
data = io.load_behav_data(included_subjects, onedrive_path)

In [None]:
color_dict = {
    'DBS OFF': '#20a39e', 
    'DBS ON': '#ef5b5b', 
    'control': '#ffba49', 
    'preop': '#8E7DBE',
    'Session 1': "#206ea1", 
    'Session 2': "#5FA363", 
    }

In [None]:
session_dict = {
    'sub006 DBS ON': 1,
    'sub006 DBS OFF': 2,
    'sub008 DBS ON': 1,
    'sub008 DBS OFF': 2,
    'sub009 DBS ON': 1,
    'sub009 DBS OFF': 2,
    'sub011 DBS ON': 2,
    'sub011 DBS OFF': 1,
    'sub015 DBS ON': 2,
    'sub015 DBS OFF': 1,
    'sub017 DBS ON': 1,
    'sub017 DBS OFF': 2,
    'sub019 DBS ON': 1,
    'sub019 DBS OFF': 2,
    'sub021 DBS ON': 2,
    'sub021 DBS OFF': 1,
    'sub023 DBS ON': 2,
    'sub023 DBS OFF': 1,
    'sub025 DBS ON': 1,
    'sub025 DBS OFF': 2,
    'sub027 DBS ON': 2,
    'sub027 DBS OFF': 1,
    'sub028 DBS OFF': 1,
    'sub028 DBS ON': 2
}

In [None]:
bis_score_dict = {
    'sub006 DBS ON': 29,
    'sub006 DBS OFF': 29,
    'sub011 DBS ON': 36,
    'sub011 DBS OFF': 36,
    'sub015 DBS ON': 24,
    'sub015 DBS OFF': 24,
    'sub019 DBS ON': 27,
    'sub019 DBS OFF': 27,
    'sub023 DBS ON': 38,
    'sub023 DBS OFF': 38,
    'sub025 DBS ON': 31,
    'sub025 DBS OFF': 31,
    'sub027 DBS ON': 32,
    'sub027 DBS OFF': 32
}

bdi_score_dict = {
    'sub006 DBS ON': 6,
    'sub006 DBS OFF': 6,
    'sub011 DBS ON': 27,
    'sub011 DBS OFF': 27,
    'sub015 DBS ON': 12,
    'sub015 DBS OFF': 12,
    'sub019 DBS ON': 10,
    'sub019 DBS OFF': 10,
    'sub023 DBS ON': 12,
    'sub023 DBS OFF': 12,
    'sub025 DBS ON': 17,
    'sub025 DBS OFF': 17,
    'sub027 DBS ON': 14,
    'sub027 DBS OFF': 14
}

# 2. Extract main statistics and values for each subject and compile in a  dictionnary #

In [None]:
stats = utils.extract_stats(data)

In [None]:
# Initialize empty dictionaries
stats_OFF = {}
stats_ON = {}
stats_CONTROL = {}
stats_PREOP = {}

# Loop through the original dictionary and filter into sub-dictionaries
for key, value in stats.items():
    if "OFF" in key:
        stats_OFF[key] = value
    elif "ON" in key:
        stats_ON[key] = value
    elif "C" in key:
        stats_CONTROL[key] = value
    elif "preop" in key:
        stats_PREOP[key] = value

### 2.1. Plot reaction times in the same order as trials ##
(these could help see slowing during the task/block)

In [None]:
trial_color_dict = {'go_trial': "#4877D5",
                    'stop_trial': "#d3075c",
                    'go_continue_trial': "#FB9D05",
                    'go_fast_trial': "#28B628"}

In [None]:
group_values_OFF = utils.get_group_values_ecdf(stats_OFF)
group_values_ON = utils.get_group_values_ecdf(stats_ON)
group_values_CONTROL = utils.get_group_values_ecdf(stats_CONTROL)
group_values_PREOP = utils.get_group_values_ecdf(stats_PREOP)

In [None]:
plotting.plot_cumulative_rt_distributions_group_average(group_values_OFF, 'DBS OFF', trial_color_dict, behav_results_saving_path)
plotting.plot_cumulative_rt_distributions_group_average(group_values_ON, 'DBS ON', trial_color_dict, behav_results_saving_path)
plotting.plot_cumulative_rt_distributions_group_average(group_values_CONTROL, 'CONTROL', trial_color_dict, behav_results_saving_path)
plotting.plot_cumulative_rt_distributions_group_average(group_values_PREOP, 'PREOP', trial_color_dict, behav_results_saving_path)

In [None]:
for trial_name in ['gf', 'go', 'gc', 'gs']:
    plotting.plot_cumulative_rt_distributions_across_groups(
            group_values_OFF,
            group_values_ON,
            group_values_CONTROL,
            group_values_PREOP,
            trial_name,
            color_dict,
            behav_results_saving_path
            )

In [None]:
plotting.plot_cumulative_rt_distributions(stats, trial_color_dict, behav_results_saving_path)

In [None]:
plotting.plot_all_sessions_prep_cost_across_blocks(
        included_subjects, stats, subject_colors, behav_results_saving_path, 
        save_as_pdf=False
)

In [None]:
dict_prep_cost_per_cond = plotting.plot_prep_cost_per_block_per_DBS_group(
    included_subjects, stats, subject_colors, behav_results_saving_path, 
    color_dict, with_average_plot = True, save_as_pdf=False
    )

In [None]:
dict_prep_cost_per_session = plotting.plot_prep_cost_per_block_per_session_order(
    included_subjects, stats, subject_colors, session_dict, behav_results_saving_path, 
    color_dict, with_average_plot = True, save_as_pdf=False    
)

Seems like there really is a learning effect... Let's check how much time was between two sessions for each subject:
sub006 : 3 days between ON and OFF
sub011 : 4 days between OFF and ON
sub015 : 1 day between OFF and ON
sub019 : 9 days between ON and OFF
sub023 : 7 days between OFF and ON

In [None]:
plotting.bar_plot_prep_cost_per_block_per_condition(
    dict_prep_cost_per_cond, color_dict, behav_results_saving_path,
    save_as_pdf = False
    )

In [None]:
plotting.whisker_plot_prep_cost_per_block_per_condition(
    dict_prep_cost_per_cond, color_dict, behav_results_saving_path,
    save_as_pdf = False
)

In [None]:
plotting.whisker_plot_prep_cost_per_block_per_session(
    dict_prep_cost_per_session, color_dict, behav_results_saving_path,
    save_as_pdf = False
)

## Session-Based Analysis: Whisker Plot with Statistical Comparisons

This plot shows the same whisker plot analysis but differentiates DBS subjects by **session** rather than DBS condition:

- **Control** (yellow): Healthy control subjects
- **Preop** (purple): Preoperative subjects  
- **Session 1** (blue): DBS subjects tested in session 1
- **Session 2** (red): DBS subjects tested in session 2

### Key Findings:
- **Session 1** subjects show lower preparation costs initially, with gradual increases across blocks
- **Session 2** subjects have consistently higher preparation costs across all blocks
- **Control** subjects show the highest preparation costs, especially in later blocks
- Statistical comparisons within each group reveal significant changes across blocks (shown by significance bars)

This analysis helps identify whether session-related factors (e.g., learning effects, fatigue, medication timing) influence preparation costs more than the DBS stimulation itself.

In [None]:
for sub in included_subjects:
    trial_IDs = stats[sub]['trial IDs']
    trial_RTs = stats[sub]['RTs (ms)']
    # replace nans with 0
    trial_RTs = [0 if np.isnan(rt) else rt for rt in trial_RTs]
    blocks = stats[sub]['block number']
    y = np.arange(len(trial_RTs))

    # Create figure with subplots for marginal distributions
    fig = plt.figure(figsize=(15, 10))
    
    # Create a grid layout
    gs = fig.add_gridspec(3, 3, width_ratios=[1, 4, 1], height_ratios=[1, 4, 1],
                         hspace=0.1, wspace=0.1)
    
    # Main scatter plot (center)
    ax_main = fig.add_subplot(gs[1, 1])
    
    # Distribution plots
    #ax_top = fig.add_subplot(gs[0, 1], sharex=ax_main)     # Top distribution (trial number)
    ax_right = fig.add_subplot(gs[1, 2], sharey=ax_main)   # Right distribution (RT)
    
    # Main scatter plot
    for i in range(len(trial_IDs)):
        ax_main.scatter(y[i], trial_RTs[i], c=trial_color_dict[trial_IDs[i]], s=15)
    
    ax_main.set_xlabel('Trial number')
    ax_main.set_ylabel('RT (ms)')
    ax_main.set_ylim(-50, 1400)
    ax_main.set_title(f'Reaction times of {sub}')
    
    # Add vertical lines for block changes
    block_change_indices = [i for i in range(1, len(blocks)) if blocks[i] != blocks[i-1]]
    for index in block_change_indices:
        ax_main.axvline(x=index, color='grey', linestyle='--', linewidth=0.5)
    
    # Top distribution - RT distribution by trial number (optional: density over time)
    trial_types = ['go_trial', 'stop_trial', 'go_continue_trial', 'go_fast_trial']
    
    # # Create stacked histogram on top showing trial type distribution over time
    # bins = np.linspace(0, len(trial_RTs), 20)
    # bottom = np.zeros(len(bins)-1)
    
    # for trial_type in trial_types:
    #     trial_indices = [i for i, t in enumerate(trial_IDs) if t == trial_type]
    #     hist, _ = np.histogram(trial_indices, bins=bins)
    #     ax_top.bar(bins[:-1], hist, width=bins[1]-bins[0], bottom=bottom, 
    #                color=trial_color_dict[trial_type], alpha=0.7, 
    #                label=trial_type.replace('_', ' ').title())
    #     bottom += hist
    
    # ax_top.set_ylabel('Trial Count')
    # ax_top.tick_params(labelbottom=False)
    
    # Right distribution - RT distribution for each trial type
    for i, trial_type in enumerate(trial_types):
        trial_rts = [trial_RTs[j] for j in range(len(trial_IDs)) 
                    if trial_IDs[j] == trial_type and trial_RTs[j] > 0]
        
        if len(trial_rts) > 0:
            # Create smooth distribution curve using seaborn's kdeplot
            import seaborn as sns
            
            # Create a temporary dataframe for seaborn
            temp_df = pd.DataFrame({'RT': trial_rts})
            
            # Plot horizontal KDE (kernel density estimation)
            sns.kdeplot(data=temp_df, y='RT', ax=ax_right, 
                    color=trial_color_dict[trial_type], 
                    alpha=0.7, linewidth=2,
                    label=trial_type.replace('_', ' ').title())

    ax_right.set_xlabel('Density')
    ax_right.tick_params(labelleft=False)
    #ax_right.legend(loc='upper right')
    
    # Create legend
    legend_handles = [plt.Line2D([0], [0], marker='o', color='w', 
                                label=trial_type.replace('_', ' ').title(), 
                                markerfacecolor=trial_color_dict[trial_type], 
                                markersize=10) for trial_type in trial_types]
    
    fig.legend(handles=legend_handles, loc='upper right', bbox_to_anchor=(0.95, 0.95))
    
    plt.show()

In [None]:
GC_minus_GO_all = []
SSRT_all = []
# Calculate GC_minus_GO and SSRT for each subject
for sub in included_subjects:
    GC_minus_GO = stats[sub]['go_continue_trial mean RT (ms)'] - stats[sub]['go_trial mean RT (ms)']
    SSRT = stats[sub]['SSRT (ms)']
    GC_minus_GO_all.append(GC_minus_GO)
    SSRT_all.append(SSRT)

print(GC_minus_GO_all)
print(SSRT_all)
plt.scatter(SSRT_all, GC_minus_GO_all, color='black')
plt.xlabel('SSRT (ms)')
plt.ylabel('GC - GO RT (ms)')

# perform person correlation
corr, p_value = scipy.stats.pearsonr(SSRT_all, GC_minus_GO_all)
print(f'Pearson correlation: {corr}, p-value: {p_value}')

# plot the correlation line
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(SSRT_all, GC_minus_GO_all)
x = np.array(SSRT_all)
y = slope * x + intercept
plt.plot(x, y, color='red', label='Fit line')
plt.title(f'Correlation: r={corr:.2f}, p={p_value:.3f}')

# 3. Plot the inhibition functions #

## 3.1. For each subject and condition separately : to get a first impression of stopping difficulty and effect of DBS per subject ##

In [None]:
grouped_df = utils.create_grouped_df_for_inhibitory_functions(
    included_subjects,
    stats
)

plotting.plot_inhibitory_function_per_subject(grouped_df, color_dict, behav_results_saving_path)

## 3.2. Using the ZRFT method (Z-score) ##

In [None]:
plotting.plot_inhibitory_function_per_subject_zscored(
    grouped_df,
    stats,
    color_dict,
    behav_results_saving_path
)

## 3.3. All groups plotted using the ZRFT method to compare across groups ##

In [None]:
plotting.plot_inhibitory_functions_per_groups(
        grouped_df,
        stats,
        color_dict,
        behav_results_saving_path
)


# 4. Looking at reaction times on unsuccessful stop trials depending on SSD #

In [None]:
rt_inhibition_df = utils.create_inhibition_df(
    included_subjects,
    stats
)

In [None]:
plotting.plot_reaction_time_relative_to_SSD(
        rt_inhibition_df,
        color_dict,
        behav_results_saving_path
)


# 5. Proactive inhibition #

## 5.1. Test if proactive inhibition is induced in all participants by comparing the reaction times for GO trials and GF trials ##

### 5.1.a. At the single subject and single session level ###

In [None]:
plotting.plot_go_gf_rt_single_sub(stats_OFF,
        stats_ON,
        stats_CONTROL,
        stats_PREOP,
        color_dict,
        behav_results_saving_path
        )

### 5.1.b. At the group level (just out of interest) ###

In [None]:
plotting.plot_go_gf_rt_group(
    stats_OFF,
    stats_ON,
    stats_CONTROL,
    stats_PREOP,
    color_dict,
    behav_results_saving_path,
    show_fig = True
)

## 5.2. Assess proactive inhibition in included subjects ##

### 5.2.1. Assess the effect of STN-DBS on proactive inhibition ###

In [None]:
df_reshaped = plotting.plot_prep_cost_on_vs_off_only_sub_with_2_sessions(
        stats_OFF,
        stats_ON,
        subject_colors,
        behav_results_saving_path,
        show_fig = True
)

In [None]:
plotting.plot_prep_cost_on_vs_off_all_sub(
        stats_OFF,
        stats_ON,
        subject_colors,
        behav_results_saving_path)

### 5.2.2. Assess if the value in DBS OFF can predict the change with DBS ON ###

In [None]:
df_reshaped_cleaned = df_reshaped.dropna()
pre_treatment = df_reshaped_cleaned["DBS OFF"].values
post_treatment = df_reshaped_cleaned["DBS ON"].values

# Compute the change
change = post_treatment - pre_treatment

# Create a DataFrame
df = pd.DataFrame({'Pre_treatment': pre_treatment, 'Change': change})

# Add a constant term for the intercept
X = sm.add_constant(df['Pre_treatment'])  # Predictor (Pre-treatment values)
y = df['Change']  # Dependent variable (Change)

# Fit the model
model = sm.OLS(y, X).fit()

# Print summary
print(model.summary())

# Scatter plot
plt.scatter(df['Pre_treatment'], df['Change'], label="Data")

# Plot regression line
x_vals = np.linspace(min(df['Pre_treatment']), max(df['Pre_treatment']), 100)
y_vals = model.params[0] + model.params[1] * x_vals
plt.plot(x_vals, y_vals, color='red', label="Regression Line")

plt.axhline(0, linestyle='--', color='gray')  # Reference line at y=0
plt.xlabel("Pre-treatment Value")
plt.ylabel("Change (Post - Pre)")
plt.legend()
plt.show()


### 5.2.3. Compare all conditions ###

In [None]:
df_proactive_all = plotting.plot_prep_cost_all_groups(
    stats_OFF,
    stats_ON,
    stats_CONTROL,
    stats_PREOP,
    color_dict,
    behav_results_saving_path,
    show_fig = True
)

# 6. Reactive inhibition #

In [None]:
plotting.plot_SSRT_on_vs_off_all_sub(
        stats_OFF,
        stats_ON,
        subject_colors,
        behav_results_saving_path,
        show_fig= True
)

In [None]:
df_reactive_all = plotting.plot_SSRT_all_groups(
    stats_OFF,
    stats_ON,
    stats_CONTROL,
    stats_PREOP,
    color_dict,
    behav_results_saving_path,
    show_fig = True        
)

In [None]:
df_proactive_all

In [None]:
df_reactive_all

# 7. Correlation proactive / reactive inhibition ? #

In [None]:
df_merged = utils.prepare_merged_dataframe(
    df_proactive_all,
    df_reactive_all,
    stats_OFF,
    stats_ON,
    stats_CONTROL,
    stats_PREOP,
    behav_results_saving_path
)

In [None]:
df_merged

In [None]:
# Copy df_merged but keep only subjects starting with "sub":
df_merged_subs = df_merged[df_merged['Subject'].str.startswith('sub')].copy()
df_merged_subs

# Add a new column 'Session' using a session_dict  :
df_merged_subs['Session'] = df_merged_subs['Subject'].map(session_dict)

df_merged_subs['BIS'] = df_merged_subs['Subject'].map(bis_score_dict)

df_merged_subs['BDI'] = df_merged_subs['Subject'].map(bdi_score_dict)

# Add a new column 'DBS' to df_merged_subs based on the 'Subject' column:
df_merged_subs['DBS'] = df_merged_subs['Subject'].apply(
    lambda x: 'ON' if 'ON' in x else 'OFF')

# in the Subject column, only keep th efirst part of the string before the space:
df_merged_subs['Subject'] = df_merged_subs['Subject'].apply(lambda x: x.split(' ')[0])

In [None]:
df_merged_subs

In [None]:
# Optional: center the scores
df_merged_subs['BIS_c'] = df_merged_subs['BIS'] - df_merged_subs['BIS'].mean()
df_merged_subs['BDI_c'] = df_merged_subs['BDI'] - df_merged_subs['BDI'].mean()

In [None]:
model = smf.mixedlm(
    formula="Q('SSRT (ms)') ~ DBS * BIS_c",
    data=df_merged_subs,
    groups="Subject"
)
result = model.fit()
print(result.summary())

In [None]:
df_on = df_merged_subs[df_merged_subs['DBS'] == 'OFF']

model_on = smf.mixedlm(
    "Q('preparation cost (ms)') ~ BDI_c",  # only BDI score as fixed effect
    data=df_on,
    groups="Subject"            # random intercept per subject
)
result_on = model_on.fit()
print(result_on.summary())

# Extract variables
x = df_on['BDI']
y = df_on['preparation cost (ms)']

# Calculate Pearson correlation
r, p_val = scipy.stats.pearsonr(x, y)

# Fit linear regression manually to get slope and intercept
slope, intercept = np.polyfit(x, y, 1)

# Plot
plt.figure(figsize=(6, 4))
sns.regplot(x=x, y=y, ci=None, scatter_kws={'s': 60, 'color': 'black'}, line_kws={'color': 'red'})

# Set title with r and regression equation
plt.title(f'Preparation cost vs BDI (DBS OFF)\n'
          f'r = {r:.2f}, p = {p_val:.3f}')

plt.xlabel('BDI Score')
plt.ylabel('Preparation score in DBS OFF (ms)')
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
df_on = df_merged_subs[df_merged_subs['DBS'] == 'ON']

model_on = smf.mixedlm(
    "Q('SSRT (ms)') ~ BIS_c",  # only BIS score as fixed effect
    data=df_on,
    groups="Subject"            # random intercept per subject
)
result_on = model_on.fit()
print(result_on.summary())

# Extract variables
x = df_on['BIS']
y = df_on['SSRT (ms)']

# Calculate Pearson correlation
r, p_val = scipy.stats.pearsonr(x, y)

# Fit linear regression manually to get slope and intercept
slope, intercept = np.polyfit(x, y, 1)

# Plot
plt.figure(figsize=(6, 4))
sns.regplot(x=x, y=y, ci=None, scatter_kws={'s': 60, 'color': 'black'}, line_kws={'color': 'red'})

# Set title with r and regression equation
plt.title(f'SSRT vs BIS (DBS ON)\n'
          f'r = {r:.2f}, p = {p_val:.3f}')

plt.xlabel('BIS Score')
plt.ylabel('SSRT in DBS ON (ms)')
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# Make sure DBS and Session are treated as categorical variables
df_merged_subs['DBS'] = df_merged_subs['DBS'].astype('category')
df_merged_subs['Session'] = df_merged_subs['Session'].astype('category')

# Fit the model
model = smf.mixedlm(
    formula="Q('preparation cost (ms)') ~ DBS + Session",  # fixed effects
    data=df_merged_subs,
    groups="Subject"  # random intercept per subject
)
result = model.fit()

# Print the summary
print(result.summary())


Term	        Coef.   	p-value 	Interpretation

Intercept	    45.60	    0.121	    Baseline prep cost (DBS=OFF, Session=1) — not significantly different from 0.

DBS[T.ON]	    4.17	    0.797	    DBS ON increases prep cost by ~4.2 ms vs OFF — not significant.

Session[T.2]	51.74   	0.001	    Session 2 increases prep cost by ~52 ms vs Session 1 — statistically significant.


✅ Only Session has a significant effect on the preparation cost.

Preparation cost is higher in the second session, regardless of DBS status.

DBS has a negligible and non-significant effect.

In [None]:
# Fit the model
model = smf.mixedlm(
    formula="Q('SSRT (ms)') ~ DBS + Session",  # fixed effects
    data=df_merged_subs,
    groups="Subject"  # random intercept per subject
)
result = model.fit()

# Print the summary
print(result.summary())

ChatGPT: We fitted a linear mixed model with SSRT as the dependent variable, DBS and Session as fixed effects, and Subject as a random effect. Neither DBS (p = 0.66) nor Session (p = 0.69) significantly affected SSRT. However, subject-level variance was high, suggesting large individual differences in SSRT performance.

In [None]:
# Fit the model
model = smf.mixedlm(
    formula="Q('mean SSD (ms)') ~ DBS + Session",  # fixed effects
    data=df_merged_subs,
    groups="Subject"  # random intercept per subject
)
result = model.fit()

# Print the summary
print(result.summary())

In [None]:
plotting.plot_corr_prep_cost_SSRT(df_merged, behav_results_saving_path, show_fig=True)

# 8. Correlation SSD and SSRT in reactive inhibition? #

In [None]:
plotting.plot_corr_SSD_SSRT(df_merged, behav_results_saving_path, show_fig=True)

# 9. Effect of STN-DBS on Success Rate #

In [None]:
plotting.plot_dbs_effect_success_rate_single_sub(
        stats_OFF,
        stats_ON,
        behav_results_saving_path,
        show_fig=True
)

In [None]:
trial_type = ['GO', 'GC', 'GF', 'Go-STOP']
for trial in trial_type:
    plotting.plot_percent_success_on_vs_off(
            stats_OFF=stats_OFF,
            stats_ON=stats_ON,
            trial_type=trial,
            subject_colors=subject_colors,
            behav_results_saving_path=behav_results_saving_path,
                show_fig=True
    )

In [None]:
stats_dbs = {}
subject_ids = [key.split(' ')[0] for key in stats_OFF.keys() if 'OFF' in key]

for sub in subject_ids: 
    stats_dbs[sub] = {
        'OFF': stats_OFF[f'{sub} DBS OFF mSST'],
        'ON': stats_ON[f'{sub} DBS ON mSST'],
    }

In [None]:
# Prepare the data in a long format suitable for seaborn's violin plot
plot_data = []

# Define trial types and bar width
trial_types = ['go_trial', 'stop_trial', 'go_fast_trial', 'go_continue_trial']
bar_width = 0.3
index = np.arange(len(trial_types))
opacity = 0.8

# Create an empty dataframe to store success rates for each participant
columns = ['subject_id', 'trial_type', 'off_percent', 'on_percent']
success_df = pd.DataFrame(columns=columns)

# Access data for a single subject in OFF and ON conditions
for subject_id in stats_dbs.keys():
    stats_dbs_new = {
        'OFF': stats_dbs[subject_id]['OFF'],
        'ON': stats_dbs[subject_id]['ON']
    }
    print(stats_dbs_new)

    # Retrieve values for each trial type in both conditions
    off_values = [
        stats_dbs_new['OFF']['percent correct go_trial'],
        stats_dbs_new['OFF']['percent correct stop_trial'],
        stats_dbs_new['OFF']['percent correct go_fast_trial'],
        stats_dbs_new['OFF']['percent correct go_continue_trial']
    ]

    on_values = [
        stats_dbs_new['ON']['percent correct go_trial'],
        stats_dbs_new['ON']['percent correct stop_trial'],
        stats_dbs_new['ON']['percent correct go_fast_trial'],
        stats_dbs_new['ON']['percent correct go_continue_trial']
    ]

    # Add the subject data to the dataframe
    for i, trial_type in enumerate(trial_types):
        success_df = success_df.append({
            'subject_id': subject_id,
            'trial_type': trial_type,
            'off_percent': off_values[i],
            'on_percent': on_values[i]
        }, ignore_index=True)

for trial_type in trial_types:
    trial_data = success_df[success_df['trial_type'] == trial_type]

    # Add data for OFF and ON conditions in the long format
    trial_data_off = trial_data[['subject_id', 'off_percent']].rename(columns={'off_percent': 'percent_correct'})
    trial_data_off['condition'] = 'OFF'
    
    trial_data_on = trial_data[['subject_id', 'on_percent']].rename(columns={'on_percent': 'percent_correct'})
    trial_data_on['condition'] = 'ON'
    
    # Combine both conditions into one DataFrame
    trial_data_combined = pd.concat([trial_data_off, trial_data_on])

    # Add trial type information for plotting
    trial_data_combined['trial_type'] = trial_type
    
    plot_data.append(trial_data_combined)

# Combine all trial data into one dataframe
plot_data = pd.concat(plot_data)

# Set up the plot
plt.figure(figsize=(12, 10))

# Create a violin plot for each trial type
sns.violinplot(x='trial_type', y='percent_correct', hue='condition', data=plot_data, split=True, 
               palette={'OFF': '#20a39e', 'ON': '#ef5b5b'}, alpha = 0.2, inner='quart', linewidth=1.25)

# Initialize lists for legend handles and labels
handles = []
labels = []

# Add colored dots for each participant
for i, trial_type in enumerate(trial_types):
    trial_data_for_dots = plot_data[plot_data['trial_type'] == trial_type]
    
    # Create a color map for each participant
    subject_colors = {subject: sns.color_palette("deep", len(trial_data_for_dots['subject_id'].unique()))[i] 
                      for i, subject in enumerate(trial_data_for_dots['subject_id'].unique())}

    for subject_id, color in subject_colors.items():
        subject_data = trial_data_for_dots[trial_data_for_dots['subject_id'] == subject_id]
        
        # Extract OFF and ON data points
        off_value = subject_data[subject_data['condition'] == 'OFF']['percent_correct'].values
        on_value = subject_data[subject_data['condition'] == 'ON']['percent_correct'].values
        
        # Offset x positions for visual clarity
        j = np.random.uniform(-0.1, 0.1)  # Random offset for each subject
        x_pos = [i - 0.15 + j, i + 0.15 + j]
        
        # Scatter plot for each participant's result
        scatter = plt.scatter(x_pos, [off_value, on_value], color=color, edgecolors='black', s=100)
        
        # **NEW: Add line connecting OFF and ON dots for each subject**
        plt.plot(x_pos, [off_value, on_value], color=color, alpha=0.7, linestyle='-', linewidth=1)

        # Add to the legend (only add each subject once)
        if subject_id not in labels:
            handles.append(scatter)
            labels.append(subject_id)

# Add t-test results to each subplot
for i, trial_type in enumerate(trial_types):
    trial_data_for_ttest = plot_data[plot_data['trial_type'] == trial_type]
    # t_stat, p_value = scipy.stats.ttest_rel(
    #     trial_data_for_ttest[trial_data_for_ttest['condition'] == 'OFF']['percent_correct'],
    #     trial_data_for_ttest[trial_data_for_ttest[ 'condition'] == 'ON']['percent_correct']
    # )
    test_result, p_value = scipy.stats.wilcoxon(
        trial_data_for_ttest[trial_data_for_ttest['condition'] == 'OFF']['percent_correct'],
        trial_data_for_ttest[trial_data_for_ttest[ 'condition'] == 'ON']['percent_correct']
        )
    
    plt.text(i, 105, f"statistic = {test_result:.3f}\npval = {p_value:.3f}", 
             horizontalalignment='center', fontsize=12, verticalalignment='bottom')

# Set labels and title
plt.xlabel('Trial Type', fontsize=14)
plt.ylabel('Percent Correct', fontsize=14)
plt.title('Comparison of Performance Between OFF and ON Conditions', fontsize=16)

# Custom legend for OFF and ON condition colors
from matplotlib.patches import Patch
condition_legend_handles = [
    Patch(color='#20a39e', label='OFF'),
    Patch(color='#ef5b5b', label='ON')
]

# Create two legends: one for OFF/ON conditions, one for subjects
legend1 = plt.legend(handles=condition_legend_handles, title="Condition", loc='upper right', fontsize=12)
plt.gca().add_artist(legend1)  # Ensure the first legend stays

# Second legend for subject IDs
plt.legend(handles=handles, labels=labels, title="Subject ID", bbox_to_anchor=(1.05, 1), loc='upper left')

# Show plot
plt.tight_layout()
plt.show()


# 10. Effect of DBS on Reaction Time #

In [None]:
trial_type = ['GO', 'GC', 'GF', 'Go-STOP']
for trial in trial_type:
    plotting.plot_reaction_time_on_vs_off(
            stats_OFF=stats_OFF,
            stats_ON=stats_ON,
            trial_type=trial,
            subject_colors=subject_colors,
            behav_results_saving_path=behav_results_saving_path,
        show_fig=True
    )

In [None]:
plotting.plot_dbs_effect_reaction_time_single_sub(
        stats_OFF,
        stats_ON,
        behav_results_saving_path,
        show_fig=True
)

In [None]:
plotting.plot_dbs_effect_on_rt_all_sub_with_2_sessions_all_trial_types(
        stats_OFF,
        stats_ON,
        subject_colors,
        behav_results_saving_path,
        show_fig=True
)

# 7. Test if RT during GO trials correlates with SSRT # (should not!)

In [None]:
plotting.plot_corr_gort_ssrt(
    stats,    
    behav_results_saving_path,
    show_fig=True
)

In [None]:
plotting.plot_early_press_on_vs_off(
        stats_OFF,
        stats_ON,
        subject_colors,
        behav_results_saving_path,
        show_fig=True
)