In [1]:
import numpy as np
import pandas as pd
import pickle
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns

# import from scripts
import os
current_wd = os.getcwd()
os.chdir(os.path.abspath("..\\..\\..\\isttc\\scripts"))
#os.chdir(os.path.abspath("C:\\Users\\ipoch\\Documents\\repos\\isttc\\scripts"))
from calculate_tau import fit_single_exp, func_single_exp, func_single_exp_monkey
from cfg_global import project_folder_path
os.chdir(current_wd)

In [2]:
dataset_folder = project_folder_path + 'results\\allen_mice\\dataset\\cut_30min\\'
fig_folder = project_folder_path + 'results\\allen_mice\\fig_draft_paper\\'

#### Load data

In [3]:
units_info_df = pd.read_pickle(dataset_folder + 'sua_list_constrained_units_df.pkl')
units_info_df_subset = units_info_df[['unit_id', 'ecephys_structure_acronym']].copy()
units_info_df_subset.head(2)

Unnamed: 0,unit_id,ecephys_structure_acronym
66,950913540,VISam
67,950915005,VISam


In [12]:
acf_full_df_file = dataset_folder + 'binned\\acf\\acf_full_50ms_20lags_df.pkl'
acf_full_df = pd.read_pickle(acf_full_df_file)

acf_isttc_full_df_file = dataset_folder + 'non_binned\\acf\\acf_isttc_full_50ms_20lags_df.pkl'
acf_isttc_full_df = pd.read_pickle(acf_isttc_full_df_file)

#### Calculate tau per unit

In [4]:
n_lags = 20
acf_cols = ['acf_' + str(i) for i in range(n_lags+1)]
print('acf_cols {}'.format(acf_cols))

acf_cols ['acf_0', 'acf_1', 'acf_2', 'acf_3', 'acf_4', 'acf_5', 'acf_6', 'acf_7', 'acf_8', 'acf_9', 'acf_10', 'acf_11', 'acf_12', 'acf_13', 'acf_14', 'acf_15', 'acf_16', 'acf_17', 'acf_18', 'acf_19', 'acf_20']


In [13]:
acf_full_2d = acf_full_df[acf_cols].values
print(f'acf_2d shape {acf_full_2d.shape}')
acf_full_unit_ids = acf_full_df['unit_id'].values
print(f'acf_full_unit_ids shape {acf_full_unit_ids.shape}')

acf_full_dict = {}
for unit_id_idx, unit_id in enumerate(acf_full_unit_ids):
    if unit_id_idx % 100 == 0:
        print(f'#####\nProcessing unit {unit_id}, {unit_id_idx+1}/{len(acf_full_unit_ids)}, {datetime.now()}')
    fit_popt, fit_pcov, tau, tau_ci, fit_r_squared, explained_var, log_message = fit_single_exp(acf_full_2d[unit_id_idx,:],
                                                                              start_idx_=1, exp_fun_=func_single_exp_monkey)
    taus = {'tau':tau,
            'tau_lower':tau_ci[0],
            'tau_upper':tau_ci[1],
            'fit_r_squared': fit_r_squared,
            'explained_var': explained_var,
            'popt': fit_popt,
            'pcov': fit_pcov,
            'log_message': log_message}
    acf_full_dict[unit_id] = {'taus': taus,
                              'acf': acf_full_2d[unit_id_idx,:]}

