Generate plots

In [None]:
import numpy as np
import pandas as pd
import pickle
import joypy
from pathlib import Path

from isttc.scripts.cfg_global import project_folder_path

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import seaborn as sns

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'

In [None]:
results_folder = project_folder_path + 'results\\synthetic\\results\\bin_size_runs\\'
fig_folder = results_folder + 'figs\\'

In [None]:
save_fig = True

### Load summary dataframes

In [None]:
def inspect(df, name):
    print(f"len {name}: {len(df)}")
    print(df.columns.tolist())

results_folder = Path(results_folder)

# only units where for all bin sizes tau was computed
print('ACF, iSTTC, PersonR, iSTTC trails (сoncat)')
df_full = pd.read_pickle(results_folder / "full_signal\\summary_tau_full_long_lags_df.pkl")
inspect(df_full, "df_full")

df_trials = pd.read_pickle(results_folder / "trials\\summary_tau_trials_long_lags_df.pkl")
inspect(df_trials, "df_trials")

# all units
print('ACF, iSTTC, PersonR, iSTTC trails (сoncat)')
df_full_all_units = pd.read_pickle(results_folder / "full_signal\\summary_tau_full_long_lags_df_all_units.pkl")
inspect(df_full_all_units, "df_full_all_units")

df_trials_all_units = pd.read_pickle(results_folder / "trials\\summary_tau_trials_long_lags_df_all_units.pkl")
inspect(df_trials_all_units, "df_trials_all_units")

### Prepare extra dataframes for plotting 

In [None]:
# extract the reference tau_ms for n_lags=20
ref = df_full[df_full['n_lags'] == 20][['unit_id', 'method', 'tau_ms']]
ref = ref.rename(columns={'tau_ms': 'tau_ms_ref'})

# merge reference back onto the full dataframe
merged_full = df_full.merge(ref, on=['unit_id', 'method'], how='left')

# compute the difference
merged_full['tau_ms_diff'] = merged_full['tau_ms'] - merged_full['tau_ms_ref']

merged_full['tau_ms_diff_log10'] = np.log10(np.abs(merged_full['tau_ms_diff']) + 1)
merged_full['tau_diff_rel_log10'] = np.log10(merged_full['tau_diff_rel'])

merged_full

In [None]:
# extract the reference tau_ms for n_lags=20
ref = df_trials[df_trials['n_lags'] == 20][['unit_id', 'method', 'tau_ms']]
ref = ref.rename(columns={'tau_ms': 'tau_ms_ref'})

# merge reference back onto the full dataframe
merged_trials = df_trials.merge(ref, on=['unit_id', 'method'], how='left')

# compute the difference
merged_trials['tau_ms_diff'] = merged_trials['tau_ms'] - merged_trials['tau_ms_ref']

merged_trials['tau_ms_diff_log10'] = np.log10(np.abs(merged_trials['tau_ms_diff']) + 1)
merged_trials['tau_diff_rel_log10'] = np.log10(merged_trials['tau_diff_rel'])

merged_trials

### Plots

In [None]:
color_acf_full = '#718190'
color_isttc_full = '#1ba9e2' 
color_pearson_trail_avg = '#f4a91c' 
color_sttc_trail_avg =  '#a49fce' 
color_sttc_trail_concat = '#955da2' 
color_abctau = '#B4464B'

color_trails = [color_pearson_trail_avg, color_sttc_trail_concat] 
color_full = [color_acf_full, color_isttc_full] 

#### Difference from 50ms time bins

In [None]:
fig, axes = plt.subplots(1,2, figsize=(10,3), sharey=True)
plt.subplots_adjust(hspace=0.4, wspace=0.4)

sns.violinplot(x='bin_size', y='tau_ms_diff_log10', hue='method', data=merged_full , cut=0, density_norm='width', 
               legend=True, palette=color_full, ax=axes[0])
axes[0].set_title('Full signal')

sns.violinplot(x='bin_size', y='tau_ms_diff_log10', hue='method', data=merged_trials , cut=0, density_norm='width', 
               legend=True, palette=color_trails, ax=axes[1])
axes[1].set_title('Trials')

#axes.set_xticklabels(['10', '20\n(reference)', '100'])

for ax in axes:
    ax.set_ylabel('IT diff, log10')
    ax.set_xlabel('Number of lags (a.u.)')
    ax.axhline(y=2, lw=0.5, c='k')
    ax.axhline(y=0, lw=0.5, c='k')
    ax.legend(frameon=False)
    sns.despine(ax=ax)

if save_fig:
    plt.savefig(fig_folder + 'it_diff_50ms_ref_vert.png' , bbox_inches='tight', dpi=300)
    plt.savefig(fig_folder + 'it_diff_50ms_ref_vert.svg' , bbox_inches='tight')

In [None]:
# fig, axes = plt.subplots(1,2, figsize=(10,3))
# plt.subplots_adjust(hspace=0.4, wspace=0.4)

# sns.violinplot(x='bin_size', y='tau_diff_rel_log10', hue='method', data=merged_full , cut=0, density_norm='width', 
#                legend=True, palette=color_full,  ax=axes[0])
# axes[0].set_title('Full')

