In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm

import scipy
from scipy import stats
from scipy import optimize
from scipy import integrate
from scipy.special import erf
import scikits.bootstrap as bootstrap
import scipy.optimize

from matplotlib import pyplot as plt
from matplotlib.patches import Circle
import seaborn as sns

import psychofit as psy
import functools
from psychofit import mle_fit_psycho, erf_psycho_2gammas, neg_likelihood


  __import__("pkg_resources").declare_namespace(__name__)


In [2]:
# select only easy trials: 
def easy_trials(trials_1): 
    """
    Select only the easy trials, the trials with high contrast >= 0.5. 
     ----------
    trials_1 : pd.DataFrame
        DataFrame containing 'contrastLeft' and 'contrastRight' columns, where contrastLeft is negative. 

    Returns
    -------
    pd.DataFrame
        Subset of trials_1 with only easy trials.
    """
    trials_df = trials_1[(trials_1['contrastLeft'].isin([-1, -0.5])) | (trials_1['contrastRight'].isin([1, 0.5]))]
    return trials_df


In [3]:
# calculate the proprtion correct for each session across all mice
def proportion_correct(trials_df): 
    """
    calculate the proprtion correct for each session across all mice
    ----------
    Parameter: 
    trials_df: panda_df with the columns ['subject'], ['session_num'], [feedbackType'], ['stage'], ['NM']

    Output: 
    result: panda_df with the columns: ['NM'], ['subject'], ['session_num'], ['stage'], ['n_trials'], 
                                       ['n_correct'], ['proportion_correct']
    """
        
    df_subset = trials_df[['subject', 'session_num', 'feedbackType', 'stage','NM']]
    
    grouped = df_subset.groupby(['NM','subject', 'session_num','stage',])

    result = grouped.agg(
        n_trials=('feedbackType', 'count'),
        n_correct=('feedbackType', lambda x: (x == 1).sum())
    )
    
    result['proportion_correct'] = result['n_correct'] / result['n_trials']
    result = result.reset_index()
    
    return result

In [5]:
def plot_training_by_genotype(result):
    """
    fig 2 d "Standardized and reproducible measurement of decision-making in mice IBL 2021"
    Displays a matplotlib plot of mean performance across training for each genotype.

    Parameters
    ----------
    result : pd.DataFrame
        Must contain at least the columns ['stage', 'subject', 'genotype',
        'session_num', 'proportion_correct'] and include a 'training' stage.
    """
    import matplotlib.pyplot as plt
    import seaborn as sns

    # Filter to training stage
    df_train = result[result['stage'] == 'training'].copy()

    # Compute training day number per subject
    df_train['session_final'] = (
        df_train.groupby('subject')['session_num']
                .transform(lambda x: pd.factorize(x)[0] + 1)
    )

    # Plot
    fig, ax = plt.subplots(figsize=(8, 4))
    sns.set_palette(sns.color_palette("hls", 8))

    sns.lineplot(x='session_final', y='proportion_correct', data=df_train[df_train['NM'] == '5HT'], 
                 color='#8866BC', label='Serotonin', ci= None)
    sns.lineplot(x='session_final', y='proportion_correct', data=df_train[df_train['NM'] == 'NE'], 
                 color='#35A2B6', label='Norepinephrine', ci= None)
    sns.lineplot(x='session_final', y='proportion_correct', data=df_train[df_train['NM'] == 'DA'], 
                 color='#DF4661', label='Dopamine', ci= None)
    sns.lineplot(x='session_final', y='proportion_correct', data=df_train[df_train['NM'] == 'ACh'], 
                 color='#009F61', label='Acetylcholine', ci= None)
    sns.lineplot(x='session_final', y='proportion_correct', data=df_train[df_train['NM'] == 'WT'], 
                 color='#999999', label='Wild Type', ci= None)
    
     # Global average line across all subjects/genotypes
    overall_mean = df_train.groupby('session_final')['proportion_correct'].mean().reset_index()
    plt.plot(overall_mean['session_final'], overall_mean['proportion_correct'],
             color='black', linewidth=1.5, linestyle='-', label='Average across genotypes')

    plt.legend(
    title='Genotypes',
    loc='upper left',
    bbox_to_anchor=(1.05, 1),
    borderaxespad=0,
    frameon=False
    )
    # plt.tight_layout(rect=[0, 0, 0.75, 1]) 
    # plt.legend(title='Genotypes')
    plt.ylim(0, 1)
    plt.axhline(0.5, linestyle='--', c='0.5')
    plt.axhline(0.7, linestyle='--', c='0.5')
    plt.ylabel('Proportion Correct', fontsize=15)
    plt.xlabel('Days', fontsize=15)
    plt.tight_layout()


In [7]:
def plot_bias_performance(result):
    """
    Plot the performance across bias sessions for each subject.

    This function filters the input DataFrame `result` to the 'bias' stage,
    computes a consecutive bias day number per subject, and then produces
    a line plot of proportion correct over bias days for each subject and overall trend.

    Parameters
    ----------
    result : pd.DataFrame
        DataFrame containing at least the columns ['stage', 'subject', 'session_num', 'proportion_correct'].
        'stage' should include a 'bias' category.

    Returns
    -------
    Displays a matplotlib plot of performance across bias sessions.
    """
    sert_b = result[result['stage'] == 'bias']
    sert_b['session_num_bias'] = (
        sert_b.groupby('subject')['session_num']
             .transform(lambda x: pd.factorize(x)[0] + 1)
    )
    plt.figure()
    sns.lineplot(x='session_num_bias',y='proportion_correct',hue='subject', data=sert_b, palette = ['0.6'], legend = False)
    sns.lineplot(x='session_num_bias',y='proportion_correct', data=sert_b, color = '0.2', legend = False)

    plt.title('Performance of mice across Bias world')
    plt.ylim(0,1)
    plt.axhline(0.5, linestyle = '--', c = '0.5')
    plt.axhline(0.7, linestyle = '--', c = '0.5')
    plt.ylabel('proportion correct', fontsize='15')
    plt.xlabel('day', fontsize='15')
    plt.tight_layout()
    sns.despine()

In [8]:
def get_last_3_training(trials_df):
    """
    Extract the last three training sessions for each subject prior to their first bias session.

    Parameters
    ----------
    trials_df : pd.DataFrame
        DataFrame containing at least ['subject', 'stage', 'session_num'] columns.

    Returns
    -------
    pd.DataFrame
        Concatenated DataFrame of trials from the last three training sessions before bias for each subject.
    """
    # Ensure proper sorting
    trials_sorted = trials_df.sort_values(by=["subject", "session_num"]).reset_index(drop=True)
    result_list = []
    # Iterate through subjects
    for subject, sub_df in trials_sorted.groupby('subject'):
        # Find first bias session index
        bias_indices = sub_df.index[sub_df['stage'] == 'bias']
        if bias_indices.empty:
            continue
        first_bias_loc = bias_indices[0]
        # Data before bias
        before_bias = sub_df.iloc[:first_bias_loc]
        # Training sessions prior to bias
        training_sessions = before_bias[before_bias['stage'] == 'training']['session_num'].unique()
        last_3 = training_sessions[-3:]
        # Append trials from last three training sessions
        result_list.append(before_bias[before_bias['session_num'].isin(last_3)])
    # Combine and reset index
    if result_list:
        result_df = pd.concat(result_list, ignore_index=True)
    else:
        result_df = pd.DataFrame(columns=trials_df.columns)
    return result_df