acf_2d shape (5775, 21)
acf_full_unit_ids shape (5775,)
#####
Processing unit 950913540, 1/5775, 2025-04-07 18:32:02.979273
#####
Processing unit 950929874, 101/5775, 2025-04-07 18:32:04.465285
#####
Processing unit 950943620, 201/5775, 2025-04-07 18:32:07.000454
#####
Processing unit 950917982, 301/5775, 2025-04-07 18:32:08.477402
#####
Processing unit 950932317, 401/5775, 2025-04-07 18:32:10.679057
#####
Processing unit 950925550, 501/5775, 2025-04-07 18:32:13.230908
#####
Processing unit 950950121, 601/5775, 2025-04-07 18:32:15.063452
#####
Processing unit 950912481, 701/5775, 2025-04-07 18:32:16.476815
#####
Processing unit 950918988, 801/5775, 2025-04-07 18:32:18.146716
#####
Processing unit 950940942, 901/5775, 2025-04-07 18:32:20.519924
#####
Processing unit 950993240, 1001/5775, 2025-04-07 18:32:22.024851
#####
Processing unit 951010492, 1101/5775, 2025-04-07 18:32:23.527307
#####
Processing unit 950991745, 1201/5775, 2025-04-07 18:32:25.431305
#####
Processing unit 951016626, 

In [None]:
isttc_full_2d = acf_isttc_full_df[acf_cols].values
print(f'isttc_full_2d shape {isttc_full_2d.shape}')
isttc_full_unit_ids = acf_isttc_full_df['unit_id'].values
print(f'isttc_full_unit_ids shape {isttc_full_unit_ids.shape}')

isttc_full_dict = {}
for unit_id_idx, unit_id in enumerate(isttc_full_unit_ids):
    if unit_id_idx % 1000 == 0:
        print(f'#####\nProcessing unit {unit_id}, {unit_id_idx+1}/{len(isttc_full_unit_ids)}, {datetime.now()}')
    fit_popt, fit_pcov, tau, tau_ci, fit_r_squared, explained_var, log_message = fit_single_exp(isttc_full_2d[unit_id_idx,:],
                                                                              start_idx_=1, exp_fun_=func_single_exp_monkey)
    taus = {'tau':tau,
            'tau_lower':tau_ci[0],
            'tau_upper':tau_ci[1],
            'fit_r_squared': fit_r_squared,
            'explained_var': explained_var,
            'popt': fit_popt,
            'pcov': fit_pcov,
            'log_message': log_message}
    isttc_full_dict[unit_id] = {'taus': taus,
                                'acf': isttc_full_2d[unit_id_idx,:]}

In [None]:
with open(dataset_folder + 'non_binned\\acf\\acf_isttc_full_50ms_20lags_dict.pkl', "wb") as f:
    pickle.dump(isttc_full_dict, f)

In [14]:
with open(dataset_folder + 'binned\\acf\\acf_full_50ms_20lags_dict.pkl', "wb") as f:
    pickle.dump(acf_full_dict, f)

#### Load calculated taus

In [17]:
with open(dataset_folder + 'non_binned\\acf\\acf_isttc_full_50ms_20lags_dict.pkl', "rb") as f:
    isttc_full_dict = pickle.load(f)

with open(dataset_folder + 'binned\\acf\\acf_full_50ms_20lags_dict.pkl', "rb") as f:
    acf_full_dict = pickle.load(f)

with open(dataset_folder + 'binned\\acf\\pearsonr_trial_avg_50ms_20lags_dict_0_2000.pkl', "rb") as f:
    pearsonr_trial_avg_dict = pickle.load(f)

with open(dataset_folder + 'non_binned\\acf\\sttc_trial_concat_50ms_20lags_dict100_2000.pkl', "rb") as f:
    sttc_trial_concat_dict = pickle.load(f)

#### Prep data for plots

In [18]:
def calculate_acf_decline_flag(acf_, start_idx=3, end_idx=5):
    acf_decay = np.all(np.diff(acf_[start_idx:end_idx]) <= 0)
    return acf_decay

In [22]:
pearsonr_trial_avg_dict['950913540']['taus']