# sns.violinplot(x='bin_size', y='tau_diff_rel_log10', hue='method', data=merged_trials , cut=0, density_norm='width', 
#                legend=True, palette=color_trails, ax=axes[1])
# axes[1].set_title('Trials')

# for ax in axes.flat:
#     ax.axhline(y=2, lw=0.5, c='k')
#     ax.axhline(y=0, lw=0.5, c='k')
#     ax.set_ylabel('REE, log10')
#     ax.set_xlabel('Number of lags (a.u.)')
#     ax.legend(frameon=False)
#     sns.despine(ax=ax)

# if save_fig:
#     plt.savefig(fig_folder + 'ree_vs_bin_size.png' , bbox_inches='tight', dpi=300)
#     plt.savefig(fig_folder + 'ree_vs_bin_size.svg' , bbox_inches='tight')

#### % of failed estimates (failed estimation and negative R-squared)

In [None]:
def compute_exclusion(group):
    tau_nans = group['tau'].isna().sum()
    fit_r_squared_neg = (group['fit_r_squared'] < 0).sum()
    either_issue = ((group['tau'].isna()) | (group['fit_r_squared'] < 0)).sum()
    total = len(group)
    percent_tau_nans = 100 * tau_nans / total if total > 0 else 0
    percent_fit_r_squared_neg = 100 * fit_r_squared_neg / total if total > 0 else 0
    percent_either_issue = 100 * either_issue / total if total > 0 else 0
    return pd.Series({
        'tau_nans': tau_nans,
        'fit_r_squared_neg': fit_r_squared_neg,
        'exclusion': either_issue,
        'percent_tau_nans': percent_tau_nans,
        'percent_fit_r_squared_neg': percent_fit_r_squared_neg,
        'exclusion_perc': percent_either_issue
    })

In [None]:
exclusion_df = df_full_all_units.groupby(['method', 'bin_size']).apply(compute_exclusion).reset_index()
#exclusion_df
exclusion_df_trials = df_trials_all_units.groupby(['method', 'bin_size']).apply(compute_exclusion).reset_index()
exclusion_df_trials

In [None]:
def plot_one_threshold_jittered_lines(ax, df, field_to_plot, methods, jitter=0.15, seed=None, plot_lines=False):
    if seed is not None:
        np.random.seed(seed)

    # fixed ordering of intervals
    cats = list(df["bin_size"].unique())
    idx_map = {cat:i for i,cat in enumerate(cats)}

    for method, col, label in methods:
        sub = df[df["method"] == method]
        # map string-labels to integer indices
        x0 = np.array([idx_map[s] for s in sub["bin_size"]])
        y  = sub[field_to_plot].values

        #  one jitter per point
        offsets = np.random.uniform(-jitter, jitter, size=len(x0))
        xj = x0 + offsets

        if plot_lines:
            order = np.argsort(x0)
            ax.plot(xj[order], y[order],
                    color=col,
                    linewidth=0.8,
                    alpha=1)
        # plot jittered points
        ax.scatter(xj, y,
                   color=col,
                   s=40,
                   marker="o",
                   label=label)

    ax.set_xticks(np.arange(len(cats)))
    ax.set_xticklabels(cats)

In [None]:
methods = [
    ("pearsonr_trial_avg", color_pearson_trail_avg, "pearsonr_avg"),
    ("sttc_trial_concat",  color_sttc_trail_concat,  "sttc_concat")
]

fig, axes = plt.subplots(1,1,figsize=(4,3))

plot_one_threshold_jittered_lines(axes,
    exclusion_df_trials, 'exclusion_perc', methods,
    jitter=0.3, seed=42)
axes.set_ylabel('Excluded units (%)')
axes.set_ylim([0, 51])
axes.set_xlabel('Bin size (ms)')
axes.grid(True, which="both", axis='y', linestyle="--", linewidth=0.5, alpha=0.7)
sns.despine()

if save_fig:
    plt.savefig(fig_folder + 'inclusion_trials.png' , bbox_inches='tight', dpi=300)
    plt.savefig(fig_folder + 'inclusion_trials.svg' , bbox_inches='tight')

In [None]:
methods = [
    ("acf_full",           color_acf_full,           "acf_full"),
    ("acf_isttc_full",         color_isttc_full,         "acf_isttc_full"),
]

fig, axes = plt.subplots(1,1,figsize=(4,3))

plot_one_threshold_jittered_lines(axes,
    exclusion_df, 'exclusion_perc', methods,
    jitter=0.3, seed=42)
axes.set_ylabel('Excluded units (%)')
axes.set_ylim([0, 1.5])
axes.set_xlabel('Bin size (ms)')
axes.grid(True, which="both", axis='y', linestyle="--", linewidth=0.5, alpha=0.7)
sns.despine()
    
if save_fig:
    plt.savefig(fig_folder + 'inclusion_full.png' , bbox_inches='tight', dpi=300)
    plt.savefig(fig_folder + 'inclusion_full.svg' , bbox_inches='tight')