In [9]:
def plot_last_3_training(result):
    """
    Plot the mean of the last 3 training days before changing to bias, across Genotype.
   
    Parameters
    ----------
    result : pd.DataFrame
        DataFrame of only the last 3 days of traingng befrore switching to the bias world. 
        Df must contain at least ['subject', 'genotype', 'session_num', 'proportion_correct'].

    Returns
    -------
    Displays a matplotlib figure showing performance per genotype over sessions.
    """
    # Compute session_final per subject
    result['session_final'] = result.groupby('subject').cumcount() + 1

    # Plot the proportion_correct for the last 3 days of training across genotypes
    plt.figure()
    sns.set_palette(sns.color_palette("hls", 8))
    sns.lineplot(x='session_final', y='proportion_correct', data=result[result['NM'] == '5HT'], 
                 color='#8866BC', label='SERT')
    # sns.lineplot(x='session_final', y='proportion_correct', data=result[result['NM'] == 'NE'], 
    #              color='grey', label='TH')
    sns.lineplot(x='session_final', y='proportion_correct', data=result[result['NM'] == 'NE'], 
                 color='#35A2B6', label='Dbh')
    sns.lineplot(x='session_final', y='proportion_correct', data=result[result['NM'] == 'DA'], 
                 color='#DF4661', label='DAT')
    sns.lineplot(x='session_final', y='proportion_correct', data=result[result['NM'] == 'ACh'], 
                 color='#009F61', label='ChAT')
    sns.lineplot(x='session_final', y='proportion_correct', data=result[result['NM'] == 'WT'], 
                 color='black', label='ChAT')

    plt.title('Last 3 days of training before Bias across all genotypes')
    plt.legend(title='Genotypes')
    plt.ylim(0, 1)
    plt.axhline(0.5, linestyle='--', c='0.5')
    plt.axhline(0.7, linestyle='--', c='0.5')
    plt.ylabel('Proportion Correct', fontsize='15')
    plt.xlabel('Days', fontsize='15')
    plt.tight_layout()
    plt.show()


In [10]:
def plot_transition_to_bias(result):
    """
    Plot the number of sessions it takes for mice to transition from training to bias by genotype.

    Parameters
    ----------
    result : pd.DataFrame
    DataFrame of only the last 3 days of traingng befrore switching to the bias world.
    Df must contain at least ['subject', 'genotype', 'session_num'].

    Returns
    -------
    Displays a matplotlib figure with boxplots and stripplots for session transition.
    """
    # Compute session_final per subject
    result['session_final'] = result.groupby('subject').cumcount() + 1
    # plotting the time the mice switch from training to bias across genotypes
    plt.figure()
    sns.set_palette(['#009F61', '#8866BC', '#DF4661', '#35A2B6', 'grey'])

    sns.boxplot(
    data=result[result['session_final'] == 3],
    x='NM',
    y='session_num',
    width=0.2,
    showcaps=True,
    boxprops=dict(facecolor='none', edgecolor='#008CC1', linewidth=1.5, zorder=1),
    whiskerprops=dict(color='#008CC1', linewidth=1, alpha=0.5, zorder=1),
    capprops=dict(color='#008CC1', alpha=0.5, zorder=1),
    medianprops=dict(color='#008CC1', linewidth=2, zorder=2),
    flierprops=dict(marker='o', markersize=4, linestyle='none', color='#008CC1', alpha=0.3)
    )
    
    plt.xlabel('NM', fontsize='12')
    plt.ylabel('Session number', fontsize='12')
    plt.title('Number of sessions for a mice to transition from training to bias task')

    # Scatterplot (overlay)
    sns.stripplot(
        data=result[result['session_final'] == 3],
        x='NM',
        y='session_num',
        hue='NM',
        dodge=True,
        jitter=True,
        alpha=0.8,
        zorder=3
    )
    plt.tight_layout()
    plt.show()

In [11]:
def plot_training_performance(result, color='0.2'):
    """
    Plot the performance across training sessions for each subject.

    Parameters
    ----------
    result : pd.DataFrame
        DataFrame containing at least the columns ['stage', 'subject', 'session_num', 'proportion_correct'].
    mean_color : str
        Color for the mean performance line (default: '0.2' = dark grey)

    Returns
    -------
    Displays a matplotlib plot of performance across training.
    """

    sert_t = result[result['stage'] == 'training'].copy()


    sert_t['session_num_training'] = (
        sert_t.groupby('subject')['session_num']
              .transform(lambda x: pd.factorize(x)[0] + 1)
    )

    # Plotting
    plt.figure(figsize=(4, 5))
    sns.lineplot(
        x='session_num_training', y='proportion_correct', hue='subject',
        data=sert_t, palette=['0.6'], legend=False, alpha= 0.7
    )
    sns.lineplot(
        x='session_num_training', y='proportion_correct',
        data=sert_t, color=color, ci=None, label='WT', alpha=1
    )

    plt.title('Performance of mice across training')
    plt.ylim(0, 1)
    plt.axhline(0.5, linestyle='--', c='0.5')
    plt.axhline(0.7, linestyle='--', c='0.5')
    plt.ylabel('Performance on easy trials', fontsize=15)
    plt.xlabel('Training day', fontsize=15)
    plt.tight_layout()
    sns.despine()

All the functions to peform Psychometric analysis on the IBL photometry data, using the same fit as in the "Standardised and reproducible measurement of decision-making in mice" paper. 