[{'tau': np.float64(0.8490503974373463),
  'tau_lower': np.float64(-0.7344096897128075),
  'tau_upper': np.float64(2.4325104845875),
  'fit_r_squared': 0.3352365401634976,
  'explained_var': 0.3352365401634977,
  'popt': array([ 0.52402486,  0.8490504 , -0.01975556]),
  'pcov': array([[ 0.37301901, -0.43697194,  0.01790533],
         [-0.43697194,  0.5632812 , -0.02338785],
         [ 0.01790533, -0.02338785,  0.00158821]]),
  'log_message': 'ok'},
 {'tau': np.float64(0.0933013304370974),
  'tau_lower': np.float64(-196.6233487351656),
  'tau_upper': np.float64(196.8099513960398),
  'fit_r_squared': 0.5271176778979456,
  'explained_var': 0.5271176815025742,
  'popt': array([1.19222951e+04, 9.33013304e-02, 1.63888382e-06]),
  'pcov': array([[ 1.63065158e+16, -1.19063107e+10, -2.24151816e+06],
         [-1.19063107e+10,  8.69347183e+03,  1.63665936e+00],
         [-2.24151816e+06,  1.63665936e+00,  3.08122456e-04]]),
  'log_message': 'ok'},
 {'tau': np.float64(0.4096623685387619),
  'tau_

In [19]:
data = []
for unit_id, unit_data in acf_full_dict.items():
    taus = unit_data['taus']  
    data.append({
        'unit_id': unit_id,
        'tau': taus['tau'],
        'tau_lower': taus['tau_lower'],
        'tau_upper': taus['tau_upper'],
        'fit_r_squared': taus['fit_r_squared'],
        'decline_150_250': calculate_acf_decline_flag(unit_data['acf'])  
    })
acf_full_plot_df = pd.DataFrame(data)
acf_full_plot_df['method'] = 'acf_full'
acf_full_plot_df['tau_ms'] = acf_full_plot_df['tau'] * 50
acf_full_plot_df = acf_full_plot_df.merge(units_info_df_subset, on='unit_id', how='left')
acf_full_plot_df.head(10)

Unnamed: 0,unit_id,tau,tau_lower,tau_upper,fit_r_squared,decline_150_250,method,tau_ms,ecephys_structure_acronym
0,950913540,0.580319,0.336549,0.824088,0.9361658,True,acf_full,29.01593,VISam
1,950915005,0.034201,0.03420111,0.03420111,-5.534868e-10,True,acf_full,1.710055,VISam
2,950915018,0.037646,0.03764569,0.03764569,-2.423335e-10,False,acf_full,1.882284,VISam
3,950913798,0.02219,0.02218997,0.02218997,-2.805223e-10,False,acf_full,1.109498,VISam
4,950915049,0.025056,0.02505569,0.02505569,-3.396461e-11,False,acf_full,1.252785,VISam
5,950913944,6.892935,6.262411,7.523458,0.9968823,True,acf_full,344.6467,VISam
6,950913961,68518.000388,-488865200.0,489002300.0,0.7475119,False,acf_full,3425900.0,VISam
7,950913991,0.040627,-34352180.0,34352180.0,-3.705791e-11,True,acf_full,2.031341,VISam
8,950913984,8.470514,7.013852,9.927176,0.9918832,True,acf_full,423.5257,VISam
9,950915073,3.018096,1.386233,4.649959,0.8158901,False,acf_full,150.9048,VISam


In [20]:
data = []
for unit_id, unit_data in isttc_full_dict.items():
    taus = unit_data['taus']  
    data.append({
        'unit_id': unit_id,
        'tau': taus['tau'],
        'tau_lower': taus['tau_lower'],
        'tau_upper': taus['tau_upper'],
        'fit_r_squared': taus['fit_r_squared'],
        'decline_150_250': calculate_acf_decline_flag(unit_data['acf'])  
    })
acf_isttc_full_plot_df = pd.DataFrame(data)
acf_isttc_full_plot_df['method'] = 'isttc_full'
acf_isttc_full_plot_df['tau_ms'] = acf_isttc_full_plot_df['tau'] * 50
acf_isttc_full_plot_df = acf_isttc_full_plot_df.merge(units_info_df_subset, on='unit_id', how='left')
acf_isttc_full_plot_df.head(10)

Unnamed: 0,unit_id,tau,tau_lower,tau_upper,fit_r_squared,decline_150_250,method,tau_ms,ecephys_structure_acronym
0,950913540,0.480634,0.2827016,0.678566,0.9566893,True,isttc_full,24.03169,VISam
1,950915005,0.016986,0.01698587,0.01698587,-3.143068e-10,True,isttc_full,0.8492934,VISam
2,950915018,0.018591,0.01859093,0.01859093,-2.220446e-16,False,isttc_full,0.9295463,VISam
3,950913798,0.009842,0.009841583,0.009841583,-1.244016e-10,False,isttc_full,0.4920792,VISam
4,950915049,0.015463,0.01546313,0.01546313,-2.220446e-16,True,isttc_full,0.7731563,VISam
5,950913944,8.010364,7.412793,8.607934,0.9983261,True,isttc_full,400.5182,VISam
6,950913961,78189.066362,-436511500.0,436667900.0,0.7746628,False,isttc_full,3909453.0,VISam
7,950913991,0.037963,0.03796291,0.03796291,-1.767142e-11,True,isttc_full,1.898145,VISam
8,950913984,9.785418,8.158991,11.41184,0.9939657,True,isttc_full,489.2709,VISam
9,950915073,5.075232,-2.502886,12.65335,0.4539619,True,isttc_full,253.7616,VISam


In [None]:
summary_df = pd.concat([acf_full_plot_df, acf_isttc_full_plot_df])
summary_df.reset_index(inplace=True, drop=True)
summary_df['tau_ms_log10'] = np.log10(summary_df['tau_ms'])
summary_df.head(3)

In [None]:
total_counts_df = summary_df.groupby('method', as_index=False)['decline_150_250'].count()
total_counts_df.rename(columns={'decline_150_250': 'total_count'}, inplace=True)

units_count_df = summary_df.groupby('method', as_index=False)['decline_150_250'].sum()
units_count_df.rename(columns={'decline_150_250': 'true_count'}, inplace=True)

units_acf_decline_df = pd.merge(units_count_df, total_counts_df, on='method')
units_acf_decline_df['percentage'] = (units_acf_decline_df['true_count'] / units_acf_decline_df['total_count']) * 100

units_acf_decline_df

In [None]:
total_counts_per_area_df = summary_df.groupby(by=['method','ecephys_structure_acronym'], as_index=False)['decline_150_250'].count()
total_counts_per_area_df.rename(columns={'decline_150_250': 'total_count'}, inplace=True)

units_count_per_area_df = summary_df.groupby(by=['method','ecephys_structure_acronym'], as_index=False)['decline_150_250'].sum()
units_count_per_area_df.rename(columns={'decline_150_250': 'true_count'}, inplace=True)

units_acf_decline_per_area_df = pd.merge(total_counts_per_area_df, units_count_per_area_df, on=['method','ecephys_structure_acronym'])
units_acf_decline_per_area_df['percentage'] = (units_acf_decline_per_area_df['true_count'] / units_acf_decline_per_area_df['total_count']) * 100

units_acf_decline_per_area_df

#### Plots

##### Taus

In [None]:
color_acf_full = '#4783B4'
color_isttc_full = '#E2552A'
color_pearson_trail_avg = 'slategray' 
color_sttc_trail_avg =  '#E97451' 
color_sttc_trail_concat = '#B94E48' 

colors=[color_acf_full, color_isttc_full]

brain_areas_axes_ticks = ['LGd', 'VISp', 'VISl', 'VISrl', 'LP', 'VISal', 'VISpm', 'VISam']
brain_areas_names = ['LGN', 'V1', 'LM', 'RL', 'LP', 'AL', 'PM', 'AM']

In [None]:
fig, axes = plt.subplots(1,2, figsize=(10, 3))

sns.violinplot(ax=axes[0], x='method', y='tau_ms_log10', hue='method', data=summary_df, cut=0,  density_norm='width', palette=colors, legend=False)
sns.violinplot(ax=axes[1], x='method', y='tau_ms', hue='method', 
               data=summary_df.query('tau_ms <= 1000 and tau_ms > 10'), cut=0,  density_norm='width', palette=colors)

sns.despine()

In [None]:
fig, axes = plt.subplots(2,4, figsize=(16, 6))
plt.subplots_adjust(hspace=0.4, wspace=0.4)

for area_idx, area in enumerate(brain_areas_axes_ticks):
    y_idx = area_idx % 4 
    x_idx = area_idx // 4
    sns.violinplot(ax=axes[x_idx,y_idx], x='method', y='tau_ms', hue='method',
                   data=summary_df.query('tau_ms <= 1000 and tau_ms > 10 and ecephys_structure_acronym == @area'), 
                   cut=0,  density_norm='width', palette=colors)
    axes[x_idx,y_idx].set_title(area)

sns.despine()

In [None]:
fig, axes = plt.subplots(2,4, figsize=(16, 6))
plt.subplots_adjust(hspace=0.4, wspace=0.4)

for area_idx, area in enumerate(brain_areas_axes_ticks):
    y_idx = area_idx % 4 
    x_idx = area_idx // 4
    sns.violinplot(ax=axes[x_idx,y_idx], x='method', y='tau_ms_log10', hue='method',
                   data=summary_df.query('ecephys_structure_acronym == @area'), 
                   cut=0,  density_norm='width', palette=colors)
    axes[x_idx,y_idx].axhline(y=np.log10(10), lw=0.5, color='k')
    axes[x_idx,y_idx].set_title(area)

sns.despine()

##### Quality metrics

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(5, 3))

sns.barplot(data=units_acf_decline_df, x='method', y='percentage', palette=colors)

axes.set_ylabel('ACF decline 150-250ms (%)')
axes.set_xlabel('')
axes.set_ylim(0, 100)  

axes.set_xticklabels(['ACF full', 'iSTTC full', 'Pearson trial avg', 'STTC trial avg', 'STTC trial concat'], rotation=45, ha='right')
axes.set_xticks(['acf_full', 'isttc_full', 'pearsonr_trial_avg', 'sttc_trial_avg', 'sttc_trial_concat'])

for p, (true_count, total_count) in zip(axes.patches, zip(units_acf_decline_df['true_count'], units_acf_decline_df['total_count'])):
    percentage = p.get_height()
    axes.annotate(f'{percentage:.1f}%\n({true_count}/{total_count})', 
                (p.get_x() + p.get_width() / 2., p.get_height()), 
                ha='center', va='bottom', fontsize=8, color='black')

sns.despine()

# if save_fig:
#     fig.savefig(fig_folder + 'criteria2_all_units.png', bbox_inches='tight', dpi=300)
#     fig.savefig(fig_folder + 'criteria2_all_units.svg', bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(12, 3))

sns.barplot(data=units_acf_decline_per_area_df, x='ecephys_structure_acronym', y='percentage', hue='method', palette=colors)

axes.set_ylabel('ACF decline 150-250ms (%)')
# axes.set_xlabel('')
axes.set_ylim(0, 100)  
axes.legend(frameon=False, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=2)

# axes.set_xticklabels(['ACF full', 'iSTTC full', 'Pearson trial avg', 'STTC trial avg', 'STTC trial concat'], rotation=45, ha='right')
# axes.set_xticks(['acf_full', 'isttc_full', 'pearsonr_trial_avg', 'sttc_trial_avg', 'sttc_trial_concat'])

# for p, (true_count, total_count) in zip(axes.patches, zip(units_acf_decline_df['true_count'], units_acf_decline_df['total_count'])):
#     percentage = p.get_height()
#     axes.annotate(f'{percentage:.1f}%\n({true_count}/{total_count})', 
#                 (p.get_x() + p.get_width() / 2., p.get_height()), 
#                 ha='center', va='bottom', fontsize=8, color='black')

sns.despine()

# if save_fig:
#     fig.savefig(fig_folder + 'criteria2_all_units.png', bbox_inches='tight', dpi=300)
#     fig.savefig(fig_folder + 'criteria2_all_units.svg', bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(5, 3))

sns.violinplot(ax=axes, x='method', y='fit_r_squared', data=summary_df.query('decline_150_250 == True'), 
               cut=0, density_norm='width', palette=colors)
axes.set_ylabel('Coefficient of determination \n (R-squared)')
axes.set_xlabel('')
axes.set_ylim(0, 1)  
axes.set_xticklabels(['ACF full', 'iSTTC full', 'Pearson trial avg', 'STTC trial avg', 'STTC trial concat'], rotation=45, ha='right')
axes.set_xticks(['acf_full', 'isttc_full', 'pearsonr_trial_avg', 'sttc_trial_avg', 'sttc_trial_concat'])

# Compute mean R-squared for each method
mean_r2 = summary_df.query('decline_150_250 == True').groupby('method')['fit_r_squared'].median()

# Create legend labels
legend_labels = [f"{method}: {mean_r2:.2f}" for method, mean_r2 in mean_r2.items()]

# Add legend on top
handles = [plt.Line2D([0], [0], color=color, lw=4) for color in colors]
legend = axes.legend(handles, legend_labels, title="Median R-squared (ACF decline 150-250ms)", loc='upper center',
                     bbox_to_anchor=(0.5, 1.35), fontsize=8, title_fontsize=9, ncol=1, frameon=False)

sns.despine()

# if save_fig:
#     fig.savefig(fig_folder + 'rsquared_acf_decline_units.png', bbox_inches='tight', dpi=300)
#     fig.savefig(fig_folder + 'rsquared_acf_decline_units.svg', bbox_inches='tight')

In [None]:
n_sttc_better = sum(r_squared_diff > 0)
n_sttc_better_perc = n_sttc_better / len(r_squared_diff) * 100

fig, axes = plt.subplots(1,3, figsize=(12,3))
plt.subplots_adjust(hspace=0.4, wspace=0.4)

sns.histplot(ax=axes[0], x=summary_df.query('method == "acf_full"')['fit_r_squared'].values, 
             y=summary_df.query('method == "isttc_full"')['fit_r_squared'].values, bins=200)
axes[0].plot([0, 1], [0, 1], c='k', transform=axes[0].transAxes)
axes[0].set_aspect('equal', adjustable='box')
axes[0].set_xlabel('ACF R-squared')
axes[0].set_ylabel('iSTTC R-squared')
axes[0].set_title('binned')

sns.scatterplot(ax=axes[1], x=summary_df.query('method == "acf_full"')['fit_r_squared'].values, 
             y=summary_df.query('method == "isttc_full"')['fit_r_squared'].values, s=2)
axes[1].plot([0, 1], [0, 1], c='k', transform=axes[1].transAxes)
axes[1].set_aspect('equal', adjustable='box')
axes[1].set_xlabel('ACF R-squared')
axes[1].set_ylabel('iSTTC R-squared')
axes[1].set_title('scatter')


r_squared_diff = summary_df.query('method == "isttc_full"')['fit_r_squared'].values - summary_df.query('method == "acf_full"')['fit_r_squared'].values
sns.histplot(ax=axes[2], x=r_squared_diff, stat='probability', bins=20, kde=False, color='steelblue')
axes[2].axvline(x=0, lw=1, c='k')
axes[2].set_xlabel('STTC R-squared - \nACF R-squared')
axes[2].set_title('{}% STTC fits \nhave higher R-squared'.format(np.round(n_sttc_better_perc,2)))

fig.suptitle('sttc vs acf, allen, n_units = ' + str(len(r_squared_diff)), y=1.15)

sns.despine()

#fig.savefig(isttc_results_folder_path + 'allen_sttc_vs_pearson.png' , bbox_inches='tight')