In [29]:
def prob_choose_right(df):
    """
    Summarize right‐choice probabilities, treating 'training' and 'bias' stages differently.

    - In 'training': group by NM (neuromodulator), subject, stage, session_num, contrast (ignore probabilityLeft)
      and set probabilityLeft=NA in the output.
    - In 'bias': group by NM, subject, stage, session_num, contrast, probabilityLeft.

    Returns a DataFrame with columns
    ['genotype','subject','session_num','contrast','probabilityLeft',
     'stage','right_choice_count','total_trials','probability_right'].
     
    """
    # ---- TRAINING ----
    df_train = df[df['stage'] == 'training']

    grp_t = (
        df_train
        .groupby(['NM','subject','stage','session_num','contrast','choice'])
        .size()
        .reset_index(name='trials')
    )

    out_t = (
        grp_t
        .groupby(['NM','subject','stage','session_num','contrast'])
        .apply(lambda g: pd.Series({
            'right_choice_count': g.loc[g['choice']== -1, 'trials'].sum(),
            'total_trials':         g['trials'].sum(),
            'prob_choose_right':    g.loc[g['choice']== - 1, 'trials'].sum() / g['trials'].sum()
        }))
        .reset_index()
    )
    # force probabilityLeft to NA
    out_t['probabilityLeft'] = np.nan

    # ---- BIAS ----
    df_bias = df[df['stage'] == 'bias']
    grp_b = (
        df_bias
        .groupby([
            'NM','subject','stage','session_num',
            'contrast','probabilityLeft','choice'
        ])
        .size()
        .reset_index(name='trials')
    )
    out_b = (
        grp_b
        .groupby([
            'NM','subject','stage','session_num',
            'contrast','probabilityLeft'
        ])
        .apply(lambda g: pd.Series({
            'right_choice_count': g.loc[g['choice']== - 1, 'trials'].sum(),
            'total_trials':         g['trials'].sum(),
            'prob_choose_right':    g.loc[g['choice']== - 1, 'trials'].sum() / g['trials'].sum()
        }))
        .reset_index()
    )

    # Combine
    out = pd.concat([out_t, out_b], ignore_index=True, sort=False)
    out = out[[
        'NM','subject','session_num','contrast','probabilityLeft',
        'stage','right_choice_count','total_trials','prob_choose_right'
    ]]
    return out

In [16]:
def sliding_windows(df, window_size=60, step_size=10):
    """
    Create overlapping windows of size `window_size` stepping by `step_size`.
    Applies sliding windows for one mice acrross the whole df. 

    Parameters
    ----------
    df : pd.DataFrame
        Input DataFrame to window (e.g., time series trials).
    window_size : int
        Number of rows in each window.
    step_size : int
        Step increment between window start indices.

    Returns
    -------
    pd.DataFrame
        Concatenated DataFrame with an added 'group' column indicating window number.
    """
    windows = []
    group_number = 1
    for start in range(0, len(df) - window_size + 1, step_size):
        end = start + window_size
        window = df.iloc[start:end].copy()
        window['group'] = group_number
        windows.append(window)
        group_number += 1
    return pd.concat(windows, ignore_index=True)

In [19]:
def fit_groups(psycho_df: pd.DataFrame) -> pd.DataFrame:
    """
    Fit psychometric parameters for each sliding-window group using the erf_psycho_2gammas model
    and the same parameter settings as compute_psych_pars.

    Parameters
    ----------
    psycho_df : pd.DataFrame
        DataFrame containing columns ['group', 'contrast', 'n_trials', 'prob_choose_right'].

    Returns
    -------
    pd.DataFrame
        DataFrame with columns ['group', 'bias', 'threshold', 'lapse_low', 'lapse_high', 'log_likelihood'].
    """
    results = []
    P_model = 'erf_psycho_2gammas'
    nfits = 5

    for group in psycho_df['group'].unique():
        group_df = psycho_df[psycho_df['group'] == group]

        # Check if group has enough rows
        if group_df.empty:
            print(f"Skipping group {group}: no data")
            continue

        # Check if the y-data has non-NaN values
        if group_df['prob_choose_right'].isnull().all():
            print(f"Skipping group {group}: all prob_choose_right are NaN")
            continue

        # Check if x and y lengths match
        if len(group_df['contrast']) != len(group_df['prob_choose_right']):
            print(f"Skipping group {group}: mismatch between contrast and prob_choose_right lengths")
            continue

        # Build data matrix
        contrasts_pct = group_df['contrast'].values.astype(float) * 100
        n_trials = group_df['n_trials'].values.astype(int)
        p_right = group_df['prob_choose_right'].values.astype(float)
        data_matrix = np.vstack([contrasts_pct, n_trials, p_right])

        # Initialize parameters
        parstart = np.array([np.mean(contrasts_pct), 20.0, 0.05, 0.05])
        parmin   = np.array([np.min(contrasts_pct),     0.0,  0.0,  0.0])
        parmax   = np.array([np.max(contrasts_pct),   100.0,  1.0,  1.0])

        # Fit with mle_fit_psycho
        pars, L = mle_fit_psycho(
            data_matrix,
            P_model=P_model,
            parstart=parstart,
            parmin=parmin,
            parmax=parmax,
            nfits=nfits
        )

        bias, threshold, lapse_low, lapse_high = pars

        results.append({
            'group': group,
            'bias': bias,
            'threshold': threshold,
            'lapse_low': lapse_low,
            'lapse_high': lapse_high,
            'log_likelihood': L
        })

    return pd.DataFrame(results)

In [20]:
def strip_plot_fit_parameters(fit_results_df):
    """
    Plot psychometric fit parameters across groups using strip plots.

    This function takes a DataFrame of fitted psychometric parameters (columns 'group', 'bias', 'threshold', 'lapse_low', 'lapse_high')
    and visualizes each parameter ['bias', 'threshold', 'lapse_low', 'lapse_high'] in its own subplot. Strip plots display individual
    group-level parameter estimates side by side for comparison across groups.

    Parameters
    ----------
    fit_results_df : pd.DataFrame
        DataFrame containing the columns:
          - 'group': identifier for each sliding window or mouse group.
          - 'bias'
          - 'threshold'
          - 'lapse_low'
          - 'lapse_high'

    Returns
    -------
    A matplotlib figure with four strip plots for parameters ['bias', 'threshold', 'lapse_low', 'lapse_high'].
    """
    df_melted = pd.melt(fit_results_df,
                        id_vars=('group'),
                        value_vars=['bias', 'threshold', 'lapse_low', 'lapse_high'],
                        var_name='parameter_type',
                        value_name='parameter_value')

    params = ['bias', 'threshold', 'lapse_low', 'lapse_high']
    fig, axes = plt.subplots(1, 4, figsize=(10, 4), sharey=False)

    for i, param in enumerate(params):
        subset = df_melted[df_melted['parameter_type'] == param]
        
        sns.stripplot(
            data=subset,
            x='parameter_type',
            y='parameter_value',
            hue='group',
            ax=axes[i],
            dodge=True,
            jitter=True,
            alpha=0.8,
            zorder=3  # Ensures it's on top of the boxplot
        )

        axes[i].set_title(f'Parameter: {param}')
        axes[i].set_xlabel("")
        axes[i].set_xticks([])  # Optional, to clean the x-axis
        axes[i].legend().remove()

    # Global legend
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.tight_layout()
    plt.suptitle('Example Mice Psychometric Fit Parameters', y=1.08)
    plt.show()

In [17]:
def get_first_5_training(trials_df):
    """
    Extract the first five training sessions for each subject, regardless of bias.

    Parameters
    ----------
    trials_df : pd.DataFrame
        Must contain at least ['subject', 'stage', 'session_num'] columns.

    Returns
    -------
    pd.DataFrame
        All rows from the first five 'training' sessions of each subject.
    """
    # Sort once so unique() later preserves ascending session_num
    trials_sorted = trials_df.sort_values(['subject', 'session_num']).reset_index(drop=True)
    result_list = []

    for subject, sub_df in trials_sorted.groupby('subject'):
        # Find each subject's unique training session numbers (in order)
        training_sessions = sub_df.loc[
            sub_df['stage'] == 'training', 'session_num'
        ].unique()
        
        # Take the first five of those
        first5 = training_sessions[:5]

        # Keep only rows in those first five training days
        sel = sub_df[
            (sub_df['stage'] == 'training') &
            (sub_df['session_num'].isin(first5))
        ]
        result_list.append(sel)

    # Concatenate (or return empty with same columns)
    if result_list:
        return pd.concat(result_list, ignore_index=True)
    else:
        return pd.DataFrame(columns=trials_df.columns)

In [3]:
def get_last_5_training(trials_df):
    """
    Extract the last three training sessions for each subject prior to their first bias session.

    Parameters
    ----------
    trials_df : pd.DataFrame
        DataFrame containing at least ['subject', 'stage', 'session_num'] columns.

    Returns
    -------
    pd.DataFrame
        Concatenated DataFrame of trials from the last three training sessions before bias for each subject.
    """
    # Ensure proper sorting
    trials_sorted = trials_df.sort_values(by=["subject", "session_num"]).reset_index(drop=True)
    result_list = []
    # Iterate through subjects
    for subject, sub_df in trials_sorted.groupby('subject'):
        # Find first bias session index
        bias_indices = sub_df.index[sub_df['stage'] == 'bias']
        if bias_indices.empty:
            continue
        first_bias_loc = bias_indices[0]
        # Data before bias
        before_bias = sub_df.iloc[:first_bias_loc]
        # Training sessions prior to bias
        training_sessions = before_bias[before_bias['stage'] == 'training']['session_num'].unique()
        last_5 = training_sessions[-5:]
        # Append trials from last three training sessions
        result_list.append(before_bias[before_bias['session_num'].isin(last_5)])
    # Combine and reset index
    if result_list:
        result_df = pd.concat(result_list, ignore_index=True)
    else:
        result_df = pd.DataFrame(columns=trials_df.columns)
    return result_df

In [21]:
def fit_by_subject_session(psycho_df_f: pd.DataFrame) -> pd.DataFrame:
    """
    Fit psychometric parameters for each subject and session using the erf_psycho_2gammas model
    and the same parameter settings as compute_psych_pars.

    Parameters
    ----------
    psycho_df_f : pd.DataFrame
        DataFrame containing columns ['subject', 'session_num', 'contrast', 'n_trials', 'prob_choose_right'].

    Returns
    -------
    pd.DataFrame
        DataFrame with columns ['subject', 'session_num', 'bias', 'threshold', 'lapse_low', 'lapse_high', 'log_likelihood'].
    """
    results = []
    P_model = 'erf_psycho_2gammas'
    nfits = 5
  
    # Group by subject and session
    for (subject, session), group_df in psycho_df_f.groupby(['subject', 'session_num']):
        # Validate data
        if group_df.empty:
          print(f"Skipping {subject}, session {session}: no data")
            continue
        if group_df['prob_choose_right'].isnull().all():
            print(f"Skipping {subject}, session {session}: all prob_choose_right NaN")
            continue
        if len(group_df['contrast']) != len(group_df['prob_choose_right']):
            print(f"Skipping {subject}, session {session}: contrast/prob mismatch")
            continue

        # Prepare data matrix
        contrasts_pct = group_df['contrast'].values.astype(float) * 100
        n_trials = group_df['n_trials'].values.astype(int)
        p_right = group_df['prob_choose_right'].values.astype(float)
        data_matrix = np.vstack([contrasts_pct, n_trials, p_right])

        # Parameter initialization
        parstart = np.array([np.mean(contrasts_pct), 20.0, 0.05, 0.05])
        parmin   = np.array([np.min(contrasts_pct),     0.0,  0.0,  0.0])
        parmax   = np.array([np.max(contrasts_pct),   100.0,  1.0,  1.0])

        # Fit with mle_fit_psycho
        pars, L = mle_fit_psycho(
            data_matrix,
            P_model=P_model,
            parstart=parstart,
            parmin=parmin,
            parmax=parmax,
            nfits=nfits
        )

        bias, threshold, lapse_low, lapse_high = pars

        results.append({
            'subject': subject,
            'session_num': session,
            'bias': bias,
            'threshold': threshold,
            'lapse_low': lapse_low,
            'lapse_high': lapse_high,
            'log_likelihood': L
        })

    return pd.DataFrame(results)

In [37]:
def fit_by_subject(psycho_df_f: pd.DataFrame) -> pd.DataFrame:
    """
    Fit psychometric parameters for each subject using the erf_psycho_2gammas model
    and the same parameter settings as compute_psych_pars. 

    Parameters
    ----------
    psycho_df_f : pd.DataFrame
        DataFrame containing columns ['subject', 'contrast', 'total_trials', 'prob_choose_right'].

    Returns
    -------
    pd.DataFrame
        DataFrame with columns ['subject', 'bias', 'threshold', 'lapse_low', 'lapse_high', 'log_likelihood'].
    """
    results = []

    # Define fixed settings matching compute_psych_pars
    P_model = 'erf_psycho_2gammas'
    nfits = 5

    for subject in psycho_df_f['subject'].unique():
        # Filter the DataFrame for the current subject
        df_sub = psycho_df_f[psycho_df_f['subject'] == subject]
        nm = psycho_df_f[psycho_df_f['NM'].iloc[0]]

        # Build the data matrix: [contrast_pct; n_trials; p_right]
        contrasts_pct = df_sub['contrast'].values.astype(float) * 100
        n_trials = df_sub['total_trials'].values.astype(int)
        p_right = df_sub['prob_choose_right'].values.astype(float)
        data_matrix = np.vstack([contrasts_pct, n_trials, p_right])

        # Set parameter start, min, and max exactly as in compute_psych_pars
        parstart = np.array([np.mean(contrasts_pct), 20.0, 0.05, 0.05])
        parmin   = np.array([np.min(contrasts_pct),     0.0,  0.0,  0.0])
        parmax   = np.array([np.max(contrasts_pct),   100.0,  1.0,  1.0])
        
        # Fit using male_fit_psycho
        pars, L = mle_fit_psycho(
            data_matrix,
            P_model=P_model,
            parstart=parstart,
            parmin=parmin,
            parmax=parmax,
            nfits=nfits
        )

        bias, threshold, lapse_low, lapse_high = pars

        # Collect results
        results.append({
            'subject': subject,
            'NM': nm,
            'bias': bias,
            'threshold': threshold,
            'lapse_low': lapse_low,
            'lapse_high': lapse_high,
            'log_likelihood': L
        })

    # Convert the results list into a DataFrame
    fit_results_df = pd.DataFrame(results)
    return fit_results_df


SyntaxError: invalid syntax (63268868.py, line 54)

In [23]:
def plot_subject_fit_parameters(fit_results_df):
    
    """
    Visualize psychometric fit parameters for each subject using boxplots and strip plots.

    This function melts the input DataFrame of fit parameters, then creates a 1x4 grid of subplots,
    each showing both a boxplot and overlaid stripplot for one parameter ('a','b','g','l').

    Parameters
    ----------
    fit_results_df : pd.DataFrame
        DataFrame containing columns ['subject', 'bias', 'threshold', 'lapse_low', 'lapse_high'].

    Returns
    -------
    Displays a matplotlib figure with boxplots and stripplots for each parameter.
    """
    
    df_melted = pd.melt(fit_results_df, 
                      id_vars=('subject'),
                      value_vars=['bias', 'threshold', 'lapse_low', 'lapse_high'],
                      var_name='parameter_type',
                      value_name='parameter_value')

    df_melted   
    params = ['bias', 'threshold', 'lapse_low', 'lapse_high']
    fig, axes = plt.subplots(1, 4, figsize=(10, 4), sharey=False)

    for i, param in enumerate(params):
        subset = df_melted[df_melted['parameter_type'] == param]

        # Boxplot (background)
        sns.boxplot(
            data=subset,
            x='parameter_type',
            y='parameter_value',
            ax=axes[i],
            width=0.2,
            color='#008CC1',
            showcaps=True,
            boxprops={'zorder': 1},
            whiskerprops={'zorder': 1},
            medianprops={'zorder': 2}
        )

        # Scatterplot (overlay)
        sns.stripplot(
            data=subset,
            x='parameter_type',
            y='parameter_value',
            hue='subject',
            ax=axes[i],
            dodge=True,
            jitter=True,
            alpha=0.8,
            zorder=3  # Ensures it's on top of the boxplot
        )

        axes[i].set_title(f'Parameter: {param}')
        axes[i].set_xlabel("")
        axes[i].set_xticks([])  # Optional, to clean the x-axis
        axes[i].legend().remove()

    # Global legend
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.tight_layout()
    plt.suptitle('SERTCre Mice Psychometric Fit Parameters', y=1.08)
    plt.show()


In [38]:
def genotype_map(final_df):
    """
    Annotate a trials DataFrame with genotype based on subject IDs. Kcénia's data.

    Parameters
    ----------
    final_df : pd.DataFrame
        DataFrame containing a 'subject' column.

    Returns
    -------
    pd.DataFrame
        The input DataFrame with an added 'genotype' column.
    """
    subject_map = {
        "SERT": ["ZFM-03059", "ZFM-05248", "ZFM-05245", "ZFM-05235", "ZFM-03061", "ZFM-03062", "ZFM-04392", "ZFM-05236", "ZFM-03065"],
        "TH": ["ZFM-04534", "ZFM-04533"],
        "ChAT": ["ZFM-06948", "ZFM-06305"],
        "DAT": ["ZFM-04022", "ZFM-04026", "ZFM-03448", "ZFM-03447", "ZFM-04019"],
        "Dbh": ["ZFM-06275", "ZFM-06268", "ZFM-06271", "ZFM-06272", "ZFM-06171"],
    }

    inverted_map = {}
    for genotype, ids in subject_map.items():
        for subject_id in ids:
            inverted_map[subject_id] = genotype

    final_df['genotype'] = final_df['subject'].map(inverted_map)
    return final_df

In [24]:
def plot_mean_psychometric_by_genotype(
    fit_df: pd.DataFrame,
    last5_df: pd.DataFrame,
    contrast_col: str = 'contrast',
    palette: dict = None,
    figsize: tuple = (7, 5),
    title: str = 'Mean psychometric function by genotype'
) -> None:
    """
    Plot one mean psychometric curve per genotype.

    Parameters
    ----------
    fit_df : pd.DataFrame
        DataFrame of fitted parameters with columns ['genotype','bias','threshold','lapse_low','lapse_high'].
    last5_df : pd.DataFrame
        DataFrame containing the contrast column (to define the contrast grid).
    contrast_col : str, optional
        Name of the contrast column in last5_df (default 'contrast').
    palette : dict, optional
        Mapping genotype -> color. If None, uses defaults.
    figsize : tuple, optional
        Figure size (default (7,5)).
    title : str, optional
        Plot title.
    """
   # default palette
    if palette is None:
        palette = {
            'ChAT' : '#009F61',
            'DAT'  : '#DF4661',
            'Dbh'  : '#35A2B6',
            'SERT' : '#8866BC',
            'TH'   : 'grey'
        }

    # define a fixed contrast grid
    xx = np.arange(-100, 100)
    contrast = xx

    plt.figure(figsize=figsize)
    for genotype, subdf in fit_df.groupby('genotype'):
        # stack each subject’s curve, then average
        all_ff = np.vstack([
            erf_psycho_2gammas(
                np.array([r['bias'], r['threshold'], r['lapse_low'], r['lapse_high']]),
                contrast
            )
            for _, r in subdf.iterrows()
        ])
        mean_ff = all_ff.mean(axis=0)

        color = palette.get(genotype, 'k')
        plt.plot(contrast, mean_ff, lw=2, color=color, label=genotype)

    plt.xlabel('Contrast (%)')
    plt.ylabel('Probability of choosing Right')
    plt.title(title)
    plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')
    plt.tight_layout()
    plt.show()


In [25]:
def plot_mean_psychometric_with_sem(
    fit_df: pd.DataFrame,
    last5_df: pd.DataFrame,
    contrast_col: str = 'contrast',
    palette: dict = None,
    figsize: tuple = (7, 5),
    title: str = 'Mean psychometric ± SEM by genotype',
    sem_alpha: float = 0.2
) -> None:
    """
    Plot one mean psychometric curve per genotype with SEM shading.

    Parameters
    ----------
    fit_df : pd.DataFrame
        DataFrame with columns ['genotype','bias','threshold','lapse_low','lapse_high'].
    last5_df : pd.DataFrame
        DataFrame containing the contrast column.
    contrast_col : str, optional
        Name of the contrast column (default 'contrast').
    palette : dict, optional
        Mapping genotype -> color. If None, uses defaults.
    figsize : tuple, optional
        Figure size.
    title : str, optional
        Plot title.
    sem_alpha : float, optional
        Opacity for the SEM band.
        """


    # default palette
    if palette is None:
        palette = {
            'ChAT' : '#009F61',
            'DAT'  : '#DF4661',
            'Dbh'  : '#35A2B6',
            'SERT' : '#8866BC',
            'TH'   : 'grey'
        }

    # define a fixed contrast grid
    xx = np.arange(-100, 100)
    contrast = xx
    

    plt.figure(figsize=figsize)
    for geno, subdf in fit_df.groupby('genotype'):
        # build an array of size (n_subjects, n_contrasts)
        all_ff = np.vstack([
            erf_psycho_2gammas(
                np.array([r['bias'], r['threshold'], r['lapse_low'], r['lapse_high']]),
                contrast
            )
            for _, r in subdf.iterrows()
        ])
        mean_ff = all_ff.mean(axis=0)
        sem_ff  = all_ff.std(axis=0, ddof=1) / np.sqrt(all_ff.shape[0])

        color = palette.get(geno, 'k')
        # plot mean line
        plt.plot(contrast, mean_ff, lw=2, color=color, label=geno)
        # plot SEM shading
        plt.fill_between(
            contrast,
            mean_ff - sem_ff,
            mean_ff + sem_ff,
            color=color,
            alpha=sem_alpha
        )

    plt.xlabel('Contrast (%)')
    plt.ylabel('Probability of choosing Right')
    plt.title(title)
    plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')
    plt.tight_layout()
    plt.show()

In [None]:
def plot_psychometric_functions(df, color='#8866BC'):
    """
    fig 3a "Standardized and reproducible measurement of decision-making in mice IBL 2021"
    
    IMPORTANT _ ONE _ plot psychometric in one genotype in particular. 
    Plot psychometric functions for each subject and the mean curve.

    Parameters
    ----------
    df : pd.DataFrame
        Must contain columns: ['bias', 'threshold', 'lapse_low', 'lapse_high']
    color : str
        Color for the mean curve (default is '#8866BC')
    """
    x_vals = np.linspace(-100, 100, 500)

    all_curves = []

    plt.figure(figsize=(4, 7))

    for _, row in df.iterrows():
        pars = [row['bias'], row['threshold'], row['lapse_low'], row['lapse_high']]
        y_vals = erf_psycho_2gammas(pars, x_vals)  # use your existing function
        all_curves.append(y_vals)
        plt.plot(x_vals, y_vals, color='black', linewidth=1, alpha=0.2)

    # Mean psychometric function
    mean_curve = np.mean(all_curves, axis=0)
    plt.plot(x_vals, mean_curve, color=color, linewidth=2.5, label='WT')

    plt.legend(
        loc='upper left',
        bbox_to_anchor=(0.05, 1),  # (x, y) in axis fraction
        frameon=False
    )

    # Labels and formatting
    plt.xlabel('Signed contrast (%)')
    plt.ylabel('P(choose right)')
    plt.ylim(0, 1)
    plt.title('WT Psychometric Functions per Subject')
    plt.tight_layout()
    plt.show()

In [None]:
def plot_threshold_evolution(df, color='#8866BC'):
    """
    fig 2 b "Standardized and reproducible measurement of decision-making in mice IBL 2021"
    Plot the evolution of the threshold across sessions.
    
    Parameters
    ----------
    df : pd.DataFrame
        Must contain ['subject', 'session_num', 'threshold'] columns.
    color : str
        Color of the mean curve (e.g. '#8866BC' for purple)
    """
    plt.figure(figsize=(4, 5))

    # Plot each subject's line in grey
    for subject, sub_df in df.groupby('subject'):
        plt.plot(
            sub_df['session_num'], sub_df['threshold'],
            color='lightgrey', linewidth=1, alpha=0.8
        )

    # Compute and plot the mean threshold per session
    mean_df = df.groupby('session_num')['threshold'].mean().reset_index()
    plt.plot(
        mean_df['session_num'], mean_df['threshold'],
        color=color, linewidth=2.5, label='WT'
    )

    # Labels and style
    plt.xlabel('Training day')
    plt.ylabel('Contrast Threshold %')
    plt.title('Evolution of Threshold Across Sessions')
    plt.legend()
    plt.tight_layout()
    sns.despine()
    plt.show()

In [None]:
def plot_bias_evolution(df, color='#8866BC'):
    """
    fig 2 c "Standardized and reproducible measurement of decision-making in mice IBL 2021"
    Plot the evolution of the threshold across sessions.
    
    Parameters
    ----------
    df : pd.DataFrame
        Must contain ['subject', 'session_num', 'threshold'] columns.
    color : str
        Color of the mean curve (e.g. '#8866BC' for purple)
    """
    plt.figure(figsize=(4, 5))

    # Plot each subject's line in grey
    for subject, sub_df in df.groupby('subject'):
        plt.plot(
            sub_df['session_num'], sub_df['bias'],
            color='lightgrey', linewidth=1, alpha=0.8
        )

    # Compute and plot the mean threshold per session
    mean_df = df.groupby('session_num')['bias'].mean().reset_index()
    plt.plot(
        mean_df['session_num'], mean_df['bias'],
        color=color, linewidth=2.5, label='5-HT'
    )

    # Labels and style
    plt.xlabel('Training day')
    plt.ylabel('Bias %')
    plt.title('Evolution Bias Across Sessions')
    plt.legend()
    plt.tight_layout()
    sns.despine()
    plt.show()

In [None]:
def plot_bias_by_genotype(df):
    """
    fig 3 e "Standardized and reproducible measurement of decision-making in mice IBL 2021"
    Plot bias (%) for each genotype + combined 'All', with no dots for 'All'.
    """
    fig, ax = plt.subplots(figsize=(5, 6))

    # Ordered NM list
    NM_order = ['5HT', 'ACh', 'DA', 'NE', 'WT']
    data = [df[df['NM'] == nm]['bias'].values for nm in NM_order]
    data.append(df['bias'].values)  # 'All' group

    # Position for each group
    positions = list(range(len(data)))

    # Plot boxplots
    bp = ax.boxplot(data, patch_artist=True, positions=positions)

    # Colors as requested
    colors = ['#8866BC', '#009F61', '#DF4661', '#35A2B6', 'grey', 'black']
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.4)
        patch.set_linewidth(1)

    # Individual dots (exclude 'All')
    for i, values in enumerate(data):
        if i == len(data) - 1:
            continue  # Skip 'All'
        ax.scatter(
            np.random.normal(loc=i, scale=0.08, size=len(values)),
            values,
            color=colors[i],
            alpha=0.7,
            edgecolor='k',
            s=30
        )

    ax.set_xticks(positions)
    ax.set_xticklabels(NM_order + ['All'], rotation=45)
    ax.set_ylabel('Bias (%)')
    ax.set_title('Bias by genotype')
    # ax.set_ylim(-20, 20)
    plt.tight_layout()
    plt.show()

def plot_threshold_by_genotype(df):
    """
    fig 3 d "Standardized and reproducible measurement of decision-making in mice IBL 2021"
    Plot bias (%) for each genotype + combined 'All', with no dots for 'All'.
    """
    fig, ax = plt.subplots(figsize=(5, 6))

    # Ordered NM list
    NM_order = ['5HT', 'ACh', 'DA', 'NE', 'WT']
    data = [df[df['NM'] == nm]['threshold'].values for nm in NM_order]
    data.append(df['threshold'].values)  # 'All' group

    # Position for each group
    positions = list(range(len(data)))

    # Plot boxplots
    bp = ax.boxplot(data, patch_artist=True, positions=positions)

    # Colors as requested
    colors = ['#8866BC', '#009F61', '#DF4661', '#35A2B6', 'grey', 'black']
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.4)
        patch.set_linewidth(1)

    # Individual dots (exclude 'All')
    for i, values in enumerate(data):
        if i == len(data) - 1:
            continue  # Skip 'All'
        ax.scatter(
            np.random.normal(loc=i, scale=0.08, size=len(values)),
            values,
            color=colors[i],
            alpha=0.7,
            edgecolor='k',
            s=30
        )

    ax.set_xticks(positions)
    ax.set_xticklabels(NM_order + ['All'], rotation=45)
    ax.set_ylabel('Bias (%)')
    ax.set_title('Contrast threshold by genotype')
    # ax.set_ylim(-20, 20)
    plt.tight_layout()
    plt.show()

plot_bias_by_genotype(fit_results_df)
plot_threshold_by_genotype(fit_results_df)


Applying the logistic regression. 

Bias analysis, finding the change of psychometric curve when there's a change in bias. 

In [93]:
def prob_choose_right_bias (bias_df):
    # Calculate the probability of choosing right, byt filtering by probability Left
    # Step 1: Group and count choices
    grouped = bias_df.groupby(['NM', 'subject', 'session_num', 'contrast', 'probabilityLeft', 'choice']).size().reset_index(name='trials')
    
    # Step 2: Compute the proportion of 'right' choices (assuming right = 1)
    psycho_df_f = grouped.groupby(['NM', 'subject', 'session_num', 'contrast', 'probabilityLeft']).apply(
        lambda g: pd.Series({
            'prob_choose_right': g.loc[g['choice'] == -1, 'trials'].sum() / g['trials'].sum(),
            'rigth_choice': g.loc[g['choice'] == -1, 'trials'].sum(),
            'total_trials': g['trials'].sum()
        })
    ).reset_index()
    return psycho_df_f

In [1]:
def proportion_correct_bias(bias_df, df_trials):
    '''
    Once you have the probability of choosing right, add the proprotion correct of each session to then select only 
    high peformance sessions.
    bias_df: your original loading df, on which you can apply def proportion_correct
    df_trials: Data Frame where you calculated the probability of choosing right
    '''
    df_perf = proportion_correct(bias_df)
    
    df_perf_bias = df_perf[df_perf['stage']=='bias']
    
    # now keep just the columns you need to merge in
    df_perf_bias = df_perf_bias[[
        'NM','subject','session_num','proportion_correct'
    ]]
    
    # merge them together
    df_big = pd.merge(
        df_trials,
        df_perf_bias,
        how='left',
        on=['NM','subject','session_num']
    )
    return df_big

In [27]:
def fit_psychometric_by_subject_and_pL(
    data: pd.DataFrame,
    fit_fn,
    model_fn,
    contrast_col: str = 'contrast',
    response_col: str = 'prob_choose_right',
    genotype_col: str = 'NM',
    subject_col: str = 'subject',
    pL_col: str = 'probabilityLeft',
    bounds=None,
    init=None,
    maxfev: int = 10000
) -> pd.DataFrame:
    """
    Fit a psychometric model separately for each (NM, subject, probabilityLeft) group
    using mle_fit_psycho with erf_psycho_2gammas and fixed parameter settings.
    """
    results = []

    # unique combinations of genotype, subject, pL
    groups = data[[genotype_col, subject_col, pL_col]].drop_duplicates()

    for geno, subj, pL in groups.values:
        # filter down to this group
        subdf = data[
            (data[genotype_col] == geno) &
            (data[subject_col]  == subj) &
            (data[pL_col]       == pL)
        ]
        # get x and y
        x = subdf[contrast_col].to_numpy().astype(float)
        y = subdf[response_col].to_numpy().astype(float)

        # build data matrix: contrast in %, assume one trial per entry
        contrasts_pct = x * 100
        n_trials = np.ones_like(contrasts_pct, dtype=int)
        data_matrix = np.vstack([contrasts_pct, n_trials, y])

        # fixed parameter settings
        P_model = 'erf_psycho_2gammas'
        parstart = np.array([np.mean(contrasts_pct), 20.0, 0.05, 0.05])
        parmin   = np.array([np.min(contrasts_pct),     0.0,  0.0,  0.0])
        parmax   = np.array([np.max(contrasts_pct),   100.0,  1.0,  1.0])
        nfits    = 5

        # fit using mle_fit_psycho
        pars, L = mle_fit_psycho(
            data_matrix,
            P_model=P_model,
            parstart=parstart,
            parmin=parmin,
            parmax=parmax,
            nfits=nfits
        )

                # unpack parameters
        bias, threshold, lapse_low, lapse_high = pars

        results.append({
            genotype_col: geno,
            subject_col: subj,
            pL_col: pL,
            'bias': bias,
            'threshold': threshold,
            'lapse_low': lapse_low,
            'lapse_high': lapse_high,
            'log_likelihood': L
        })

    return pd.DataFrame(results)(results)

In [71]:
def plot_mean_psychometric_by_genotype_and_pL(fit_df,
                                              last5_df,
                                              contrast_col='contrast',
                                              prob_col='probability_left',
                                              p_levels=(0.2, 0.5, 0.8),
                                              palette=None,
                                              style_map=None,
                                              figsize=(7,5),
                                              title='Last 5 days – Mean psychometric by genotype & pL'):
    """
    Plot mean psychometric curves, one per (genotype, probability_left) pair.

    Parameters
    ----------
    fit_df : pd.DataFrame
        Must have columns ['genotype', prob_col, 'a','b','g','l'].
    last5_df : pd.DataFrame
        Must have a contrast column to define the x‐axis.
    contrast_col : str
        Name of the contrast column in last5_df.
    prob_col : str
        Name of the probability‐left column in fit_df.
    p_levels : iterable of float
        Which probability_left values to plot (e.g. [0.2,0.5,0.8]).
    palette : dict, optional
        genotype → color hex. Defaults to your original five colors.
    style_map : dict, optional
        p_level → line style. Defaults to {0.2:'--',0.5:'-',0.8':':' }.
    """
    # defaults
    if palette is None:
        palette = {
            'ChAT': '#009F61',
            'DAT' : '#DF4661',
            'Dbh' : '#35A2B6',
            'SERT': '#8866BC',
            'TH'  : 'grey'
        }
    if style_map is None:
        style_map = {0.2: '--', 0.5: '-', 0.8: ':'}

    # build and sort contrast axis (0–1 → 0–100%)
    contrast = np.sort(last5_df[contrast_col].unique() * 100)

    plt.figure(figsize=figsize)
    for geno, geno_df in fit_df.groupby('NM'):
        col = palette.get(geno, 'k')
        for pL in p_levels:
            sub = geno_df[geno_df[prob_col] == pL]
            if sub.empty:
                continue

            # stack each subject’s curve
            all_ff = np.vstack([
                erf_psycho_2gammas(
                    np.array([r['bias'], r['threshold'], r['lapse_low'], r['lapse_high']]),
                    contrast
                )
                for _, r in sub.iterrows()
            ])
            mean_ff = all_ff.mean(axis=0)

            ls = style_map.get(pL, '-')
            plt.plot(contrast, mean_ff,
                     color=col, linestyle=ls, lw=2,
                     label=f"{geno}, pL={pL}")

    plt.xlabel('Contrast (%)')
    plt.ylabel('Probability of choosing Right')
    plt.title(title)
    plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')
    plt.tight_layout()
    plt.show()

In [74]:
def plot_zoom_mean_psychometric_by_genotype_and_pL(fit_df,
                                              last5_df,
                                              contrast_col='contrast',
                                              prob_col='probability_left',
                                              p_levels=(0.2, 0.5, 0.8),
                                              palette=None,
                                              style_map=None,
                                              figsize=(10,5),
                                              title='Last 5 days – Mean psychometric by genotype & pL',
                                              xlim=None):
    """
    Plot mean psychometric curves, one per (genotype, probability_left) pair,
    optionally zoomed in on a particular contrast range.

    Parameters
    ----------
    fit_df : pd.DataFrame
        Must have columns ['genotype', prob_col, 'bias', 'threshold', 'lapse_low', 'lapse_high'].
    last5_df : pd.DataFrame
        Must have a contrast column to define the x‐axis.
    contrast_col : str
        Name of the contrast column in last5_df.
    prob_col : str
        Name of the probability‐left column in fit_df.
    p_levels : iterable of float
        Which probability_left values to plot (e.g. [0.2,0.5,0.8]).
    palette : dict, optional
        genotype → color hex.
    style_map : dict, optional
        p_level → line‐style.
    figsize : tuple, optional
        Figure size.
    title : str, optional
        Plot title.
    xlim : tuple (xmin, xmax), optional
        If given, set the x‐axis limits to this range (in %).
    """
    # defaults
    if palette is None:
        palette = {
            'ChAT': '#009F61',
            'DAT' : '#DF4661',
            'Dbh' : '#35A2B6',
            'SERT': '#8866BC',
            'TH'  : 'grey'
        }
    if style_map is None:
        style_map = {0.2: '--', 0.5: '-', 0.8: ':'}

    # build and sort contrast axis (0–1 → 0–100%)
    contrast = np.sort(last5_df[contrast_col].unique() * 100)

    plt.figure(figsize=figsize)
    for geno, geno_df in fit_df.groupby('NM'):
        col = palette.get(geno, 'k')
        for pL in p_levels:
            sub = geno_df[geno_df[prob_col] == pL]
            if sub.empty:
                continue

            all_ff = np.vstack([
                erf_psycho_2gammas(
                    np.array([r['bias'], r['threshold'], r['lapse_low'], r['lapse_high']]),
                    contrast
                )
                for _, r in sub.iterrows()
            ])
            mean_ff = all_ff.mean(axis=0)

            ls = style_map.get(pL, '-')
            plt.plot(contrast, mean_ff,
                     color=col, linestyle=ls, lw=2,
                     label=f"{geno}, pL={pL}")

    # apply zoom if requested
    if xlim is not None:
        plt.xlim(xlim)
        

    plt.xlabel('Contrast (%)')
    plt.ylabel('Probability of choosing Right')
    plt.title(title)
    plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')
    plt.tight_layout()
    plt.show()

In [None]:
def transform_and_plot_bias(df_param, color_map=None, figsize=(10, 6)):
    """
    Calculate and plot the bias differences when changing blocks. 
    
    Parameters
    ----------
    df_param : pandas.DataFrame
        Input DataFrame containing columns ['NM', 'subject', 'probabilityLeft', 'bias'].
    color_map : dict, optional
        Mapping of 'NM' categories to hex color strings. Defaults to a predefined palette.
    figsize : tuple, optional
        Figure size for the plot. Default is (10, 6).

    Returns
    -------
    df_melted : pandas.DataFrame
        Melted DataFrame with columns ['NM', 'subject', 'Bias', 'D_a'] containing bias differences.
    fig : matplotlib.figure.Figure
        The matplotlib Figure object.
    ax : matplotlib.axes.Axes
        The matplotlib Axes object with the boxplot.
    """
    # Default color map if none provided
    default_colors = {
        '5HT': '#8866BC',
        'WT':  '#808080',
        'NE':  '#35A2B6',
        'DA':  '#DF4661',
        'ACh': '#009F61'
    }
    cmap = color_map or default_colors

    # Filter relevant columns
    delta_df = df_param.filter(['NM', 'subject', 'probabilityLeft', 'bias'])

    # Pivot to wide format
    df_pivot = (
        delta_df
        .pivot(index=['NM', 'subject'], columns='probabilityLeft', values='bias')
        .reset_index()
    )
    df_pivot.columns.name = None
    df_pivot = df_pivot.rename(columns={0.2: 'bias_0.2', 0.5: 'bias_0.5', 0.8: 'bias_0.8'})

    # Compute bias differences
    df_pivot['bias_R'] = df_pivot['bias_0.2'] - df_pivot['bias_0.5']
    df_pivot['bias_L'] = df_pivot['bias_0.5'] - df_pivot['bias_0.8']

    # Select final columns and melt
    delta_df2 = df_pivot[['NM', 'subject', 'bias_R', 'bias_L']]
    df_melted = pd.melt(
        delta_df2,
        id_vars=('NM', 'subject'),
        value_vars=['bias_R', 'bias_L'],
        var_name='Bias',
        value_name='D_a'
    )

    # Prepare the plot
    fig, ax = plt.subplots(figsize=figsize)
    nm_levels = df_melted['NM'].unique()
    bias_levels = df_melted['Bias'].unique()

    # Compute positions and widths
    n_nm = len(nm_levels)
    box_width = 0.15
    group_width = box_width * n_nm
    x_positions = []
    for i, bias in enumerate(bias_levels):
        base = i * (group_width + 0.1)
        for j, nm in enumerate(nm_levels):
            x_positions.append(base + j * box_width)

    # Plot each box
    data_groups = []
    labels = []
    colors = []
    for bias in bias_levels:
        for nm in nm_levels:
            subset = df_melted[(df_melted['Bias'] == bias) & (df_melted['NM'] == nm)]['D_a']
            data_groups.append(subset)
            labels.append(f"{bias}_{nm}")
            colors.append(cmap.get(nm, '#000000'))

    ax.boxplot(data_groups, positions=x_positions, widths=box_width, patch_artist=True)

    # Apply colors
    for patch, color in zip(ax.artists, colors):
        patch.set_facecolor(color)

    # Ticks and labels
    ax.set_xticks([((i * (group_width + 0.1)) + (group_width - box_width) / 2) for i in range(len(bias_levels))])
    ax.set_xticklabels(bias_levels)
    ax.set_xlabel('Bias Type')
    ax.set_ylabel('D_a')
    ax.set_title('Bias Differences by NM and Bias Type')
    ax.legend(handles=[plt.Line2D([], [], color=c, marker='s', linestyle='') for nm, c in cmap.items()],
              labels=list(cmap.keys()), title='NM', bbox_to_anchor=(1.05, 1), loc='upper left')
    fig.tight_layout()

    return df_melted, fig, ax
