In [None]:
# To be able to make edits to repo without having to restart notebook
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, mannwhitneyu, wilcoxon, ttest_rel, ttest_ind
import seaborn as sns
from matplotlib.patches import FancyArrowPatch
from matplotlib.colors import ColorConverter


PROJECT_PATH = os.getcwd()
sys.path.append(PROJECT_PATH)

In [None]:
df = pd.read_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\LEC_full_merged_scores.xlsx')
df.columns

In [None]:
""" QUALITY CHECK DATA """ 
nan_idx = np.where(df['obj_q_0'].isna())[0]
not_nan_idx = np.where(~df['obj_q_0'].isna())[0]
nan_dates = (df['date'][nan_idx].unique())
nan_names = (df['name'][nan_idx].unique())

print('Number of NaN rows: ' + str(len(nan_idx)))
print('Animals with NaN rows: ' + str(nan_names))
print('Dates with NaN rows: ' + str(nan_dates))

# remove rows with NaN values
print('Removing nan rows')
df = df.iloc[not_nan_idx]

In [None]:
""" FILTERING """

""" REMOVE FIELD WITH LOW COVERAGE % """
# remove rows with field_coverage < 0.1
# df = df[df['field_coverage'] >= 0.1]

""" ONLY KEEPING MAIN FIELD """
# remove rows where field_id is not 1 and score is not 'whole' or 'spike_density'
# df = df[(df['field_id'] == 1) | (df['score'] == 'whole') | (df['score'] == 'spike_density')]

""" CHOOSING ANGLE FOR EACH ROW """
# for each row, choose lowest quantile from ['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']
df['obj_q'] = df[['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']].min(axis=1)
df['obj_a'] = df[['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']].idxmin(axis=1)
# convert obj_a to degrees
df['obj_a'] = df['obj_a'].apply(lambda x: int(x.split('_')[2]))
# use obj_wass with angle of min quantile
df['obj_w'] = df.apply(lambda x: x['obj_wass_' + str(x['obj_a'])], axis=1)

""" ASSESSING SIG FOR EACH ROW AT EACH ANGLE """
# obj_s_rows = ['obj_s_0', 'obj_s_90', 'obj_s_180', 'obj_s_270']
# obj_q_rows = ['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']
# for i in range(len(obj_s_rows)):
#     obj_q_x = obj_q_rows[i]
#     df[obj_s_rows[i]] = df[obj_q_x].apply(lambda x: 1 if x < quantile_threshold else 0)


# df2 = df[df['score'] == 'whole'].copy()
df2 = df.copy()
# group_by_cell = ['group', 'name', 'depth', 'date','tetrode', 'unit_id']
# df2 = df2.groupby(group_by_cell).mean().reset_index()
cts = df2[df2['spike_count'] > 30000]['group'].value_counts()
for nm in ['ANT', 'B6', 'NON']:
    if nm not in cts:
        cts[nm] = 0
print('Spike count upper of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(30000 , str(len(df2[df2['spike_count'] > 30000])), 
                       cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['spike_count'] < 100]['group'].value_counts()
for nm in ['ANT', 'B6', 'NON']:
    if nm not in cts:
        cts[nm] = 0
print('Spike count lower of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(100 , str(len(df2[df2['spike_count'] < 100])),
                          cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['information'] < 0.25]['group'].value_counts()
for nm in ['ANT', 'B6', 'NON']:
    if nm not in cts:
        cts[nm] = 0
print('Spatial info of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(0.25 , str(len(df2[df2['information'] < 0.25])),
                            cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['selectivity'] < 5]['group'].value_counts()
for nm in ['ANT', 'B6', 'NON']:
    if nm not in cts:
        cts[nm] = 0
print('Selectivity of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(5 , str(len(df2[df2['selectivity'] < 5])),
                            cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['iso_dist'] < 7.5]['group'].value_counts()
for nm in ['ANT', 'B6', 'NON']:
    if nm not in cts:
        cts[nm] = 0
print('Isolation distance of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(5 , str(len(df2[df2['iso_dist'] < 5])),
                            cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['firing_rate'] > 80]['group'].value_counts()
for nm in ['ANT', 'B6', 'NON']:
    if nm not in cts:
        cts[nm] = 0
print('Firing rate of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(20 , str(len(df2[df2['firing_rate'] > 20])),
                            cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['spike_width'] < 0.00005]['group'].value_counts()
for nm in ['ANT', 'B6', 'NON']:
    if nm not in cts:
        cts[nm] = 0
print('Spike width of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(0.00005 , str(len(df2[df2['spike_width'] < 0.00005])),
                            cts['ANT'], cts['B6'], cts['NON']))

# drop spike count column 
df2 = df2.drop(columns=['spike_count'])
# rename spike_count.1 to spike_count
df2 = df2.rename(columns={'spike_count.1': 'spike_count'})

# df2 = df2[df2['spike_count'] < 30000]
# df2 = df2[df2['spike_count'] > 100]
# df2 = df2[df2['information'] > 0.25]
# df2 = df2[df2['selectivity'] > 5]
df2 = df2[df2['iso_dist'] > 7.5]
# df2 = df2[df2['firing_rate'] < 80]
# df2 = df2[df2['spike_width'] > 0.00005]

print('Remaining cells: ' + str(len(df2)) + ' of which ' + str(len(df2[df2['group'] == 'ANT'])) + ' ANT, ' + str(len(df2[df2['group'] == 'B6'])) + ' B6 and ' + str(len(df2[df2['group'] == 'NON'])) + ' NON')


In [None]:
df2['obj_q'] = df2['obj_q_NO']
df = df2.copy()

In [None]:
# save df
# df = df[df['object_location'] != 'NO']
df.to_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\filtered_df4.xlsx')

In [None]:
# consecutive_sessions_threshold = 2
# quantile_threshold = 0.25
# consecutive = False
# score = 'field'
# main_field_only = False

# df['obj_q'] = df[['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']].min(axis=1)
# df['obj_a'] = df[['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']].idxmin(axis=1)
# # convert obj_a to degrees
# df['obj_a'] = df['obj_a'].apply(lambda x: int(x.split('_')[2]))
# # use obj_wass with angle of min quantile
# df['obj_w'] = df.apply(lambda x: x['obj_wass_' + str(x['obj_a'])], axis=1)

# # drop rows where spike count > 30 000
# # df2 = df[df['spike_count'] < 30000]
# # df2 = df2[df2['spike_count'] > 100]
# # drop rows where spatial info < 0.25 
# # df = df[df['information'] > 0.25]
# # df = df[df['selectivity'] > 5]
# df2 = df[df['iso_dist'] > 5]
# df2 = df2[df2['firing_rate'] < 80]
# # df2 = df2[df2['spike_width'] > 0.00005]
# # drop field count > 3


In [None]:
fig = plt.figure(figsize=(16, 4))

for i in ['ANT','B6','NON']:
    ax = fig.add_subplot(1, 3, ['ANT','B6','NON'].index(i)+1)

    df_to_use = df2[df2['group'] == i]
    # plt.hist(df_to_use[df_to_use['score'] == 'whole']['obj_w'], bins=100)
    sns.histplot(data=df_to_use[df_to_use['score'] == 'whole'], x='obj_q', bins=50, hue='name', kde=False, ax=ax)
    ax.set_title(i)

fig.tight_layout()
plt.show()

In [None]:
df['corr'] = np.nan
for grp in df.groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']):
    corr =  grp[1]['obj_w'].corr(grp[1]['obj_q'])

    df.loc[grp[1].index, 'corr'] = corr

assert df['corr'].isna().sum() == 0

fig = plt.figure(figsize=(16, 4))    

for i in ['ANT','B6','NON']:
    ax = fig.add_subplot(1, 3, ['ANT','B6','NON'].index(i)+1)

    df_to_use = df[df['group'] == i]
    # plt.hist(df_to_use[df_to_use['score'] == 'whole']['obj_w'], bins=100)
    sns.histplot(data=df_to_use[df_to_use['score'] == 'spike_density'], x='corr', bins=np.arange(0,1,0.01), kde=False, ax=ax)
    # sns.scatterplot(data=df_to_use[df_to_use['score'] == 'whole'], x='obj_q', y='obj_w', hue='name', ax=ax)

    # plot corr
    # sns.scatterplot(data=df_to_use[df_to_use['score'] == 'whole'], x='obj_q', y='obj_w', hue='name', ax=ax)
    ax.set_title(i + ' - ' + str(df_to_use[df_to_use['score'] == 'spike_density']['corr'].mean()))
    ax.set_xlim(0, 1)

fig.tight_layout()
plt.show()

In [None]:
# save
# df.to_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\filtered_df.xlsx')

In [None]:
# density plot of iso_dist for each group
fig, ax = plt.subplots(figsize=(12, 4))
dft = df[df['score'] == 'whole']
sns.kdeplot(dft[dft['group'] == 'ANT']['iso_dist'], label='ANT', ax=ax)
sns.kdeplot(dft[dft['group'] == 'NON']['iso_dist'], label='NON', ax=ax)
sns.kdeplot(dft[dft['group'] == 'B6']['iso_dist'], label='B6', ax=ax)
ax.axvline(np.median(dft['iso_dist']), color='black', linestyle='--', label='median: ' + str(np.round(np.median(dft['iso_dist']), 2)))
ax.set_xlabel('Isolation Distance')
ax.legend()
ax.set_ylabel('Density')
ax.set_title('Isolation Distance Distribution')
plt.show()




In [None]:
df.shape

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind
from statsmodels.stats import multitest
import statsmodels.api as sm
from statsmodels.robust.robust_linear_model import RLM


scores_to_use = ['whole', 'spike_density', 'field', 'binary']
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary']
gps = ['ANT', 'B6', 'NON']
gp_colors = ['red', 'blue', 'green']

fig = plt.figure(figsize=(20, 20))

# metric = 'obj_w'
metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    ax = fig.add_subplot(2, 2, i+1)

    # # every row for that score
    to_plot_single = df[df['score'] == score]
    to_plot =  df[df['score'] == score]
    # # scores averaged for each animal
    # to_plot = df[df['score'] == score].groupby(['group', 'name']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name']).count().reset_index()
    # # scores averaged for each session
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
    # scores averaged for each neuron 
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).count().reset_index()

    # get group means + CI
    means = to_plot.groupby('group')[metric].mean().round(2)
    stds = to_plot.groupby('group')[metric].std()
    n = to_plot.groupby('group')[metric].count()
    sems = stds / np.sqrt(n)
    sems = sems.round(2)
    

    # plot boxplot for each group
    # bp = sns.boxplot(x='group', y='obj_w', data=to_plot, ax=ax)
    # sns.swarmplot(x='group', y='obj_w', data=to_plot, ax=ax, color='black', alpha=0.5)
    bps = []
    lbls = []
    for k in range(3):
        c = gp_colors[k]
        bp = ax.boxplot(to_plot[to_plot['group'] == gps[k]][metric], positions=[k], widths=0.5, 
                    notch=False, patch_artist=True,
                    boxprops=dict(facecolor=c, color='k'),
                    capprops=dict(color='k'),
                    whiskerprops=dict(color='k'),
                    flierprops=dict(color='k', markeredgecolor='k'),
                    medianprops=dict(color='k'),
                    showmeans=True, 
                    meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
                                    
        bps.append(bp['boxes'][0])
        lbls.append(str(means[k]) + ' ± ' + str(sems[k]) + ' cm, N = ' + str(n[k]))
    
    ax.set_xticklabels(gps)

    # do t test 
    # res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[1]][metric],usevar='unequal')
    # ks test   
    # res = ks_2samp(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[1]][metric])
    # mann whitney u test
    res = mannwhitneyu(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[1]][metric], alternative='greater')
    # w1 = to_plot_count[to_plot_count['group'] == gps[0]][metric]
    # w2 = to_plot_count[to_plot_count['group'] == gps[1]][metric]
    # w1 = w1 / np.mean(w1)
    # w2 = w2 / np.mean(w2)
    # res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[1]][metric],
    #                 usevar='unequal', weights=(w1, w2), alternative='larger')
    t = res[0]
    p = res[1]
    antvsb6 = p
    if p <= 0.05:
        mx = np.max([np.max(to_plot[to_plot[metric] == gps[0]][metric]),  np.max(to_plot[to_plot['group'] == gps[1]][metric]), np.max(to_plot[to_plot['group'] == gps[2]][metric])])
        mx += 0.01
        ax.plot([0, 0, 1, 1], [mx, mx+.05, mx+.05,mx], lw=1.5, c='k')
        if score != 'field' and score !='binary':
            ax.text(0.315,0.9, '*', transform=ax.transAxes, fontsize=25, color='k')
        else:
            ax.text(0.315,0.85, '*', transform=ax.transAxes, fontsize=25, color='k')
    print('ANT vs B6: t = ' + str(t) + ', p = ' + str(p))
    # res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[2]][metric],usevar='unequal')
    # ks test
    # res = ks_2samp(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[2]][metric])
    # mann whitney u test
    res = mannwhitneyu(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[2]][metric])
    # w1 = to_plot_count[to_plot_count['group'] == gps[0]][metric]
    # w2 = to_plot_count[to_plot_count['group'] == gps[2]][metric]
    # w1 = w1 / np.mean(w1)
    # w2 = w2 / np.mean(w2)
    # res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[2]][metric],
    #                 usevar='unequal', weights=(w1, w2), alternative='larger')
    t = res[0]
    p = res[1]
    antvsnon = p
    if p <= 0.05:
        mx = np.max([np.max(to_plot[to_plot['group'] == gps[0]][metric]),  np.max(to_plot[to_plot['group'] == gps[1]][metric]), np.max(to_plot[to_plot['group'] == gps[2]][metric])])
        mx += 0.02
        ax.plot([0, 0, 2, 2], [mx, mx+.05, mx+.05,mx], lw=1.5, c='k')
        if score != 'field' and score !='binary':
            ax.text(0.5,0.925, '*', transform=ax.transAxes, fontsize=25, color='k')
        else:
            ax.text(0.5,0.9, '*', transform=ax.transAxes, fontsize=25, color='k')
    print('ANT vs NON: t = ' + str(t) + ', p = ' + str(p))
    # res = ttest_ind(to_plot[to_plot['group'] == gps[1]][metric], to_plot[to_plot['group'] == gps[2]][metric],usevar='unequal', alternative='two-sided')
    # ks test
    # res = ks_2samp(to_plot[to_plot['group'] == gps[1]][metric], to_plot[to_plot['group'] == gps[2]][metric])
    # mann whitney u test
    res = mannwhitneyu(to_plot[to_plot['group'] == gps[1]][metric], to_plot[to_plot['group'] == gps[2]][metric], alternative='two-sided')
    # w1 = to_plot_count[to_plot_count['group'] == gps[1]][metric]
    # w2 = to_plot_count[to_plot_count['group'] == gps[2]][metric]
    # w1 = w1 / np.mean(w1)
    # w2 = w2 / np.mean(w2)
    # res = ttest_ind(to_plot[to_plot['group'] == gps[1]][metric], to_plot[to_plot['group'] == gps[2]][metric],
    #                 usevar='unequal', weights=(w1, w2), alternative='larger')
    t = res[0]
    p = res[1]
    b6vsnon = p
    if p <= 0.05:
        mx = np.max([np.max(to_plot[to_plot['group'] == gps[0]][metric]),  np.max(to_plot[to_plot['group'] == gps[1]][metric]), np.max(to_plot[to_plot['group'] == gps[2]][metric])])
        mx += .03
        ax.plot([1, 1, 2, 2], [mx, mx+.05, mx+.05,mx], lw=1.5, c='k')
        if score != 'field' and score !='binary':
            ax.text(0.5+0.315/2,0.95, '*', transform=ax.transAxes, fontsize=25, color='k')
        else:
            ax.text(0.5+0.315/2,0.95, '*', transform=ax.transAxes, fontsize=25, color='k')
    print('B6 vs NON: t = ' + str(t) + ', p = ' + str(p))

    # BH correction
    pvals = [antvsb6, antvsnon, b6vsnon]
    pval_names = ['antvsb6', 'antvsnon', 'b6vsnon']
    reject, pvals_corrected, alphacSidak, alphacBonf = multitest.multipletests(pvals, alpha=0.05, method='fdr_bh')
    print(pval_names)
    print('Corrected p-values: ' + str(reject))


    ax.legend(bps, lbls, loc='upper right')
    ax.set_title(score)
    ax.set_xlabel('Group')
    ax.set_title(titles_to_use[i])
    ax.set_ylabel('Wasserstein distances (cm)')

    group_order = ['B6', 'ANT', 'NON']  # 'B6' becomes the reference group
    to_plot['group'] = pd.Categorical(to_plot['group'], categories=group_order, ordered=True)
    # to_plot_model = to_plot[to_plot['group'] != 'B6']

    model = sm.MixedLM.from_formula(metric + ' ~ group', data=to_plot, groups=to_plot['name'])
    # result = model.fit()

    robust_model = RLM(model.endog, model.exog, M=sm.robust.norms.HuberT())
    result = robust_model.fit()
    
    


    print(score)
    print(result.summary())

fig.suptitle('Averaged by Session')
# fig.suptitle('All indiv. cell-session appearances')
fig.tight_layout()
plt.show()

In [None]:
# https://stackoverflow.com/questions/40044375/how-to-calculate-the-kolmogorov-smirnov-statistic-between-two-weighted-samples

import numpy as np

def ks_w2(data1, data2, wei1, wei2):
    ix1 = np.argsort(data1)
    ix2 = np.argsort(data2)
    data1 = data1[ix1]
    data2 = data2[ix2]
    wei1 = wei1[ix1]
    wei2 = wei2[ix2]
    data = np.concatenate([data1, data2])
    cwei1 = np.hstack([0, np.cumsum(wei1)/sum(wei1)])
    cwei2 = np.hstack([0, np.cumsum(wei2)/sum(wei2)])
    cdf1we = cwei1[[np.searchsorted(data1, data, side='right')]]
    cdf2we = cwei2[[np.searchsorted(data2, data, side='right')]]
    return np.max(np.abs(cdf1we - cdf2we))

from scipy.stats import distributions

def ks_weighted(data1, data2, wei1, wei2, alternative='two-sided'):
    ix1 = np.argsort(data1)
    ix2 = np.argsort(data2)
    data1 = data1[ix1]
    data2 = data2[ix2]
    wei1 = wei1[ix1]
    wei2 = wei2[ix2]
    data = np.concatenate([data1, data2])
    cwei1 = np.hstack([0, np.cumsum(wei1)/sum(wei1)])
    cwei2 = np.hstack([0, np.cumsum(wei2)/sum(wei2)])
    cdf1we = cwei1[np.searchsorted(data1, data, side='right')]
    cdf2we = cwei2[np.searchsorted(data2, data, side='right')]
    d = np.max(np.abs(cdf1we - cdf2we))
    # calculate p-value
    n1 = data1.shape[0]
    n2 = data2.shape[0]
    m, n = sorted([float(n1), float(n2)], reverse=True)
    en = m * n / (m + n)
    if alternative == 'two-sided':
        prob = distributions.kstwo.sf(d, np.round(en))
    else:
        z = np.sqrt(en) * d
        # Use Hodges' suggested approximation Eqn 5.3
        # Requires m to be the larger of (n1, n2)
        expt = -2 * z**2 - 2 * z * (m + 2*n)/np.sqrt(m*n*(m+n))/3.0
        prob = np.exp(expt)
    return d, prob

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind
from statsmodels.stats import multitest
import statsmodels.api as sm
from statsmodels.robust.robust_linear_model import RLM
from scipy.stats import mannwhitneyu
from scipy import stats

np.random.seed(0)

scores_to_use = ['whole', 'spike_density', 'field', 'binary']
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary']
gps = ['ANT', 'B6', 'NON']
gp_colors = ['red', 'blue', 'green']

fig = plt.figure(figsize=(20, 20))

metric = 'obj_w'
# metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    ax = fig.add_subplot(2, 2, i+1)

    # # every row for that score
    to_plot_single = df[df['score'] == score]
    to_plot = df[df['score'] == score]
    # # scores averaged for each animal
    # to_plot = df[df['score'] == score].groupby(['group', 'name']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name']).count().reset_index()
    # # scores averaged for each session
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
    # scores averaged for each neuron 
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).count().reset_index()

    # get group means + CI
    means = to_plot.groupby('group')[metric].mean().round(2)
    stds = to_plot.groupby('group')[metric].std()
    n = to_plot.groupby('group')[metric].count()
    sems = stds / np.sqrt(n)
    sems = sems.round(2)
    

    # plot boxplot for each group
    # bp = sns.boxplot(x='group', y='obj_w', data=to_plot, ax=ax)
    # sns.swarmplot(x='group', y='obj_w', data=to_plot, ax=ax, color='black', alpha=0.5)
    bps = []
    lbls = []
    for k in range(3):
        c = gp_colors[k]
        bp = ax.boxplot(to_plot[to_plot['group'] == gps[k]][metric], positions=[k], widths=0.5, 
                    notch=False, patch_artist=True,
                    boxprops=dict(facecolor=c, color='k'),
                    capprops=dict(color='k'),
                    whiskerprops=dict(color='k'),
                    flierprops=dict(color='k', markeredgecolor='k'),
                    medianprops=dict(color='k'),
                    showmeans=True, 
                    meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
                                    
        bps.append(bp['boxes'][0])
        lbls.append(str(means[k]) + ' ± ' + str(sems[k]) + ' cm, N = ' + str(n[k]))
    
    ax.set_xticklabels(gps)

    # do t test 
    # res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[1]][metric],usevar='unequal')
    w1 = to_plot_count[to_plot_count['group'] == gps[0]][metric]
    w2 = to_plot_count[to_plot_count['group'] == gps[1]][metric]
    # w1 = w1 / np.sum(w1)
    # w2 = w2 / np.sum(w2)
    # res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[1]][metric],
    #                 usevar='unequal', weights=(w1, w2), alternative='larger')
    # weighted mann whitney u test from scipy

    g1 = to_plot[to_plot['group'] == gps[0]][metric]
    g2 = to_plot[to_plot['group'] == gps[1]][metric]
    # weighted_ranksum_g1 = np.sum(stats.rankdata(g1) * w1)
    # weighted_ranksum_g2 = np.sum(stats.rankdata(g2) * w2)
    # observed_statistic = weighted_ranksum_g1 - weighted_ranksum_g2
    # # observed_statistic = np.sum(stats.rankdata(g1) * w1) - (np.sum(w1) * (np.sum(w1) + 1) / 2)
    # npermute = 10000
    # permuted_stats = []
    # for _ in range(npermute):
    #     combined_data = np.concatenate((g1, g2))
    #     combined_weights = np.concatenate((w1, w2))
    #     np.random.shuffle(combined_data)
    #     permuted_statistic = np.sum(np.multiply(stats.rankdata(combined_data), combined_weights))
    #     permuted_stats.append(permuted_statistic)

    # p = (np.sum(np.abs(permuted_stats) >= np.abs(observed_statistic)) + 1) / (npermute + 1)
    # t = observed_statistic

    # n_permutations = 10000
    # permuted_ks_stats = []

    # for _ in range(n_permutations):
    #     shuffled_g1 = np.random.permutation(g1.values)
    #     shuffled_g2 = np.random.permutation(g2.values)
    #     ks_statistic = ks_w2(shuffled_g1, shuffled_g2, w1.values, w2.values)
    #     permuted_ks_stats.append(ks_statistic)

    # observed_ks_statistic = ks_w2(g1.values, g2.values, w1.values, w2.values)

    # p_value = (np.sum(np.abs(permuted_ks_stats) >= np.abs(observed_ks_statistic)) + 1) / (n_permutations + 1)
    # t = observed_ks_statistic

    # t, p = ks_weighted(g1.values, g2.values, w1.values, w2.values)

    # g1 = np.repeat(g1, w1)
    # g2 = np.repeat(g2, w2)
    t, p = mannwhitneyu(g1, g2, alternative='greater')
    # t = res[0]
    # p = res[1]
    antvsb6 = p
    if p <= 0.05:
        mx = np.max([np.max(to_plot[to_plot[metric] == gps[0]][metric]),  np.max(to_plot[to_plot['group'] == gps[1]][metric]), np.max(to_plot[to_plot['group'] == gps[2]][metric])])
        mx += 0.01
        ax.plot([0, 0, 1, 1], [mx, mx+.05, mx+.05,mx], lw=1.5, c='k')
        if score != 'field' and score !='binary':
            ax.text(0.315,0.9, '*', transform=ax.transAxes, fontsize=25, color='k')
        else:
            ax.text(0.315,0.85, '*', transform=ax.transAxes, fontsize=25, color='k')
    print('ANT vs B6: t = ' + str(t) + ', p = ' + str(p))
    # res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[2]][metric],usevar='unequal')
    w1 = to_plot_count[to_plot_count['group'] == gps[0]][metric]
    w2 = to_plot_count[to_plot_count['group'] == gps[2]][metric]
    # w1 = w1 / np.sum(w1)
    # w2 = w2 / np.sum(w2)
    # res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[2]][metric],
    #                 usevar='unequal', weights=(w1, w2), alternative='larger')
    # t = res[0]
    # p = res[1]
    g1 = to_plot[to_plot['group'] == gps[0]][metric]
    g2 = to_plot[to_plot['group'] == gps[2]][metric]
    # observed_statistic = np.sum(stats.rankdata(g1) * w1) - (np.sum(w1) * (np.sum(w1) + 1) / 2)
    # weighted_ranksum_g1 = np.sum(stats.rankdata(g1) * w1)
    # weighted_ranksum_g2 = np.sum(stats.rankdata(g2) * w2)
    # observed_statistic = weighted_ranksum_g1 - weighted_ranksum_g2
    # npermute = 10000
    # permuted_stats = []
    # for _ in range(npermute):
    #     combined_data = np.concatenate((g1, g2))
    #     combined_weights = np.concatenate((w1, w2))
    #     np.random.shuffle(combined_data)
    #     permuted_statistic = np.sum(np.multiply(stats.rankdata(combined_data), combined_weights))
    #     permuted_stats.append(permuted_statistic)

    # p = (np.sum(np.abs(permuted_stats) >= np.abs(observed_statistic)) + 1) / (npermute + 1)
    # t = observed_statistic

    # n_permutations = 10000
    # permuted_ks_stats = []

    # for _ in range(n_permutations):
    #     shuffled_g1 = np.random.permutation(g1.values)
    #     shuffled_g2 = np.random.permutation(g2.values)
    #     ks_statistic = ks_w2(shuffled_g1, shuffled_g2, w1.values, w2.values)
    #     permuted_ks_stats.append(ks_statistic)

    # observed_ks_statistic = ks_w2(g1.values, g2.values, w1.values, w2.values)

    # p_value = (np.sum(np.abs(permuted_ks_stats) >= np.abs(observed_ks_statistic)) + 1) / (n_permutations + 1)
    # t = observed_ks_statistic

    # t, p = ks_weighted(g1.values, g2.values, w1.values, w2.values)

    # g1 = np.repeat(g1, w1)
    # g2 = np.repeat(g2, w2)
    t, p = mannwhitneyu(g1, g2, alternative='greater')
    antvsnon = p
    if p <= 0.05:
        mx = np.max([np.max(to_plot[to_plot['group'] == gps[0]][metric]),  np.max(to_plot[to_plot['group'] == gps[1]][metric]), np.max(to_plot[to_plot['group'] == gps[2]][metric])])
        mx += 0.02
        ax.plot([0, 0, 2, 2], [mx, mx+.05, mx+.05,mx], lw=1.5, c='k')
        if score != 'field' and score !='binary':
            ax.text(0.5,0.925, '*', transform=ax.transAxes, fontsize=25, color='k')
        else:
            ax.text(0.5,0.9, '*', transform=ax.transAxes, fontsize=25, color='k')
    print('ANT vs NON: t = ' + str(t) + ', p = ' + str(p))
    # res = ttest_ind(to_plot[to_plot['group'] == gps[1]][metric], to_plot[to_plot['group'] == gps[2]][metric],usevar='unequal', alternative='two-sided')
    w1 = to_plot_count[to_plot_count['group'] == gps[1]][metric]
    w2 = to_plot_count[to_plot_count['group'] == gps[2]][metric]
    # w1 = w1 / np.sum(w1)
    # w2 = w2 / np.sum(w2)
    # res = ttest_ind(to_plot[to_plot['group'] == gps[1]][metric], to_plot[to_plot['group'] == gps[2]][metric],
    #                 usevar='unequal', weights=(w1, w2), alternative='larger')
    # t = res[0]
    # p = res[1]
    g1 = to_plot[to_plot['group'] == gps[1]][metric]
    g2 = to_plot[to_plot['group'] == gps[2]][metric]
    # observed_statistic = np.sum(stats.rankdata(g1) * w1) - (np.sum(w1) * (np.sum(w1) + 1) / 2)
    # weighted_ranksum_g1 = np.sum(stats.rankdata(g1) * w1)
    # weighted_ranksum_g2 = np.sum(stats.rankdata(g2) * w2)
    # observed_statistic = weighted_ranksum_g1 - weighted_ranksum_g2

    # npermute = 10000
    # permuted_stats = []
    # for _ in range(npermute):
    #     combined_data = np.concatenate((g1, g2))
    #     combined_weights = np.concatenate((w1, w2))
    #     np.random.shuffle(combined_data)
    #     permuted_statistic = np.sum(np.multiply(stats.rankdata(combined_data), combined_weights))
    #     permuted_stats.append(permuted_statistic)
    
    # p = (np.sum(np.abs(permuted_stats) >= np.abs(observed_statistic)) + 1) / (npermute + 1)
    # t = observed_statistic

    # t, p = ks_weighted(g1.values, g2.values, w1.values, w2.values)

    # n_permutations = 10000
    # permuted_ks_stats = []

    # for _ in range(n_permutations):
    #     shuffled_g1 = np.random.permutation(g1.values)
    #     shuffled_g2 = np.random.permutation(g2.values)
    #     ks_statistic = ks_w2(shuffled_g1, shuffled_g2, w1.values, w2.values)
    #     permuted_ks_stats.append(ks_statistic)

    # observed_ks_statistic = ks_w2(g1.values, g2.values, w1.values, w2.values)

    # p_value = (np.sum(np.abs(permuted_ks_stats) >= np.abs(observed_ks_statistic)) + 1) / (n_permutations + 1)
    # t = observed_ks_statistic


    # g1 = np.repeat(g1, w1)
    # g2 = np.repeat(g2, w2)
    t, p = mannwhitneyu(g1, g2, alternative='greater')
    b6vsnon = p
    if p <= 0.05:
        mx = np.max([np.max(to_plot[to_plot['group'] == gps[0]][metric]),  np.max(to_plot[to_plot['group'] == gps[1]][metric]), np.max(to_plot[to_plot['group'] == gps[2]][metric])])
        mx += .03
        ax.plot([1, 1, 2, 2], [mx, mx+.05, mx+.05,mx], lw=1.5, c='k')
        if score != 'field' and score !='binary':
            ax.text(0.5+0.315/2,0.95, '*', transform=ax.transAxes, fontsize=25, color='k')
        else:
            ax.text(0.5+0.315/2,0.95, '*', transform=ax.transAxes, fontsize=25, color='k')
    print('B6 vs NON: t = ' + str(t) + ', p = ' + str(p))

    # BH correction
    pvals = [antvsb6, antvsnon, b6vsnon]
    pval_names = ['antvsb6', 'antvsnon', 'b6vsnon']
    reject, pvals_corrected, alphacSidak, alphacBonf = multitest.multipletests(pvals, alpha=0.05, method='fdr_bh')
    print(pval_names)
    print('Corrected p-values: ' + str(reject))


    ax.legend(bps, lbls, loc='upper right')
    ax.set_title(score)
    ax.set_xlabel('Group')
    ax.set_title(titles_to_use[i])
    ax.set_ylabel('Wasserstein quantiles')

    group_order = ['B6', 'ANT', 'NON']  # 'B6' becomes the reference group
    to_plot['group'] = pd.Categorical(to_plot['group'], categories=group_order, ordered=True)

    # model = sm.MixedLM.from_formula(metric + ' ~ group', data=to_plot, groups=to_plot['name'])
    # formula = metric + ' ~ C(group):C(group)'

    # to_plot_sub = to_plot[to_plot['group'] != 'NON']
    # to_plot_sub = to_plot_sub[to_plot_sub['group'] != 'B6']
    # to_plot_sub['group'] = pd.Categorical(to_plot_sub['group'], categories=['NON', 'ANT'], ordered=True)
    to_plot_sub = to_plot
    to_plot_sub['combined_group'] = to_plot_sub['session_id'] + '_' + to_plot_sub['name']
    # formula = metric + ' ~ C(session_id)'
    formula = metric + ' ~ C(group):C(group)'
    model = sm.MixedLM.from_formula(formula, data=to_plot_sub, groups=to_plot_sub['combined_group'])

    result = model.fit()

    # robust_model = RLM(model.endog, model.exog, M=sm.robust.norms.HuberT())
    # result = robust_model.fit()
    
    


    print(score)
    print(result.summary())

fig.suptitle('All indiv. cell-session appearances')
fig.tight_layout()
plt.show()

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind
from statsmodels.stats import multitest
import statsmodels.api as sm
from statsmodels.robust.robust_linear_model import RLM


scores_to_use = ['whole', 'spike_density', 'field', 'binary']
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary']
gps = ['ANT', 'B6', 'NON']
gp_colors = ['red', 'blue', 'green']

fig = plt.figure(figsize=(20, 20))

metric = 'obj_w'
# metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    ax = fig.add_subplot(2, 2, i+1)

    # # every row for that score
    to_plot_single = df[df['score'] == score]
    # # scores averaged for each animal
    # to_plot = df[df['score'] == score].groupby(['group', 'name']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name']).count().reset_index()
    # # scores averaged for each session
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
    # scores averaged for each neuron 
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).count().reset_index()

    # get group means + CI
    means = to_plot.groupby('group')[metric].mean().round(2)
    stds = to_plot.groupby('group')[metric].std()
    n = to_plot.groupby('group')[metric].count()
    sems = stds / np.sqrt(n)
    sems = sems.round(2)
    

    # plot boxplot for each group
    # bp = sns.boxplot(x='group', y='obj_w', data=to_plot, ax=ax)
    # sns.swarmplot(x='group', y='obj_w', data=to_plot, ax=ax, color='black', alpha=0.5)
    bps = []
    lbls = []
    for k in range(3):
        c = gp_colors[k]
        bp = ax.boxplot(to_plot[to_plot['group'] == gps[k]][metric], positions=[k], widths=0.5, 
                    notch=False, patch_artist=True,
                    boxprops=dict(facecolor=c, color='k'),
                    capprops=dict(color='k'),
                    whiskerprops=dict(color='k'),
                    flierprops=dict(color='k', markeredgecolor='k'),
                    medianprops=dict(color='k'),
                    showmeans=True, 
                    meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
                                    
        bps.append(bp['boxes'][0])
        lbls.append(str(means[k]) + ' ± ' + str(sems[k]) + ' cm, N = ' + str(n[k]))
    
    ax.set_xticklabels(gps)

    # do t test 
    # res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[1]][metric],usevar='unequal')
    w1 = to_plot_count[to_plot_count['group'] == gps[0]][metric]
    w2 = to_plot_count[to_plot_count['group'] == gps[1]][metric]
    w1 = w1 / np.mean(w1)
    w2 = w2 / np.mean(w2)
    res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[1]][metric],
                    usevar='unequal', weights=(w1, w2), alternative='larger')
    t = res[0]
    p = res[1]
    antvsb6 = p
    if p <= 0.05:
        mx = np.max([np.max(to_plot[to_plot[metric] == gps[0]][metric]),  np.max(to_plot[to_plot['group'] == gps[1]][metric]), np.max(to_plot[to_plot['group'] == gps[2]][metric])])
        mx += 0.01
        ax.plot([0, 0, 1, 1], [mx, mx+.05, mx+.05,mx], lw=1.5, c='k')
        if score != 'field' and score !='binary':
            ax.text(0.315,0.9, '*', transform=ax.transAxes, fontsize=25, color='k')
        else:
            ax.text(0.315,0.85, '*', transform=ax.transAxes, fontsize=25, color='k')
    print('ANT vs B6: t = ' + str(t) + ', p = ' + str(p))
    # res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[2]][metric],usevar='unequal')
    w1 = to_plot_count[to_plot_count['group'] == gps[0]][metric]
    w2 = to_plot_count[to_plot_count['group'] == gps[2]][metric]
    w1 = w1 / np.mean(w1)
    w2 = w2 / np.mean(w2)
    res = ttest_ind(to_plot[to_plot['group'] == gps[0]][metric], to_plot[to_plot['group'] == gps[2]][metric],
                    usevar='unequal', weights=(w1, w2), alternative='larger')
    t = res[0]
    p = res[1]
    antvsnon = p
    if p <= 0.05:
        mx = np.max([np.max(to_plot[to_plot['group'] == gps[0]][metric]),  np.max(to_plot[to_plot['group'] == gps[1]][metric]), np.max(to_plot[to_plot['group'] == gps[2]][metric])])
        mx += 0.02
        ax.plot([0, 0, 2, 2], [mx, mx+.05, mx+.05,mx], lw=1.5, c='k')
        if score != 'field' and score !='binary':
            ax.text(0.5,0.925, '*', transform=ax.transAxes, fontsize=25, color='k')
        else:
            ax.text(0.5,0.9, '*', transform=ax.transAxes, fontsize=25, color='k')
    print('ANT vs NON: t = ' + str(t) + ', p = ' + str(p))
    # res = ttest_ind(to_plot[to_plot['group'] == gps[1]][metric], to_plot[to_plot['group'] == gps[2]][metric],usevar='unequal', alternative='two-sided')
    w1 = to_plot_count[to_plot_count['group'] == gps[1]][metric]
    w2 = to_plot_count[to_plot_count['group'] == gps[2]][metric]
    w1 = w1 / np.mean(w1)
    w2 = w2 / np.mean(w2)
    res = ttest_ind(to_plot[to_plot['group'] == gps[1]][metric], to_plot[to_plot['group'] == gps[2]][metric],
                    usevar='unequal', weights=(w1, w2), alternative='larger')
    t = res[0]
    p = res[1]
    b6vsnon = p
    if p <= 0.05:
        mx = np.max([np.max(to_plot[to_plot['group'] == gps[0]][metric]),  np.max(to_plot[to_plot['group'] == gps[1]][metric]), np.max(to_plot[to_plot['group'] == gps[2]][metric])])
        mx += .03
        ax.plot([1, 1, 2, 2], [mx, mx+.05, mx+.05,mx], lw=1.5, c='k')
        if score != 'field' and score !='binary':
            ax.text(0.5+0.315/2,0.95, '*', transform=ax.transAxes, fontsize=25, color='k')
        else:
            ax.text(0.5+0.315/2,0.95, '*', transform=ax.transAxes, fontsize=25, color='k')
    print('B6 vs NON: t = ' + str(t) + ', p = ' + str(p))

    # BH correction
    pvals = [antvsb6, antvsnon, b6vsnon]
    pval_names = ['antvsb6', 'antvsnon', 'b6vsnon']
    reject, pvals_corrected, alphacSidak, alphacBonf = multitest.multipletests(pvals, alpha=0.05, method='fdr_bh')
    print(pval_names)
    print('Corrected p-values: ' + str(reject))


    ax.legend(bps, lbls, loc='upper right')
    ax.set_title(score)
    ax.set_xlabel('Group')
    ax.set_title(titles_to_use[i])
    ax.set_ylabel('Wasserstein distances (cm)')

    group_order = ['B6', 'ANT', 'NON']  # 'B6' becomes the reference group
    to_plot['group'] = pd.Categorical(to_plot['group'], categories=group_order, ordered=True)

    model = sm.MixedLM.from_formula(metric + ' ~ group', data=to_plot, groups=to_plot['name'])
    # result = model.fit()

    robust_model = RLM(model.endog, model.exog, M=sm.robust.norms.HuberT())
    result = robust_model.fit()
    
    


    print(score)
    print(result.summary())

fig.suptitle('Averaged by Session')
fig.tight_layout()
plt.show()

In [None]:
# Access fixed effects coefficients
fixed_effects = result.params
print("Fixed Effects Coefficients:\n", fixed_effects)

# Access random effects variances
random_effects_variances = result.cov_re
print("Random Effects Variances:\n", random_effects_variances)

# Predicted values
predicted_values = result.fittedvalues
print("Predicted Values:\n", predicted_values)

# Likelihood ratio test for random effects significance
print("Likelihood Ratio Test for Random Effects Significance:\n")
print(result.summary().tables[1])

# Akaike Information Criterion (AIC) and Bayesian Information Criterion (BIC)
print("AIC:", result.aic)
print("BIC:", result.bic)

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind

scores_to_use = ['whole', 'spike_density', 'field', 'binary']
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary']
gps = np.unique(df['session_id'])
gp_labels = ['ANT', 'B6', 'NON']
gp_colors = ['r', 'b', 'g']
np.random.seed(0)
def _single_shuffle(to_plot_shuffle, sesgp, metric, gplbl):
    # to_plot_shuffle['group'] = np.random.permutation(to_plot_shuffle['group'].loc[:].values)
    vals = to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'].to_numpy()
    # shuffle the vals
    np.random.shuffle(vals)
    to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'] = vals

    use = to_plot_shuffle[to_plot_shuffle['session_id'] == sesgp]
    use = use[use['group'] == gplbl][metric]
    mn = np.mean(use)
    # print(metric, sesgp, mn, gplbl)
    return mn

fig = plt.figure(figsize=(25, 20))

metric = 'obj_w'
# metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    ax = fig.add_subplot(2, 2, i+1)

    # # every row for that score
    # to_plot = df[df['score'] == score]
    # to_plot_shuffle = to_plot.copy()
    # # scores averaged for each animal
    # to_plot = df[df['score'] == score].groupby(['group', 'name']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name']).count().reset_index()
    # scores averaged for each session
    to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
    to_plot_shuffle = to_plot.copy()
    # scores averaged for each neuron 
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).count().reset_index()

    

    # plot boxplot for each group
    # bp = sns.boxplot(x='group', y='obj_w', data=to_plot, ax=ax)
    # sns.swarmplot(x='group', y='obj_w', data=to_plot, ax=ax, color='black', alpha=0.5)
    bps = []
    lbls = []
    shuffle_count = 1000
    mns = [[] for j in range(3)]
    mns_shuffle = [[[] for sc in range(shuffle_count)] for j in range(3)]
    for k in range(len(gps)):
        # c = gp_colors[k]
        for j in range(3):
            # get group means + CI
            to_plot_now = to_plot[to_plot['group'] == gp_labels[j]]
            # means = to_plot.groupby('session_id')[metric].mean().round(2)
            # stds = to_plot.groupby('session_id')[metric].std()
            # n = to_plot.groupby('session_id')[metric].count()
            # sems = stds / np.sqrt(n)
            # sems = sems.round(2)
            bp = ax.boxplot(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], positions=[k*3+j*.5], widths=0.5, 
                        notch=False, patch_artist=True,
                        boxprops=dict(facecolor=gp_colors[j],color='k'),
                        capprops=dict(color='k'),
                        whiskerprops=dict(color='k'),
                        flierprops=dict(color='k', markeredgecolor='k'),
                        medianprops=dict(color='k'),
                        showmeans=False, 
                        meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
            if k == 0:
                bps.append(bp['boxes'][0])   

            mn = np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
            if mn == mn:
                mns[j].append(mn)
                        
            # lbls.append(str(means[j]) + ' ± ' + str(sems[j]) + ' cm, N = ' + str(n[j]))


        for j in range(3):
            for sc in range(shuffle_count):
                out = _single_shuffle(to_plot_shuffle, gps[k], metric, gp_labels[j])
                if out == out:
                    mns_shuffle[j][sc].append(out)
            
    ax.set_xticks(np.arange(len(gps)) * 3 + 1.25/2)
    ax.set_xticklabels(gps)


    # mann kendall
    import pymannkendall as mk
    from statsmodels.stats import multitest
    lbls = []
    empirical = []
    ps = []
    mns_shuffle = np.array(mns_shuffle)
    print(mns_shuffle.shape)
    print('Metric: ' + score)
    for j in range(3):
        # polyfit 

        memp, c = np.polyfit(np.arange(len(mns[j])), mns[j], 1)
        empirical.append(memp)

        shuffled = []
        for sc in range(shuffle_count):
            # if len(mns_shuffle[j,sc]) > 0:
            ses_dist = mns_shuffle[j,sc]
            mshuffled, c = np.polyfit(np.arange(len(ses_dist)), ses_dist, 1)
            shuffled.append(mshuffled)
                # result = mk.original_test(ses_dist)
                # ps.append(result.p)


        # # two sided p-value 
        pgreater = np.sum(np.array(shuffled) < empirical[j]) / len(shuffled) 
        plower = np.sum(np.array(shuffled) > empirical[j]) / len(shuffled)
        pvalue = np.min([pgreater, plower]) * 2
        ps.append(pvalue)
        print('Group: ' + gp_labels[j])
        print('Empirical: ' + str(empirical[j]))
        print('Shuffled: ' + str(np.mean(shuffled)) + ' ± ' + str(np.std(shuffled)))
        # print('p-value: ' + str(np.min([pgreater, plower]) * 2))


        if empirical[j] > np.mean(shuffled):
            tag = 'greater'
        elif empirical[j] < np.mean(shuffled):
            tag = 'lower'
        lbl = 'Slope is ' + str(tag) + ' than shuffled: ' + str(np.round(memp, 2)) + ' , p = ' + str(np.round(pvalue, 3)) 
        lbls.append(lbl)

    # result = mk.original_test(ses_dist)
    print(ps)
    out = multitest.multipletests(ps, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    cc = 0
    for case in out[0]:
        if case:
            lbls[cc] = lbls[cc] + ' & is sig after BH correction'
        else:
            lbls[cc] = lbls[cc] + ' & is NOT sig after BH correction'
        cc += 1
    print(out)

    ax.legend(bps, lbls, loc='upper right')
    ax.set_title(score)
    ax.set_xlabel('Session')
    ax.set_title(titles_to_use[i])
    ax.set_ylabel('Wasserstein distnaces (cm)')

fig.suptitle('All indiv. cell-session appearances')
fig.tight_layout()
plt.show()

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind

scores_to_use = ['whole', 'spike_density', 'field', 'binary', 'centroid', 'firing_rate']
quad_arrange = [[0,0],[0,1],[1,0],[1,1], [2,0], [2,1]]
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary', 'Centroid', 'Firing Rate']
gps = np.unique(df['session_id'])
gp_labels = ['ANT', 'B6', 'NON']
gp_colors = ['r', 'b', 'g']
np.random.seed(0)
def _single_shuffle(to_plot_shuffle, sesgp, metric, gplbl):
    # to_plot_shuffle['group'] = np.random.permutation(to_plot_shuffle['group'].loc[:].values)
    vals = to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'].to_numpy()
    # shuffle the vals
    np.random.shuffle(vals)
    to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'] = vals

    use = to_plot_shuffle[to_plot_shuffle['session_id'] == sesgp]
    use = use[use['group'] == gplbl][metric]
    mn = np.mean(use)
    # print(metric, sesgp, mn, gplbl)
    return mn

fig = plt.figure(figsize=(23, 46))
gs_main = gridspec.GridSpec(3, 2, width_ratios=[1,1], height_ratios=[1,1, 1])  # Adjust width_ratios and height_ratios as needed


# metric = 'obj_w'
metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    # scores averaged for each session
    if score != 'firing_rate':
        to_plot = df[df['score'] == score]
        # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
        # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
        # to_plot_shuffle = to_plot.copy()
    else:
        to_plot = df[df['score'] == 'whole']
        # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
        # to_plot_count = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
        # to_plot_shuffle = to_plot.copy()


    row, col = quad_arrange[i]
    gs_sub = gridspec.GridSpecFromSubplotSpec(5, 1, subplot_spec=gs_main[row, col], height_ratios=[12,12,1,1,1], hspace=0.1)

    ax = plt.subplot(gs_sub[0])
    axf = ax
    bps = []
    lbls = []
    means = []
    sems = []
    n = []
    comps = {}
    comp_maxs = []
    for k in range(3):
        c = gp_colors[k]
        if score != 'firing_rate':
            mtouse = metric 
        else:
            mtouse = 'firing_rate'
        bp = ax.boxplot(to_plot[to_plot['group'] == gp_labels[k]][mtouse], positions=[k], widths=0.5, 
                    notch=False, patch_artist=True,
                    boxprops=dict(facecolor=c, color='k'),
                    capprops=dict(color='k'),
                    whiskerprops=dict(color='k'),
                    flierprops=dict(color='k', markeredgecolor='k'),
                    medianprops=dict(color='k'),
                    showmeans=True, 
                    meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
        comp_maxs.append(np.max(to_plot[to_plot['group'] == gp_labels[k]][mtouse]))
        for k2 in range(3):
            if k2 != k:
                lbl_pair = np.sort([gp_labels[k], gp_labels[k2]])
                comp_group = lbl_pair[0] + '_' + lbl_pair[1]
                if comp_group not in comps.keys():
                    comps[comp_group] = np.nan

                    if 'B6' in comp_group and 'ANT' in comp_group:
                        if score != 'firing_rate':
                            to_plot_model = df[df['score'] == score]
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        else:
                            to_plot_model = df[df['score'] == 'whole']
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        model_data = to_plot_model[to_plot_model['group'].isin(['B6', 'ANT'])]
                        group_order = ['B6', 'ANT']  # 'B6' becomes the reference group
                    elif 'B6' in comp_group and 'NON' in comp_group:
                        if score != 'firing_rate':
                            to_plot_model = df[df['score'] == score]
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        else:
                            to_plot_model = df[df['score'] == 'whole']
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        model_data = to_plot_model[to_plot_model['group'].isin(['B6', 'NON'])]
                        group_order = ['B6', 'NON']
                    elif 'ANT' in comp_group and 'NON' in comp_group:
                        if score != 'firing_rate':
                            to_plot_model = df[df['score'] == score]
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        else:
                            to_plot_model = df[df['score'] == 'whole']
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        model_data = to_plot_model[to_plot_model['group'].isin(['NON', 'ANT'])]
                        group_order = ['NON', 'ANT']
                    # model_data['group'] = pd.Categorical(model_data['group'], categories=group_order, ordered=True)
                    model_data.loc[:, 'group'] = pd.Categorical(model_data.loc[:, 'group'], categories=group_order, ordered=True)

                    if score != 'firing_rate':
                        # model = sm.MixedLM.from_formula(metric + ' ~ group', data=model_data, groups=model_data['name'])

                        # https://stats.stackexchange.com/questions/31300/dealing-with-0-1-values-in-a-beta-regression
                        # 𝑥′=𝑥(𝑁−1)+𝑠𝑁
                        # where N is the sample size and s is a constant between 0 and 1. 
                        # From a Bayesian standpoint, s acts as if we are taking a prior into account. 
                        # A reasonable choice for s would be .5.

                        model_data['obj_q'] = (model_data['obj_q'] * (len(model_data['obj_q']) - 1) + .5) / len(model_data['obj_q'])
                        model = BetaModel.from_formula(metric + ' ~ group', data=model_data, groups=model_data['name'])

                    else:
                        # model = sm.MixedLM.from_formula('firing_rate ~ group', data=model_data, groups=model_data['name'])

                        model_data['obj_q'] = (model_data['obj_q'] * (len(model_data['obj_q']) - 1) + .5) / len(model_data['obj_q'])
                        model = BetaModel.from_formula('firing_rate ~ group', data=model_data, groups=model_data['name'])
                    result = model.fit()
                    # result = model.fit()
                    # robust_model = RLM(model.endog, model.exog, M=sm.robust.norms.HuberT())
                    # result = robust_model.fit()
                    emp_coeff = result.params[1]

                                        # Create a dictionary to map animal names to groups
                    itms = model_data.groupby(['name','group']).groups.keys()
                    unique_animal_group_pairs_dict = {itm[0]: itm[1] for itm in itms}


                    bootstrap_coeffs = []
                    prev_sample = None
                    for b in range(1000):
                        resampled_animal_group_pairs = np.random.choice(list(unique_animal_group_pairs_dict.keys()), size=len(unique_animal_group_pairs_dict), replace=True)
                        sample = None
                        for pr in resampled_animal_group_pairs:
                            if sample is None:
                                sample = model_data[model_data['name'] == pr]
                                sample = sample.sample(frac=1, replace=True)
                            else:
                                sp = model_data[model_data['name'] == pr]
                                sp = sp.sample(frac=1, replace=True)
                                sample = pd.concat([sample, sp])
                        # check sample and prev_sample are NOT the same
                        if prev_sample is not None:
                            assert not prev_sample.equals(sample), 'Sample and previous sample are the same'

                        # model = sm.MixedLM.from_formula(metric + ' ~ group', data=sample, groups=sample['name'])
                        # robust_model = RLM(model.endog, model.exog, M=sm.robust.norms.HuberT())
                        if score != 'firing_rate':
                            model = BetaModel.from_formula(metric + '~ group', data=sample, groups=sample['name'])
                        else:
                            model = BetaModel.from_formula('firing_rate ~ group', data=sample, groups=sample['name'])
                        result = model.fit()
                        coeff = result.params[1]
                        bootstrap_coeffs.append(coeff)
                    p = np.sum(np.abs(bootstrap_coeffs) > np.abs(emp_coeff)) / len(bootstrap_coeffs)
                    comps[comp_group] = p

        means.append(np.mean(to_plot[to_plot['group'] == gp_labels[k]][metric]))
        sems.append(np.std(to_plot[to_plot['group'] == gp_labels[k]][metric]) / np.sqrt(len(to_plot[to_plot['group'] == gp_labels[k]][metric])))
        n.append(len(to_plot[to_plot['group'] == gp_labels[k]][metric]))
                                    
        bps.append(bp['boxes'][0])
        # lbls.append(str(means[k]) + ' ± ' + str(sems[k]) + ' cm, N = ' + str(n[k]))
        lbls.append(gp_labels[k])
        #  + ': N = ' + str(n[k]))
    
    ax.set_xticklabels(gp_labels)
    ax.legend(bps, lbls, loc='upper right')
    ax.set_xlabel('Group')
    ax.set_ylabel('EMD quantile')

    # benjamini hochberg correction
    kys, vals = zip(*comps.items())
    accepted, pvals_corrected, _, _ = multipletests(vals, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    
    p_count = 0
    for comp_key, val in comps.items():
        comparison = comps[comp_key]
        if 'ANT' in comp_key and 'B6' in comp_key:
            nme = [0,1]
        elif 'ANT' in comp_key and 'NON' in comp_key:
            nme = [0,2]
        elif 'B6' in comp_key and 'NON' in comp_key:
            nme = [1,2]

        # if accepted[k]:
        barplot_annotate_brackets(nme[0],nme[1],pvals_corrected[p_count],[0,1,2], comp_maxs, maxasterix=5)
        
        p_count += 1

    # ax = fig.add_subplot(2, 2, i+1)
    ax = plt.subplot(gs_sub[1])
    ax1 = ax

    # # every row for that score
    if score != 'firing_rate':
        to_plot = df[df['score'] == score]
        # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    else:
        to_plot = df[df['score'] == 'whole']
        # to_plot = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
        metric = 'firing_rate'
        # to_plot = to_plot[to_plot['depth']]
    to_plot_shuffle = to_plot.copy()



    bps = []
    lbls = []
    shuffle_count = 1000
    mns = [[] for j in range(3)]
    mns_shuffle = [[[] for sc in range(shuffle_count)] for j in range(3)]
    
    session_comps = {}
    group_ses_frs = {'ANT': [], 'B6': [], 'NON': []}
    for k in range(len(gps)):
        # c = gp_colors[k]
        ses_visited = []
        # ses_frs = []
        for j in range(3):
            # get group means + CI
            to_plot_now = to_plot[to_plot['group'] == gp_labels[j]]
            ses_fr = to_plot_now[to_plot_now['session_id'] == gps[k]][metric]
            # ses_fr = to_plot_now[to_plot_now['session_id'] == gps[k]]['information']
            bp = ax.boxplot(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], positions=[k*3+j*.5], widths=0.5, 
                        notch=False, patch_artist=True,
                        boxprops=dict(facecolor=gp_colors[j],color='k'),
                        capprops=dict(color='k'),
                        whiskerprops=dict(color='k'),
                        flierprops=dict(color='k', markeredgecolor='k'),
                        medianprops=dict(color='k'),
                        showmeans=False, 
                        meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
            if k == 0:
                bps.append(bp['boxes'][0])   

            mn = np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
            if mn == mn:
                mns[j].append(mn)

            for j2 in range(3):
                to_plot_now2 = to_plot[to_plot['group'] == gp_labels[j2]]
                if j != j2:
                    # welch's t-test
                    if len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]) > 1 and len(to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric]) > 1:
                        # res = ttest_ind(to_plot_now[to_plot_now['session_id'] == gps[k]][metric],
                        #                     to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric],
                        #                     usevar='unequal', alternative='two-sided')
                        # pvalue = res[1]

                        # mann whitney u test
                        _, pvalue = mannwhitneyu(to_plot_now[to_plot_now['session_id'] == gps[k]][metric],
                                           to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric],
                                          alternative='two-sided')
                         
                    
                        sorted_labels = np.sort([gp_labels[j], gp_labels[j2]])
                        pair_id = sorted_labels[0] + '_' + sorted_labels[1]
                        if pair_id not in session_comps:
                            session_comps[pair_id] = []
                        if pair_id not in ses_visited:
                            session_comps[pair_id].append(pvalue)
                            # session_comps[pair_id].append(pvalue)
                            ses_visited.append(pair_id)
            if len(ses_fr) > 0 and np.mean(ses_fr) == np.mean(ses_fr):       
                group_ses_frs[gp_labels[j]].append(np.mean(ses_fr))

        for j in range(3):
            for sc in range(shuffle_count):
                out = _single_shuffle(to_plot_shuffle, gps[k], metric, gp_labels[j])
                if out == out:
                    mns_shuffle[j][sc].append(out)

    from mpl_toolkits.axes_grid1 import make_axes_locatable
    from scipy import interpolate
    from scipy.interpolate import griddata


    # benjamini-hochberg correction
    pvalsANT_B6 = np.array(session_comps['ANT_B6'])
    pvalsB6_NON = np.array(session_comps['B6_NON'])
    pvalsANT_NON = np.array(session_comps['ANT_NON'])
    pvals = np.concatenate((pvalsANT_B6, pvalsB6_NON, pvalsANT_NON))
    reject, pvals_corrected, alphacSidak, alphacBonf = multitest.multipletests(pvals, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    bh_ant_b6 = np.ones(len(pvalsANT_B6))
    bh_ant_b6[pvals_corrected[:len(pvalsANT_B6)] <= 0.05] = 0
    bh_b6_non = np.ones(len(pvalsB6_NON))
    bh_b6_non[pvals_corrected[len(pvalsANT_B6):len(pvalsANT_B6)+len(pvalsB6_NON)] <= 0.05] = 0
    bh_ant_non = np.ones(len(pvalsANT_NON))
    bh_ant_non[pvals_corrected[len(pvalsANT_B6)+len(pvalsB6_NON):] <= 0.05] = 0

    # plt.setp(ax1.get_xticklabels(), visible=False)

    ax = plt.subplot(gs_sub[2])
    # bh_ant_b6 = np.hstack((bh_ant_b6, [.5]))
    # ax.imshow(np.expand_dims(bh_ant_b6, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ant_fr = group_ses_frs['ANT']
    ant_interp = np.linspace(0, len(ant_fr)-1, 1000)
    ant_fr_smooth = np.interp(ant_interp, np.arange(len(ant_fr)), ant_fr)
    im = ax.imshow(np.expand_dims(ant_fr_smooth, 0), cmap='jet', aspect='auto', extent=[0, len(gps), 0, 1])
    # colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ylbl = ax.set_ylabel('ANT', labelpad=15, rotation=0)
    # pos = ylbl.get_position()
    # ylbl.set_position((pos[0], pos[1] -.5))
    ax.set_yticks([])
    # ax.set_ylabel('ANT-B6', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='x', which='both', bottom=False, top=False)


    ax = plt.subplot(gs_sub[3])
    # bh_b6_non = np.hstack((bh_b6_non, [.5]))
    # ax.imshow(np.expand_dims(bh_b6_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    b6_fr = group_ses_frs['B6']
    b6_interp = np.linspace(0, len(b6_fr)-1, 1000)
    b6_fr_smooth = np.interp(b6_interp, np.arange(len(b6_fr)), b6_fr)
    im = ax.imshow(np.expand_dims(b6_fr_smooth, 0), cmap='jet', aspect='auto', extent=[0, len(gps), 0, 1])
    # colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ylbl = ax.set_ylabel('B6', labelpad=15, rotation=0)
    # pos = ylbl.get_position()
    # ylbl.set_position((pos[0], pos[1]-.5))
    ax.set_yticks([])
    # ax.set_ylabel('B6-NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='x', which='both', bottom=False, top=False)


    ax = plt.subplot(gs_sub[4])
    # # bh_ant_non = np.hstack((bh_ant_non, [.5]))
    # ax.imshow(np.expand_dims(bh_ant_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    non_fr = group_ses_frs['NON']
    non_interp = np.linspace(0, len(non_fr)-1, 1000)
    non_fr_smooth = np.interp(non_interp, np.arange(len(non_fr)), non_fr)
    im = ax.imshow(np.expand_dims(non_fr_smooth, 0), cmap='jet', aspect='auto', extent=[0, len(gps), 0, 1])
    # colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ylbl = ax.set_ylabel('NON', labelpad=15, rotation=0)
    # pos = ylbl.get_position()
    # ylbl.set_position((pos[0], pos[1] - .5))
    plt.setp(ax.get_xticklabels(), visible=False)

    ax.set_yticks([])
    # ylbl = ax.set_ylabel('ANT-NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='x', which='both', bottom=False, top=False)


    # mann kendall
    import pymannkendall as mk
    from statsmodels.stats import multitest
    lbls = []
    lbl_colors = []
    empirical = []
    ps = []
    mns_shuffle = np.array(mns_shuffle)
    print(mns_shuffle.shape)
    print('Metric: ' + score)
    for j in range(3):
        # polyfit 

        memp, c = np.polyfit(np.arange(len(mns[j])), mns[j], 1)
        empirical.append(memp)

        shuffled = []
        for sc in range(shuffle_count):
            # if len(mns_shuffle[j,sc]) > 0:
            ses_dist = mns_shuffle[j,sc]
            mshuffled, c = np.polyfit(np.arange(len(ses_dist)), ses_dist, 1)
            shuffled.append(mshuffled)
                # result = mk.original_test(ses_dist)
                # ps.append(result.p)


        # # two sided p-value 
        pgreater = np.sum(np.array(shuffled) < empirical[j]) / len(shuffled) 
        plower = np.sum(np.array(shuffled) > empirical[j]) / len(shuffled)
        pvalue = np.min([pgreater, plower]) * 2
        ps.append(pvalue)
        print('Group: ' + gp_labels[j])
        print('Empirical: ' + str(empirical[j]))
        print('Shuffled: ' + str(np.mean(shuffled)) + ' ± ' + str(np.std(shuffled)))
        # print('p-value: ' + str(np.min([pgreater, plower]) * 2))


        # if empirical[j] > np.mean(shuffled):
        #     tag = 'greater'
        # elif empirical[j] < np.mean(shuffled):
        #     tag = 'lower'
        # lbl = 'Slope is ' + str(tag) + ' than shuffled: ' + str(np.round(memp, 2)) + ' , p = ' + str(np.round(pvalue, 3)) 
        # lbls.append(lbl)

        # lbl = gp_labels[j] + ' 
        lbl = 'slope: ' + str(np.round(memp, 2)) 
        lbl_colors.append('k')
        lbls.append(lbl)

    # result = mk.original_test(ses_dist)
    print(ps)
    out = multitest.multipletests(ps, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    cc = 0
    for case in out[0]:
        if case:
            lbl_colors[cc] = 'r'
            # lbls[cc] = lbls[cc] + ' & is sig after BH correction'
        # else:
            # lbls[cc] = lbls[cc] + ' & is NOT sig after BH correction'
        cc += 1
    print(out)

    ax1.legend(bps, lbls, loc='upper right')
    # color label text in legend
    for text, color in zip(ax1.legend_.get_texts(), lbl_colors):
        text.set_color(color)

    # ax1.set_title(score)
    # ax.set_xlabel('Session')
    axf.set_title(titles_to_use[i])
    ax1.set_ylabel('EMD quantile')
              
    # ax1.set_xticks(np.arange(len(gps)) * 3 + 1.25/2)
    # ax1.set_xticklabels(gps)
    # ax1.set_xlim([-1.25/2, len(gps) * 3 - 1.25/2])

    ax1.set_xticks(np.arange(len(gps)) * 3 + .5)
    ax1.set_xticklabels(gps)
    ax1.set_ylim(-0.05,)
    # ax1.set_xlim([-.5/2, len(gps) * 3 - .5/2])


fig.suptitle('All indiv. cell-session appearances', fontweight='bold')
# fig.suptitle('Averaged by session', fontweight='bold')
gs_main.tight_layout(fig, rect=[0, 0, 1, 0.98])
# fig.suptitle('Averaged by session')
fig.tight_layout()
plt.show()

In [None]:
plt.hist(bootstrap_coeffs, bins=100)
plt.vlines(emp_coeff,0,100, color='r')
plt.show()

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind

scores_to_use = ['whole', 'spike_density', 'field', 'binary', 'centroid', 'firing_rate']
quad_arrange = [[0,0],[0,1],[1,0],[1,1], [2,0], [2,1]]
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary', 'Centroid', 'Firing Rate']
gps = np.unique(df['session_id'])
gp_labels = ['ANT', 'B6', 'NON']
gp_colors = ['r', 'b', 'g']
np.random.seed(0)
def _single_shuffle(to_plot_shuffle, sesgp, metric, gplbl):
    # to_plot_shuffle['group'] = np.random.permutation(to_plot_shuffle['group'].loc[:].values)
    vals = to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'].to_numpy()
    # shuffle the vals
    np.random.shuffle(vals)
    to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'] = vals

    use = to_plot_shuffle[to_plot_shuffle['session_id'] == sesgp]
    use = use[use['group'] == gplbl][metric]
    mn = np.mean(use)
    # print(metric, sesgp, mn, gplbl)
    return mn

fig = plt.figure(figsize=(23, 46))
gs_main = gridspec.GridSpec(3, 2, width_ratios=[1,1], height_ratios=[1,1, 1])  # Adjust width_ratios and height_ratios as needed
np.random.seed(0)

metric = 'obj_w'
# metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    # scores averaged for each session
    if score != 'firing_rate':
        to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
        # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
        # to_plot_shuffle = to_plot.copy()
    else:
        to_plot = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
        # to_plot_count = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
        # to_plot_shuffle = to_plot.copy()


    row, col = quad_arrange[i]
    gs_sub = gridspec.GridSpecFromSubplotSpec(5, 1, subplot_spec=gs_main[row, col], height_ratios=[12,12,1,1,1], hspace=0.1)

    ax = plt.subplot(gs_sub[0])
    axf = ax
    bps = []
    lbls = []
    means = []
    sems = []
    n = []
    comps = {}
    comp_maxs = []
    for k in range(3):
        c = gp_colors[k]
        if score != 'firing_rate':
            mtouse = metric 
        else:
            mtouse = 'firing_rate'
        bp = ax.boxplot(to_plot[to_plot['group'] == gp_labels[k]][mtouse], positions=[k], widths=0.5, 
                    notch=False, patch_artist=True,
                    boxprops=dict(facecolor=c, color='k'),
                    capprops=dict(color='k'),
                    whiskerprops=dict(color='k'),
                    flierprops=dict(color='k', markeredgecolor='k'),
                    medianprops=dict(color='k'),
                    showmeans=True, 
                    meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
        comp_maxs.append(np.max(to_plot[to_plot['group'] == gp_labels[k]][mtouse]))
        for k2 in range(3):
            if k2 != k:
                lbl_pair = np.sort([gp_labels[k], gp_labels[k2]])
                comp_group = lbl_pair[0] + '_' + lbl_pair[1]
                if comp_group not in comps.keys():
                    comps[comp_group] = np.nan

                    if 'B6' in comp_group and 'ANT' in comp_group:
                        if score != 'firing_rate':
                            to_plot_model = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        else:
                            to_plot_model = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        model_data = to_plot_model[to_plot_model['group'].isin(['B6', 'ANT'])]
                        group_order = ['B6', 'ANT']  # 'B6' becomes the reference group
                    elif 'B6' in comp_group and 'NON' in comp_group:
                        if score != 'firing_rate':
                            to_plot_model = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        else:
                            to_plot_model = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        model_data = to_plot_model[to_plot_model['group'].isin(['B6', 'NON'])]
                        group_order = ['B6', 'NON']
                    elif 'ANT' in comp_group and 'NON' in comp_group:
                        if score != 'firing_rate':
                            to_plot_model = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        else:
                            to_plot_model = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        model_data = to_plot_model[to_plot_model['group'].isin(['NON', 'ANT'])]
                        group_order = ['NON', 'ANT']
                    # model_data['group'] = pd.Categorical(model_data['group'], categories=group_order, ordered=True)
                    model_data.loc[:, 'group'] = pd.Categorical(model_data.loc[:, 'group'], categories=group_order, ordered=True)

                    if score != 'firing_rate':
                        model = sm.MixedLM.from_formula(metric + ' ~ group', data=model_data, groups=model_data['name'])
                        # model = BetaModel.from_formula(metric + ' ~ group', data=model_data, groups=model_data['name'])
                    else:
                        model = sm.MixedLM.from_formula('firing_rate ~ group', data=model_data, groups=model_data['name'])
                        # model = BetaModel.from_formula('firing_rate ~ group', data=model_data, groups=model_data['name'])
                    # result = model.fit()
                    robust_model = RLM(model.endog, model.exog, M=sm.robust.norms.HuberT())
                    result = robust_model.fit()
                    emp_coeff = result.params[1]

                    # Create a dictionary to map animal names to groups
                    itms = model_data.groupby(['name','group']).groups.keys()
                    unique_animal_group_pairs_dict = {itm[0]: itm[1] for itm in itms}


                    bootstrap_coeffs = []
                    prev_sample = None
                    for b in range(1000):
                        resampled_animal_group_pairs = np.random.choice(list(unique_animal_group_pairs_dict.keys()), size=len(unique_animal_group_pairs_dict), replace=True)
                        sample = None
                        for pr in resampled_animal_group_pairs:
                            if sample is None:
                                sample = model_data[model_data['name'] == pr]
                                sample = sample.sample(frac=1, replace=True)
                            else:
                                sp = model_data[model_data['name'] == pr]
                                sp = sp.sample(frac=1, replace=True)
                                sample = pd.concat([sample, sp])
                        # check sample and prev_sample are NOT the same
                        if prev_sample is not None:
                            assert not prev_sample.equals(sample), 'Sample and previous sample are the same'
                        # bootstrap for p-value
                        # sample = model_data.sample(frac=1, replace=True) 
                        if score != 'firing_rate':
                            model = sm.MixedLM.from_formula(metric + ' ~ group', data=sample, groups=sample['name'])
                        else:
                            model = sm.MixedLM.from_formula('firing_rate ~ group', data=sample, groups=sample['name'])
                        robust_model = RLM(model.endog, model.exog, M=sm.robust.norms.HuberT())
                        result = robust_model.fit()
                        coeff = result.params[1]
                        bootstrap_coeffs.append(coeff)
                        prev_sample = sample
                    p = np.sum(np.abs(bootstrap_coeffs) > np.abs(emp_coeff)) / len(bootstrap_coeffs)
                    comps[comp_group] = p

        means.append(np.mean(to_plot[to_plot['group'] == gp_labels[k]][metric]))
        sems.append(np.std(to_plot[to_plot['group'] == gp_labels[k]][metric]) / np.sqrt(len(to_plot[to_plot['group'] == gp_labels[k]][metric])))
        n.append(len(to_plot[to_plot['group'] == gp_labels[k]][metric]))
                                    
        bps.append(bp['boxes'][0])
        # lbls.append(str(means[k]) + ' ± ' + str(sems[k]) + ' cm, N = ' + str(n[k]))
        lbls.append(gp_labels[k])
        #  + ': N = ' + str(n[k]))
    
    ax.set_xticklabels(gp_labels)
    ax.legend(bps, lbls, loc='upper right')
    ax.set_xlabel('Group')
    ax.set_ylabel('EMD (cm)')

    print('comps here')
    print(comps)


    # benjamini hochberg correction
    kys, vals = zip(*comps.items())
    accepted, pvals_corrected, _, _ = multipletests(vals, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    
    p_count = 0
    for comp_key, val in comps.items():
        comparison = comps[comp_key]
        if 'ANT' in comp_key and 'B6' in comp_key:
            nme = [0,1]
        elif 'ANT' in comp_key and 'NON' in comp_key:
            nme = [0,2]
        elif 'B6' in comp_key and 'NON' in comp_key:
            nme = [1,2]

        # if accepted[k]:
        barplot_annotate_brackets(nme[0],nme[1],pvals_corrected[p_count],[0,1,2], comp_maxs, maxasterix=5)
        
        p_count += 1

    # ax = fig.add_subplot(2, 2, i+1)
    ax = plt.subplot(gs_sub[1])
    ax1 = ax

    # # every row for that score
    if score != 'firing_rate':
        # to_plot = df[df['score'] == score]
        to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    else:
        # to_plot = df[df['score'] == 'whole']
        to_plot = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
        metric = 'firing_rate'
        # to_plot = to_plot[to_plot['depth']]
    to_plot_shuffle = to_plot.copy()



    bps = []
    lbls = []
    shuffle_count = 1000
    mns = [[] for j in range(3)]
    mns_shuffle = [[[] for sc in range(shuffle_count)] for j in range(3)]
    
    session_comps = {}
    group_ses_frs = {'ANT': [], 'B6': [], 'NON': []}
    for k in range(len(gps)):
        # c = gp_colors[k]
        ses_visited = []
        # ses_frs = []
        for j in range(3):
            # get group means + CI
            to_plot_now = to_plot[to_plot['group'] == gp_labels[j]]
            ses_fr = to_plot_now[to_plot_now['session_id'] == gps[k]][metric]
            # ses_fr = to_plot_now[to_plot_now['session_id'] == gps[k]]['information']
            bp = ax.boxplot(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], positions=[k*3+j*.5], widths=0.5, 
                        notch=False, patch_artist=True,
                        boxprops=dict(facecolor=gp_colors[j],color='k'),
                        capprops=dict(color='k'),
                        whiskerprops=dict(color='k'),
                        flierprops=dict(color='k', markeredgecolor='k'),
                        medianprops=dict(color='k'),
                        showmeans=False, 
                        meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
            if k == 0:
                bps.append(bp['boxes'][0])   

            mn = np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
            if mn == mn:
                mns[j].append(mn)

            for j2 in range(3):
                to_plot_now2 = to_plot[to_plot['group'] == gp_labels[j2]]
                if j != j2:
                    # welch's t-test
                    if len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]) > 1 and len(to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric]) > 1:
                        # res = ttest_ind(to_plot_now[to_plot_now['session_id'] == gps[k]][metric],
                        #                     to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric],
                        #                     usevar='unequal', alternative='two-sided')
                        # pvalue = res[1]

                        # mann whitney u test
                        _, pvalue = mannwhitneyu(to_plot_now[to_plot_now['session_id'] == gps[k]][metric],
                                           to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric],
                                          alternative='two-sided')
                         
                    
                        sorted_labels = np.sort([gp_labels[j], gp_labels[j2]])
                        pair_id = sorted_labels[0] + '_' + sorted_labels[1]
                        if pair_id not in session_comps:
                            session_comps[pair_id] = []
                        if pair_id not in ses_visited:
                            session_comps[pair_id].append(pvalue)
                            # session_comps[pair_id].append(pvalue)
                            ses_visited.append(pair_id)
            if len(ses_fr) > 0 and np.mean(ses_fr) == np.mean(ses_fr):       
                group_ses_frs[gp_labels[j]].append(np.mean(ses_fr))

        for j in range(3):
            for sc in range(shuffle_count):
                out = _single_shuffle(to_plot_shuffle, gps[k], metric, gp_labels[j])
                if out == out:
                    mns_shuffle[j][sc].append(out)

    from mpl_toolkits.axes_grid1 import make_axes_locatable
    from scipy import interpolate
    from scipy.interpolate import griddata


    # benjamini-hochberg correction
    pvalsANT_B6 = np.array(session_comps['ANT_B6'])
    pvalsB6_NON = np.array(session_comps['B6_NON'])
    pvalsANT_NON = np.array(session_comps['ANT_NON'])
    pvals = np.concatenate((pvalsANT_B6, pvalsB6_NON, pvalsANT_NON))
    reject, pvals_corrected, alphacSidak, alphacBonf = multitest.multipletests(pvals, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    bh_ant_b6 = np.ones(len(pvalsANT_B6))
    bh_ant_b6[pvals_corrected[:len(pvalsANT_B6)] <= 0.05] = 0
    bh_b6_non = np.ones(len(pvalsB6_NON))
    bh_b6_non[pvals_corrected[len(pvalsANT_B6):len(pvalsANT_B6)+len(pvalsB6_NON)] <= 0.05] = 0
    bh_ant_non = np.ones(len(pvalsANT_NON))
    bh_ant_non[pvals_corrected[len(pvalsANT_B6)+len(pvalsB6_NON):] <= 0.05] = 0

    # plt.setp(ax1.get_xticklabels(), visible=False)

    ax = plt.subplot(gs_sub[2])
    # bh_ant_b6 = np.hstack((bh_ant_b6, [.5]))
    # ax.imshow(np.expand_dims(bh_ant_b6, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ant_fr = group_ses_frs['ANT']
    ant_interp = np.linspace(0, len(ant_fr)-1, 1000)
    ant_fr_smooth = np.interp(ant_interp, np.arange(len(ant_fr)), ant_fr)
    im = ax.imshow(np.expand_dims(ant_fr_smooth, 0), cmap='jet', aspect='auto', extent=[0, len(gps), 0, 1])
    # colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ylbl = ax.set_ylabel('ANT', labelpad=15, rotation=0)
    # pos = ylbl.get_position()
    # ylbl.set_position((pos[0], pos[1] -.5))
    ax.set_yticks([])
    # ax.set_ylabel('ANT-B6', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='x', which='both', bottom=False, top=False)


    ax = plt.subplot(gs_sub[3])
    # bh_b6_non = np.hstack((bh_b6_non, [.5]))
    # ax.imshow(np.expand_dims(bh_b6_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    b6_fr = group_ses_frs['B6']
    b6_interp = np.linspace(0, len(b6_fr)-1, 1000)
    b6_fr_smooth = np.interp(b6_interp, np.arange(len(b6_fr)), b6_fr)
    im = ax.imshow(np.expand_dims(b6_fr_smooth, 0), cmap='jet', aspect='auto', extent=[0, len(gps), 0, 1])
    # colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ylbl = ax.set_ylabel('B6', labelpad=15, rotation=0)
    # pos = ylbl.get_position()
    # ylbl.set_position((pos[0], pos[1]-.5))
    ax.set_yticks([])
    # ax.set_ylabel('B6-NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='x', which='both', bottom=False, top=False)


    ax = plt.subplot(gs_sub[4])
    # # bh_ant_non = np.hstack((bh_ant_non, [.5]))
    # ax.imshow(np.expand_dims(bh_ant_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    non_fr = group_ses_frs['NON']
    non_interp = np.linspace(0, len(non_fr)-1, 1000)
    non_fr_smooth = np.interp(non_interp, np.arange(len(non_fr)), non_fr)
    im = ax.imshow(np.expand_dims(non_fr_smooth, 0), cmap='jet', aspect='auto', extent=[0, len(gps), 0, 1])
    # colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ylbl = ax.set_ylabel('NON', labelpad=15, rotation=0)
    # pos = ylbl.get_position()
    # ylbl.set_position((pos[0], pos[1] - .5))
    plt.setp(ax.get_xticklabels(), visible=False)

    ax.set_yticks([])
    # ylbl = ax.set_ylabel('ANT-NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='x', which='both', bottom=False, top=False)


    # mann kendall
    import pymannkendall as mk
    from statsmodels.stats import multitest
    lbls = []
    lbl_colors = []
    empirical = []
    ps = []
    mns_shuffle = np.array(mns_shuffle)
    print(mns_shuffle.shape)
    print('Metric: ' + score)
    for j in range(3):
        # polyfit 

        memp, c = np.polyfit(np.arange(len(mns[j])), mns[j], 1)
        empirical.append(memp)

        shuffled = []
        for sc in range(shuffle_count):
            # if len(mns_shuffle[j,sc]) > 0:
            ses_dist = mns_shuffle[j,sc]
            mshuffled, c = np.polyfit(np.arange(len(ses_dist)), ses_dist, 1)
            shuffled.append(mshuffled)
                # result = mk.original_test(ses_dist)
                # ps.append(result.p)


        # # two sided p-value 
        pgreater = np.sum(np.array(shuffled) < empirical[j]) / len(shuffled) 
        plower = np.sum(np.array(shuffled) > empirical[j]) / len(shuffled)
        pvalue = np.min([pgreater, plower]) * 2
        ps.append(pvalue)
        print('Group: ' + gp_labels[j])
        print('Empirical: ' + str(empirical[j]))
        print('Shuffled: ' + str(np.mean(shuffled)) + ' ± ' + str(np.std(shuffled)))
        # print('p-value: ' + str(np.min([pgreater, plower]) * 2))


        # if empirical[j] > np.mean(shuffled):
        #     tag = 'greater'
        # elif empirical[j] < np.mean(shuffled):
        #     tag = 'lower'
        # lbl = 'Slope is ' + str(tag) + ' than shuffled: ' + str(np.round(memp, 2)) + ' , p = ' + str(np.round(pvalue, 3)) 
        # lbls.append(lbl)

        # lbl = gp_labels[j] + ' 
        lbl = 'slope: ' + str(np.round(memp, 2)) 
        lbl_colors.append('k')
        lbls.append(lbl)

    # result = mk.original_test(ses_dist)
    print(ps)
    out = multitest.multipletests(ps, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    cc = 0
    for case in out[0]:
        if case:
            lbl_colors[cc] = 'r'
            # lbls[cc] = lbls[cc] + ' & is sig after BH correction'
        # else:
            # lbls[cc] = lbls[cc] + ' & is NOT sig after BH correction'
        cc += 1
    print(out)

    ax1.legend(bps, lbls, loc='upper right')
    # color label text in legend
    for text, color in zip(ax1.legend_.get_texts(), lbl_colors):
        text.set_color(color)

    # ax1.set_title(score)
    # ax.set_xlabel('Session')
    axf.set_title(titles_to_use[i], fontweight='bold')
    ax1.set_ylabel('EMD (cm)')
              
    # ax1.set_xticks(np.arange(len(gps)) * 3 + 1.25/2)
    # ax1.set_xticklabels(gps)
    # ax1.set_xlim([-1.25/2, len(gps) * 3 - 1.25/2])

    ax1.set_xticks(np.arange(len(gps)) * 3 + .5)
    ax1.set_xticklabels(gps)
    ax1.set_ylim(-0.05,)
    # ax1.set_xlim([-.5/2, len(gps) * 3 - .5/2])


fig.suptitle('All indiv. cell-session appearances', fontweight='bold')
# fig.suptitle('Averaged by session', fontweight='bold')
gs_main.tight_layout(fig, rect=[0, 0, 1, 0.98])
# fig.suptitle('Averaged by session')
fig.tight_layout()
plt.show()

In [None]:

ses_avg_whole = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
indiv_whole = df[df['score'] == 'whole']

ses_avg_sd = df[df['score'] == 'spike_density'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
indiv_sd = df[df['score'] == 'spike_density']

ses_avg_field = df[df['score'] == 'field'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
indiv_field = df[df['score'] == 'field']

ses_avg_binary = df[df['score'] == 'binary'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
indiv_binary = df[df['score'] == 'binary']

ses_avg_centroid = df[df['score'] == 'centroid'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
indiv_centroid = df[df['score'] == 'centroid']

# all above are for obj_w AND obj_q. First one ('whole') also includes firing rate

In [None]:
np.random.choice(list(unique_animal_group_pairs_dict.keys()), size=len(unique_animal_group_pairs_dict), replace=True)


In [None]:
unique_animal_group_pairs_dict

In [None]:
resampled_animal_group_pairs

In [None]:
resampled_animal_group_pairs

In [None]:
                #    unique_animal_group_pairs = list(zip(model_data['name'].unique(), model_data['group'].unique()))
                #     unique_animal_group_pairs_dict = {animal: group for animal, group in unique_animal_group_pairs}


                #     bootstrap_coeffs = []
                #     for b in range(1000):
                #         resampled_animal_group_pairs = np.random.choice(['ANT','NON','B6'], size=len(unique_animal_group_pairs_dict), replace=True)
                #         sample = model_data[model_data['group'].isin(resampled_animal_group_pairs)]
                #         print(sample)
                #         sample = sample.groupby(['name']).apply(lambda x: x.sample(frac=1, replace=True)).reset_index(drop=True)
                #         # bootstrap for p-value


unique_animal_group_pairs           

In [None]:
comps

In [None]:
import statsmodels.api as sm
from statsmodels.othermod.betareg import BetaModel
import numpy as np

# Simulated data
np.random.seed(0)
n_samples = 100
groups = np.random.choice(['ANT', 'B6'], size=n_samples)
names = np.random.choice(['A', 'B', 'C'], size=n_samples)
firing_rates = np.random.beta(2, 5, size=n_samples)  # Simulated beta-distributed firing rates

# Create a DataFrame for the data
model_data = pd.DataFrame({'firing_rate': firing_rates, 'group': groups, 'name': names})

# Fit a beta regression model
model = BetaModel.from_formula('firing_rate ~ group', data=model_data, groups=model_data['name'])

result = RLM(model.endog, model.exog, M=sm.robust.norms.HuberT()).fit()
# Summary of the model
print(result.summary())


In [None]:
plt.plot(bootstrap_coeffs)

In [None]:
emp_coeff

In [None]:
comps.keys()

In [None]:
# https://stackoverflow.com/questions/11517986/indicating-the-statistically-significant-difference-in-bar-graph

def barplot_annotate_brackets(num1, num2, data, center, height, yerr=None, dh=.05, barh=.05, fs=None, maxasterix=None):
    """ 
    Annotate barplot with p-values.

    :param num1: number of left bar to put bracket over
    :param num2: number of right bar to put bracket over
    :param data: string to write or number for generating asterixes
    :param center: centers of all bars (like plt.bar() input)
    :param height: heights of all bars (like plt.bar() input)
    :param yerr: yerrs of all bars (like plt.bar() input)
    :param dh: height offset over bar / bar + yerr in axes coordinates (0 to 1)
    :param barh: bar height in axes coordinates (0 to 1)
    :param fs: font size
    :param maxasterix: maximum number of asterixes to write (for very small p-values)
    """

    if type(data) is str:
        text = data
    else:
        # * is p < 0.05
        # ** is p < 0.005
        # *** is p < 0.0005
        # etc.
        text = ''
        p = .05

        while data < p:
            text += '*'
            p /= 10.

            if maxasterix and len(text) == maxasterix:
                break

        if len(text) == 0:
            text = 'n. s.'

    lx, ly = center[num1], height[num1]
    rx, ry = center[num2], height[num2]

    if yerr:
        ly += yerr[num1]
        ry += yerr[num2]

    ax_y0, ax_y1 = plt.gca().get_ylim()
    dh *= (ax_y1 - ax_y0)
    barh *= (ax_y1 - ax_y0)

    y = max(ly, ry) + dh

    barx = [lx, lx, rx, rx]
    bary = [y, y+barh, y+barh, y]
    mid = ((lx+rx)/2, y+barh)

    plt.plot(barx, bary, c='black')

    kwargs = dict(ha='center', va='bottom')
    if fs is not None:
        kwargs['fontsize'] = fs

    plt.text(*mid, text, **kwargs)

In [None]:
sample = df.sample(frac=1, replace=True) 
print(len(sample), len(df))
print(sample['group'].value_counts(), df['group'].value_counts())
# check order 


In [None]:
to_plot_now.columns

In [None]:
pos[0]

In [None]:
    window_size = 3  # Adjust the window size as needed
    ant_fr_smooth = np.convolve(ant_fr, np.ones(window_size)/window_size, mode='same')

In [None]:
ant_fr_smooth

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind
from scipy.stats import norm
from sklearn.neighbors import KernelDensity
from scipy.stats import shapiro, kstest, mannwhitneyu, ttest_ind, ttest_rel, wilcoxon, ks_2samp, anderson_ksamp, anderson

scores_to_use = ['whole', 'spike_density', 'field', 'binary']
quad_arrange = [[0,0],[0,1],[1,0],[1,1]]
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary']
gps = np.unique(df['session_id'])
gp_labels = ['ANT', 'B6', 'NON']
gp_colors = ['r', 'b', 'g']
np.random.seed(0)
def _single_shuffle(to_plot_shuffle, sesgp, metric, gplbl):
    # to_plot_shuffle['group'] = np.random.permutation(to_plot_shuffle['group'].loc[:].values)
    vals = to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'].to_numpy()
    # shuffle the vals
    np.random.shuffle(vals)
    to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'] = vals

    use = to_plot_shuffle[to_plot_shuffle['session_id'] == sesgp]
    use = use[use['group'] == gplbl][metric]
    mn = np.mean(use)
    # print(metric, sesgp, mn, gplbl)
    return mn

fig = plt.figure(figsize=(16, 12))
gs_main = gridspec.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1])  # Adjust width_ratios and height_ratios as needed


# metric = 'obj_w'
metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    row, col = quad_arrange[i]
    gs_sub = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=gs_main[row, col], height_ratios=[12,1,1,1])  

    # ax = fig.add_subplot(2, 2, i+1)
    ax = plt.subplot(gs_sub[0])
    ax1 = ax

    # # # every row for that score
    # to_plot = df[df['score'] == score]
    # to_plot_shuffle = to_plot.copy()

    # scores averaged for each session
    to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
    to_plot_shuffle = to_plot.copy()

    # bps = []
    lbls = []
    shuffle_count = 1
    mns = [[] for j in range(3)]
    mns_shuffle = [[[] for sc in range(shuffle_count)] for j in range(3)]
    
    session_comps = {}
    for k in range(len(gps)):
        # c = gp_colors[k]
        ses_visited = []
        for j in range(3):
            # get group means + CI
            to_plot_now = to_plot[to_plot['group'] == gp_labels[j]]
            # bp = ax.boxplot(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], positions=[k*3+j*.5], widths=0.5, 
            #             notch=False, patch_artist=True,
            #             boxprops=dict(facecolor=gp_colors[j],color='k'),
            #             capprops=dict(color='k'),
            #             whiskerprops=dict(color='k'),
            #             flierprops=dict(color='k', markeredgecolor='k'),
            #             medianprops=dict(color='k'),
            #             showmeans=False, 
            #             meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
            # if k == 0:
            #     bps.append(bp['boxes'][0])  

            # kde = norm.pdf(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]), np.std(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]))
            # ax.plot([k*3+j*.5]+kde*3, to_plot_now[to_plot_now['session_id'] == gps[k]][metric], color=gp_colors[j], alpha=.5)

            if len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]) > 0:
                kde = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(to_plot_now[to_plot_now['session_id'] == gps[k]][metric].values.reshape(-1, 1))
                density = np.exp(kde.score_samples(np.linspace(0, 1, 100).reshape(-1, 1)))
                ax.plot([k*3+j*.5]+density, np.linspace(0, 1, 100), color=gp_colors[j], alpha=.5)
                # ax.plot([k*3+j*.5]+density*3, to_plot_now[to_plot_now['session_id'] == gps[k]][metric], color=gp_colors[j], alpha=.5)
                # ax.plot([k*3+j*.5]+kde.pdf(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]), to_plot_now[to_plot_now['session_id'] == gps[k]][metric], color=gp_colors[j], alpha=.5)
                if gp_labels[j] == 'ANT':
                    ax.text(k*3+j*.5, 0.9, 'N='+str(len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])), color=gp_colors[j], fontsize=8, fontweight='bold')
                elif gp_labels[j] == 'B6':
                    median_index = np.argsort(density)[len(density) // 2]
                    ax.text(k*3+j*.5, 0.5, 'N='+str(len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])), color=gp_colors[j], fontsize=8, fontweight='bold')
                elif gp_labels[j] == 'NON':
                    ax.text(k*3+j*.5, 0.1, 'N='+str(len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])), color=gp_colors[j], fontsize=8, fontweight='bold')


            mn = np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
            if mn == mn:
                mns[j].append(mn)

            if len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]) > 3:
                # shapiro-wilk test
                stat, p = shapiro(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
                
                # KS test
                # stat, p = kstest(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], 'norm', args=(np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]), np.std(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])))

                # stat, crit, sig = anderson(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], dist='norm')
                # print(gp_labels[j], gps[k], stat, crit, sig)
                # if result.statistic > result.critical_values[result.significance_level == 5.0]:
                #     p = 0
                # else:
                #     p = 1
                    
                if gp_labels[j] not in session_comps:
                    session_comps[gp_labels[j]] = []
                session_comps[gp_labels[j]].append(p)
            # else:
            #     if gp_labels[j] not in session_comps:
            #         session_comps[gp_labels[j]] = []
            #     session_comps[gp_labels[j]].append(np.nan)
                
                        

        for j in range(3):
            for sc in range(shuffle_count):
                out = _single_shuffle(to_plot_shuffle, gps[k], metric, gp_labels[j])
                if out == out:
                    mns_shuffle[j][sc].append(out)

    # benjamini-hochberg correction
    pvalsANT = np.array(session_comps['ANT'])
    pvalsB6 = np.array(session_comps['B6'])
    pvalsNON = np.array(session_comps['NON'])
    pvals = np.concatenate((pvalsANT, pvalsB6, pvalsNON))
    reject, pvals_corrected, alphacSidak, alphacBonf = multitest.multipletests(pvals, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    bh_ant = np.ones(len(pvalsANT))
    bh_ant[pvals_corrected[:len(pvalsANT)] <= 0.05] = 0
    bh_b6 = np.ones(len(pvalsB6))
    bh_b6[pvals_corrected[len(pvalsANT):len(pvalsANT)+len(pvalsB6)] <= 0.05] = 0
    bh_non = np.ones(len(pvalsNON))
    bh_non[pvals_corrected[len(pvalsANT)+len(pvalsB6):] <= 0.05] = 0

    plt.setp(ax1.get_xticklabels(), visible=False)

    ax = plt.subplot(gs_sub[1])
    # bh_ant = np.hstack((bh_ant, [1]))
    ax.imshow(np.expand_dims(bh_ant, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ax.set_yticks([])
    ax.set_ylabel('ANT', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)


    ax = plt.subplot(gs_sub[2])
    bh_b6 = np.hstack((bh_b6, [.5, .5]))
    ax.imshow(np.expand_dims(bh_b6, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ax.set_yticks([])
    ax.set_ylabel('B6', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)


    ax = plt.subplot(gs_sub[3])
    bh_non = np.hstack((bh_non, [.5]))
    # bh_non = np.ones(7)
    # bh_non[0] = 0
    ax.imshow(np.expand_dims(bh_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ax.set_yticks([])
    ylbl = ax.set_ylabel('NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)

 
    ax1.set_title(score)
    ax1.set_title(titles_to_use[i])
    ax1.set_ylabel('Quantile')
    ax1.set_xticks(np.arange(len(gps)) * 3 + .5)
    ax1.set_xticklabels(gps)



fig.suptitle('Averaged by session')
fig.tight_layout()
plt.show()

In [None]:
bh_ant

In [None]:
plt.imshow(np.expand_dims(bh_non, 0), extent=[0, len(gps), 0, 1])
plt.xticks(np.arange(len(gps)) + 0.5)

plt.show()

In [None]:
bh_non

In [None]:
pvals.shape

In [None]:
# plt.hist(to_plot_now[to_plot_now['session_id'] == 'session_1'][metric], bins=20)
# plt.show()
# GRIN013_H16_M33_S54
# sns.distplot(to_plot_now[to_plot_now['session_id'] == 'session_1'][metric], bins=20)
# ashapiorwilk
# p = shapiro(to_plot_now[to_plot_now['session_id'] == 'session_1'][metric])
# normaltest
# from scipy.stats import normaltest
# p = normaltest(to_plot_now[to_plot_now['session_id'] == 'session_1'][metric])
# jarque_bera
# from scipy.stats import jarque_bera
# p = jarque_bera(to_plot_now[to_plot_now['session_id'] == 'session_1'][metric])
# lilliefors
from statsmodels.stats.diagnostic import lilliefors
p = lilliefors(to_plot_now[to_plot_now['session_id'] == 'session_1'][metric])
print(p)

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind
from scipy.stats import norm
from sklearn.neighbors import KernelDensity
from scipy.stats import shapiro, kstest, mannwhitneyu, ttest_ind, ttest_rel, wilcoxon, ks_2samp, anderson_ksamp, anderson

scores_to_use = ['whole', 'spike_density', 'field', 'binary']
quad_arrange = [[0,0],[0,1],[1,0],[1,1]]
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary']
gps = np.unique(df['session_id'])
gp_labels = ['ANT', 'B6', 'NON']
gp_colors = ['r', 'b', 'g']
np.random.seed(0)
def _single_shuffle(to_plot_shuffle, sesgp, metric, gplbl):
    # to_plot_shuffle['group'] = np.random.permutation(to_plot_shuffle['group'].loc[:].values)
    vals = to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'].to_numpy()
    # shuffle the vals
    np.random.shuffle(vals)
    to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'] = vals

    use = to_plot_shuffle[to_plot_shuffle['session_id'] == sesgp]
    use = use[use['group'] == gplbl][metric]
    mn = np.mean(use)
    # print(metric, sesgp, mn, gplbl)
    return mn

fig = plt.figure(figsize=(16, 8))
gs_main = gridspec.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1])  # Adjust width_ratios and height_ratios as needed


metric = 'obj_w'
# metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    row, col = quad_arrange[i]
    gs_sub = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=gs_main[row, col], height_ratios=[12,1,1,1])  

    # ax = fig.add_subplot(2, 2, i+1)
    ax = plt.subplot(gs_sub[0])
    ax1 = ax

    # # # every row for that score
    # to_plot = df[df['score'] == score]
    # to_plot_shuffle = to_plot.copy()

    # scores averaged for each session
    to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
    to_plot_shuffle = to_plot.copy()

    # bps = []
    lbls = []
    shuffle_count = 1
    mns = [[] for j in range(3)]
    mns_shuffle = [[[] for sc in range(shuffle_count)] for j in range(3)]
    
    session_comps = {}
    for k in range(len(gps)):
        # c = gp_colors[k]
        ses_visited = []
        for j in range(3):
            # get group means + CI
            to_plot_now = to_plot[to_plot['group'] == gp_labels[j]]
            # bp = ax.boxplot(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], positions=[k*3+j*.5], widths=0.5, 
            #             notch=False, patch_artist=True,
            #             boxprops=dict(facecolor=gp_colors[j],color='k'),
            #             capprops=dict(color='k'),
            #             whiskerprops=dict(color='k'),
            #             flierprops=dict(color='k', markeredgecolor='k'),
            #             medianprops=dict(color='k'),
            #             showmeans=False, 
            #             meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
            # if k == 0:
            #     bps.append(bp['boxes'][0])  

            # kde = norm.pdf(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]), np.std(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]))
            # ax.plot([k*3+j*.5]+kde*3, to_plot_now[to_plot_now['session_id'] == gps[k]][metric], color=gp_colors[j], alpha=.5)

            if len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]) > 0:
                kde = KernelDensity(kernel='gaussian', bandwidth=1).fit(to_plot_now[to_plot_now['session_id'] == gps[k]][metric].values.reshape(-1, 1))
                lspcs = np.linspace(np.min(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]), np.max(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]), 20)
                density = np.exp(kde.score_samples(lspc.reshape(-1, 1)))
                ax.plot([k*3+j*.5]+density*5, lspc, color=gp_colors[j], alpha=.5)
                # hide xtick lines
                ax.xaxis.set_ticks_position('none')
                # ax.plot([k*3+j*.5]+density*3, to_plot_now[to_plot_now['session_id'] == gps[k]][metric], color=gp_colors[j], alpha=.5)
                # ax.plot([k*3+j*.5]+kde.pdf(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]), to_plot_now[to_plot_now['session_id'] == gps[k]][metric], color=gp_colors[j], alpha=.5)
                if gp_labels[j] == 'ANT':
                    ax.text(k*3+j*.5, 50, 'N='+str(len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])), color=gp_colors[j], fontsize=8, fontweight='bold')
                elif gp_labels[j] == 'B6':
                    median_index = np.argsort(density)[len(density) // 2]
                    ax.text(k*3+j*.5, 30, 'N='+str(len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])), color=gp_colors[j], fontsize=8, fontweight='bold')
                elif gp_labels[j] == 'NON':
                    ax.text(k*3+j*.5, 10, 'N='+str(len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])), color=gp_colors[j], fontsize=8, fontweight='bold')

            mn = np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
            if mn == mn:
                mns[j].append(mn)

            if len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]) > 3:
                # shapiro-wilk test
                stat, p = shapiro(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
                
                # KS test
                # stat, p = kstest(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], 'norm', args=(np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]), np.std(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])))

                # anderson-darling test
                # stat, crit, sig = anderson(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
                # print(gp_labels[j], gps[k], stat, crit, sig)
                # if result.statistic > result.critical_values[result.significance_level == 5.0]:
                #     p = 0
                # else:
                #     p = 1
                    
                if gp_labels[j] not in session_comps:
                    session_comps[gp_labels[j]] = []
                session_comps[gp_labels[j]].append(p)
                        

        for j in range(3):
            for sc in range(shuffle_count):
                out = _single_shuffle(to_plot_shuffle, gps[k], metric, gp_labels[j])
                if out == out:
                    mns_shuffle[j][sc].append(out)

    # benjamini-hochberg correction
    pvalsANT = np.array(session_comps['ANT'])
    pvalsB6 = np.array(session_comps['B6'])
    pvalsNON = np.array(session_comps['NON'])
    pvals = np.concatenate((pvalsANT, pvalsB6, pvalsNON))
    reject, pvals_corrected, alphacSidak, alphacBonf = multitest.multipletests(pvals, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    bh_ant = np.ones(len(pvalsANT))
    bh_ant[pvals_corrected[:len(pvalsANT)] <= 0.05] = 0
    bh_b6 = np.ones(len(pvalsB6))
    bh_b6[pvals_corrected[len(pvalsANT):len(pvalsANT)+len(pvalsB6)] <= 0.05] = 0
    bh_non = np.ones(len(pvalsNON))
    bh_non[pvals_corrected[len(pvalsANT)+len(pvalsB6):] <= 0.05] = 0

    plt.setp(ax1.get_xticklabels(), visible=False)

    ax = plt.subplot(gs_sub[1])
    bh_ant = np.hstack((bh_ant, [.5]))
    ax.imshow(np.expand_dims(bh_ant, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1])
    ax.set_yticks([])
    ax.set_ylabel('ANT', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)


    ax = plt.subplot(gs_sub[2])
    bh_b6 = np.hstack((bh_b6, [.5,.5]))
    ax.imshow(np.expand_dims(bh_b6, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1])
    ax.set_yticks([])
    ax.set_ylabel('B6', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)


    ax = plt.subplot(gs_sub[3])
    bh_non = np.hstack((bh_non, [.5]))
    ax.imshow(np.expand_dims(bh_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1])
    ax.set_yticks([])
    ylbl = ax.set_ylabel('NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)

 
    ax1.set_title(score)
    ax1.set_title(titles_to_use[i])
    ax1.set_ylabel('EMD (cm)')
    ax1.set_xticks(np.arange(len(gps)) * 3 + .5)
    ax1.set_xticklabels(gps)



fig.suptitle('Averaged by session')
fig.tight_layout()
plt.show()

In [None]:
bh_non

In [None]:
import numpy as np
from scipy.stats import anderson

# Sample data
data = np.array([10, 15, 20, 25, 30])

# Perform the Anderson-Darling test against a normal distribution
result = anderson(data, dist='norm')

# Set significance level (choose based on your desired level)
significance_level = 0.05

# Determine if the test is significant
is_significant = result.statistic > result.critical_values[result.significance_level == significance_level]

# Print the result
if is_significant:
    print("Test is Significant (Departure from Normality): 1")
else:
    print("Test is Not Significant (Consistent with Normality): 0")


In [None]:
gps[k]

In [None]:
plt.plot(density, np.arange(len(density)), color=gp_colors[j], alpha=.5)

In [None]:
density

In [None]:
density.shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Sample data: p-values and corresponding N values
p_values = np.array([0.05, 0.001, 0.2, 0.0001, 0.02, 0.1, 0.03])
N_values = np.array([100, 200, 50, 300, 150, 80, 120])

# Create a 2D grid for the p-values
grid = p_values.reshape(1, -1)  # Reshape to a single row

# Create a figure and axis
fig, ax = plt.subplots()

# Display the grid of squares with imshow
im = ax.imshow(grid, cmap='binary', aspect='auto')

# Loop through each cell and add the N value as text label
for i in range(len(p_values)):
    ax.text(i, 0, f'N={N_values[i]}', color='black', ha='center', va='center')

# Customize the plot
ax.set_xticks(np.arange(len(p_values)))
ax.set_yticks([])
ax.set_xticklabels([])  # Remove x-axis tick labels
ax.set_title('P-Values with Corresponding N Values')

# Show the plot
plt.colorbar(im, ax=ax)
plt.show()


In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind

scores_to_use = ['whole', 'spike_density', 'field', 'binary']
quad_arrange = [[0,0],[0,1],[1,0],[1,1]]
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary']
gps = np.unique(df['session_id'])
gp_labels = ['ANT', 'B6', 'NON']
gp_colors = ['r', 'b', 'g']
np.random.seed(0)
def _single_shuffle(to_plot_shuffle, sesgp, metric, gplbl):
    # to_plot_shuffle['group'] = np.random.permutation(to_plot_shuffle['group'].loc[:].values)
    vals = to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'].to_numpy()
    # shuffle the vals
    np.random.shuffle(vals)
    to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'] = vals

    use = to_plot_shuffle[to_plot_shuffle['session_id'] == sesgp]
    use = use[use['group'] == gplbl][metric]
    mn = np.mean(use)
    # print(metric, sesgp, mn, gplbl)
    return mn

fig = plt.figure(figsize=(16, 16))
gs_main = gridspec.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1])  # Adjust width_ratios and height_ratios as needed


metric = 'obj_w'
# metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    row, col = quad_arrange[i]
    gs_sub = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=gs_main[row, col], height_ratios=[12,1,1,1])  

    # ax = fig.add_subplot(2, 2, i+1)
    ax = plt.subplot(gs_sub[0])
    ax1 = ax

    # # # every row for that score
    # to_plot = df[df['score'] == score]
    # to_plot_shuffle = to_plot.copy()

    # scores averaged for each session
    to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
    to_plot_shuffle = to_plot.copy()

    bps = []
    lbls = []
    shuffle_count = 1000
    mns = [[] for j in range(3)]
    mns_shuffle = [[[] for sc in range(shuffle_count)] for j in range(3)]
    
    session_comps = {}
    for k in range(len(gps)):
        # c = gp_colors[k]
        ses_visited = []
        for j in range(3):
            # get group means + CI
            to_plot_now = to_plot[to_plot['group'] == gp_labels[j]]
            to_plot_count_now = to_plot_count[to_plot_count['group'] == gp_labels[j]]
            bp = ax.boxplot(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], positions=[k*3+j*.5], widths=0.5, 
                        notch=False, patch_artist=True,
                        boxprops=dict(facecolor=gp_colors[j],color='k'),
                        capprops=dict(color='k'),
                        whiskerprops=dict(color='k'),
                        flierprops=dict(color='k', markeredgecolor='k'),
                        medianprops=dict(color='k'),
                        showmeans=False, 
                        meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
            if k == 0:
                bps.append(bp['boxes'][0])   

            mn = np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
            if mn == mn:
                mns[j].append(mn)

            for j2 in range(3):
                to_plot_now2 = to_plot[to_plot['group'] == gp_labels[j2]]
                to_plot_count_now2 = to_plot_count[to_plot_count['group'] == gp_labels[j2]]
                if j != j2:
                    # welch's t-test
                    if len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]) > 1 and len(to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric]) > 1:
                        w1 = to_plot_count_now[to_plot_count_now['session_id'] == gps[k]][metric]
                        w2 = to_plot_count_now2[to_plot_count_now2['session_id'] == gps[k]][metric]
                        w1 = w1 / np.mean(w1)
                        w2 = w2 / np.mean(w2)
                        res = ttest_ind(to_plot_now[to_plot_now['session_id'] == gps[k]][metric],
                                            to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric],
                                            usevar='unequal', alternative='two-sided', weights=(w1, w2))
                        pvalue = res[1]
                        # mann-whitney u test
                        # g1 = to_plot_now[to_plot_now['session_id'] == gps[k]][metric]
                        # g2 = to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric]
                        # w1 = to_plot_count_now[to_plot_count_now['session_id'] == gps[k]][metric]
                        # w2 = to_plot_count_now2[to_plot_count_now2['session_id'] == gps[k]][metric]
                        # w1 = w1 / np.mean(w1)
                        # w2 = w2 / np.mean(w2)
                        # g1 = np.repeat(g1, w1)
                        # g2 = np.repeat(g2, w2)
                        # _, pvalue = mannwhitneyu(g1,
                        #                         g2,
                        #                         alternative='two-sided')
                        sorted_labels = np.sort([gp_labels[j], gp_labels[j2]])
                        pair_id = sorted_labels[0] + '_' + sorted_labels[1]
                        if pair_id not in session_comps:
                            session_comps[pair_id] = []
                        if pair_id not in ses_visited:
                            session_comps[pair_id].append(pvalue)
                            ses_visited.append(pair_id)
                        

        for j in range(3):
            for sc in range(shuffle_count):
                out = _single_shuffle(to_plot_shuffle, gps[k], metric, gp_labels[j])
                if out == out:
                    mns_shuffle[j][sc].append(out)

    # benjamini-hochberg correction
    pvalsANT_B6 = np.array(session_comps['ANT_B6'])
    pvalsB6_NON = np.array(session_comps['B6_NON'])
    pvalsANT_NON = np.array(session_comps['ANT_NON'])
    pvals = np.concatenate((pvalsANT_B6, pvalsB6_NON, pvalsANT_NON))
    reject, pvals_corrected, alphacSidak, alphacBonf = multitest.multipletests(pvals, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    bh_ant_b6 = np.ones(len(pvalsANT_B6))
    bh_ant_b6[pvals_corrected[:len(pvalsANT_B6)] <= 0.05] = 0
    bh_b6_non = np.ones(len(pvalsB6_NON))
    bh_b6_non[pvals_corrected[len(pvalsANT_B6):len(pvalsANT_B6)+len(pvalsB6_NON)] <= 0.05] = 0
    bh_ant_non = np.ones(len(pvalsANT_NON))
    bh_ant_non[pvals_corrected[len(pvalsANT_B6)+len(pvalsB6_NON):] <= 0.05] = 0

    plt.setp(ax1.get_xticklabels(), visible=False)

    ax = plt.subplot(gs_sub[1])
    bh_ant_b6 = np.hstack((bh_ant_b6, [.5, .5]))
    ax.imshow(np.expand_dims(bh_ant_b6, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ax.set_yticks([])
    ax.set_ylabel('ANT-B6', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)


    ax = plt.subplot(gs_sub[2])
    bh_b6_non = np.hstack((bh_b6_non, [.5, .5]))
    ax.imshow(np.expand_dims(bh_b6_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ax.set_yticks([])
    ax.set_ylabel('B6-NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)


    ax = plt.subplot(gs_sub[3])
    bh_ant_non = np.hstack((bh_ant_non, [.5]))
    ax.imshow(np.expand_dims(bh_ant_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ax.set_yticks([])
    ylbl = ax.set_ylabel('ANT-NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)




    # mann kendall
    import pymannkendall as mk
    from statsmodels.stats import multitest
    lbls = []
    lbl_colors = []
    empirical = []
    ps = []
    mns_shuffle = np.array(mns_shuffle)
    print(mns_shuffle.shape)
    print('Metric: ' + score)
    for j in range(3):
        # polyfit 

        memp, c = np.polyfit(np.arange(len(mns[j])), mns[j], 1)
        empirical.append(memp)

        shuffled = []
        for sc in range(shuffle_count):
            # if len(mns_shuffle[j,sc]) > 0:
            ses_dist = mns_shuffle[j,sc]
            mshuffled, c = np.polyfit(np.arange(len(ses_dist)), ses_dist, 1)
            shuffled.append(mshuffled)
                # result = mk.original_test(ses_dist)
                # ps.append(result.p)


        # # two sided p-value 
        pgreater = np.sum(np.array(shuffled) < empirical[j]) / len(shuffled) 
        plower = np.sum(np.array(shuffled) > empirical[j]) / len(shuffled)
        pvalue = np.min([pgreater, plower]) * 2
        ps.append(pvalue)
        print('Group: ' + gp_labels[j])
        print('Empirical: ' + str(empirical[j]))
        print('Shuffled: ' + str(np.mean(shuffled)) + ' ± ' + str(np.std(shuffled)))
        # print('p-value: ' + str(np.min([pgreater, plower]) * 2))


        # if empirical[j] > np.mean(shuffled):
        #     tag = 'greater'
        # elif empirical[j] < np.mean(shuffled):
        #     tag = 'lower'
        # lbl = 'Slope is ' + str(tag) + ' than shuffled: ' + str(np.round(memp, 2)) + ' , p = ' + str(np.round(pvalue, 3)) 
        # lbls.append(lbl)
        lbl = gp_labels[j] + ' slope: ' + str(np.round(memp, 2)) 
        lbl_colors.append('k')
        lbls.append(lbl)

    # result = mk.original_test(ses_dist)
    print(ps)
    out = multitest.multipletests(ps, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    cc = 0
    for case in out[0]:
        if case:
            lbl_colors[cc] = 'r'
            # lbls[cc] = lbls[cc] + ' & is sig after BH correction'
        # else:
            # lbls[cc] = lbls[cc] + ' & is NOT sig after BH correction'
        cc += 1
    print(out)

    ax1.legend(bps, lbls, loc='upper right')
    # color label text in legend
    for text, color in zip(ax1.legend_.get_texts(), lbl_colors):
        text.set_color(color)

    ax1.set_title(score)
    # ax.set_xlabel('Session')
    ax1.set_title(titles_to_use[i])
    ax1.set_ylabel('EMD (cm)')
              
    # ax1.set_xticks(np.arange(len(gps)) * 3 + 1.25/2)
    # ax1.set_xticklabels(gps)
    # ax1.set_xlim([-1.25/2, len(gps) * 3 - 1.25/2])

    ax1.set_xticks(np.arange(len(gps)) * 3 + .5)
    ax1.set_xticklabels(gps)
    # ax1.set_xlim([-.5/2, len(gps) * 3 - .5/2])


# fig.suptitle('All indiv. cell-session appearances')
fig.suptitle('Averaged by session')
fig.tight_layout()
plt.show()

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind

scores_to_use = ['whole', 'spike_density', 'field', 'binary']
quad_arrange = [[0,0],[0,1],[1,0],[1,1]]
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary']
gps = np.unique(df['session_id'])
gp_labels = ['ANT', 'B6', 'NON']
gp_colors = ['r', 'b', 'g']
np.random.seed(0)
def _single_shuffle(to_plot_shuffle, sesgp, metric, gplbl):
    # to_plot_shuffle['group'] = np.random.permutation(to_plot_shuffle['group'].loc[:].values)
    vals = to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'].to_numpy()
    # shuffle the vals
    np.random.shuffle(vals)
    to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'] = vals

    use = to_plot_shuffle[to_plot_shuffle['session_id'] == sesgp]
    use = use[use['group'] == gplbl][metric]
    mn = np.mean(use)
    # print(metric, sesgp, mn, gplbl)
    return mn

fig = plt.figure(figsize=(16, 16))
gs_main = gridspec.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1])  # Adjust width_ratios and height_ratios as needed


metric = 'obj_w'
# metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    row, col = quad_arrange[i]
    gs_sub = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=gs_main[row, col], height_ratios=[12,1,1,1])  

    # ax = fig.add_subplot(2, 2, i+1)
    ax = plt.subplot(gs_sub[0])
    ax1 = ax

    # # every row for that score
    to_plot = df[df['score'] == score]
    to_plot_shuffle = to_plot.copy()

    # # scores averaged for each session
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
    # to_plot_shuffle = to_plot.copy()

    bps = []
    lbls = []
    shuffle_count = 1000
    mns = [[] for j in range(3)]
    mns_shuffle = [[[] for sc in range(shuffle_count)] for j in range(3)]
    
    session_comps = {}
    for k in range(len(gps)):
        # c = gp_colors[k]
        ses_visited = []
        for j in range(3):
            # get group means + CI
            to_plot_now = to_plot[to_plot['group'] == gp_labels[j]]
            bp = ax.boxplot(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], positions=[k*3+j*.5], widths=0.5, 
                        notch=False, patch_artist=True,
                        boxprops=dict(facecolor=gp_colors[j],color='k'),
                        capprops=dict(color='k'),
                        whiskerprops=dict(color='k'),
                        flierprops=dict(color='k', markeredgecolor='k'),
                        medianprops=dict(color='k'),
                        showmeans=False, 
                        meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
            if k == 0:
                bps.append(bp['boxes'][0])   

            mn = np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
            if mn == mn:
                mns[j].append(mn)

            for j2 in range(3):
                to_plot_now2 = to_plot[to_plot['group'] == gp_labels[j2]]
                if j != j2:
                    # welch's t-test
                    if len(to_plot_now[to_plot_now['session_id'] == gps[k]][metric]) > 1 and len(to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric]) > 1:
                        # res = ttest_ind(to_plot_now[to_plot_now['session_id'] == gps[k]][metric],
                        #                     to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric],
                        #                     usevar='unequal', alternative='two-sided')
                        # pvalue = res[1]
                        # mann-whitney u test
                        _, pvalue = mannwhitneyu(to_plot_now[to_plot_now['session_id'] == gps[k]][metric],
                                                to_plot_now2[to_plot_now2['session_id'] == gps[k]][metric],
                                                alternative='two-sided')
                        sorted_labels = np.sort([gp_labels[j], gp_labels[j2]])
                        pair_id = sorted_labels[0] + '_' + sorted_labels[1]
                        if pair_id not in session_comps:
                            session_comps[pair_id] = []
                        if pair_id not in ses_visited:
                            session_comps[pair_id].append(pvalue)
                            ses_visited.append(pair_id)
                        

        for j in range(3):
            for sc in range(shuffle_count):
                out = _single_shuffle(to_plot_shuffle, gps[k], metric, gp_labels[j])
                if out == out:
                    mns_shuffle[j][sc].append(out)

    # benjamini-hochberg correction
    pvalsANT_B6 = np.array(session_comps['ANT_B6'])
    pvalsB6_NON = np.array(session_comps['B6_NON'])
    pvalsANT_NON = np.array(session_comps['ANT_NON'])
    pvals = np.concatenate((pvalsANT_B6, pvalsB6_NON, pvalsANT_NON))
    reject, pvals_corrected, alphacSidak, alphacBonf = multitest.multipletests(pvals, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    bh_ant_b6 = np.ones(len(pvalsANT_B6))
    bh_ant_b6[pvals_corrected[:len(pvalsANT_B6)] <= 0.05] = 0
    bh_b6_non = np.ones(len(pvalsB6_NON))
    bh_b6_non[pvals_corrected[len(pvalsANT_B6):len(pvalsANT_B6)+len(pvalsB6_NON)] <= 0.05] = 0
    bh_ant_non = np.ones(len(pvalsANT_NON))
    bh_ant_non[pvals_corrected[len(pvalsANT_B6)+len(pvalsB6_NON):] <= 0.05] = 0

    plt.setp(ax1.get_xticklabels(), visible=False)

    ax = plt.subplot(gs_sub[1])
    bh_ant_b6 = np.hstack((bh_ant_b6, [.5, .5]))
    ax.imshow(np.expand_dims(bh_ant_b6, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ax.set_yticks([])
    ax.set_ylabel('ANT-B6', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)


    ax = plt.subplot(gs_sub[2])
    bh_b6_non = np.hstack((bh_b6_non, [.5,.5]))
    ax.imshow(np.expand_dims(bh_b6_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ax.set_yticks([])
    ax.set_ylabel('B6-NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)


    ax = plt.subplot(gs_sub[3])
    bh_ant_non = np.hstack((bh_ant_non, [.5]))
    ax.imshow(np.expand_dims(bh_ant_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ax.set_yticks([])
    ylbl = ax.set_ylabel('ANT-NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)




    # mann kendall
    import pymannkendall as mk
    from statsmodels.stats import multitest
    lbls = []
    lbl_colors = []
    empirical = []
    ps = []
    mns_shuffle = np.array(mns_shuffle)
    print(mns_shuffle.shape)
    print('Metric: ' + score)
    for j in range(3):
        # polyfit 

        memp, c = np.polyfit(np.arange(len(mns[j])), mns[j], 1)
        empirical.append(memp)

        shuffled = []
        for sc in range(shuffle_count):
            # if len(mns_shuffle[j,sc]) > 0:
            ses_dist = mns_shuffle[j,sc]
            mshuffled, c = np.polyfit(np.arange(len(ses_dist)), ses_dist, 1)
            shuffled.append(mshuffled)
                # result = mk.original_test(ses_dist)
                # ps.append(result.p)


        # # two sided p-value 
        pgreater = np.sum(np.array(shuffled) < empirical[j]) / len(shuffled) 
        plower = np.sum(np.array(shuffled) > empirical[j]) / len(shuffled)
        pvalue = np.min([pgreater, plower]) * 2
        ps.append(pvalue)
        print('Group: ' + gp_labels[j])
        print('Empirical: ' + str(empirical[j]))
        print('Shuffled: ' + str(np.mean(shuffled)) + ' ± ' + str(np.std(shuffled)))
        # print('p-value: ' + str(np.min([pgreater, plower]) * 2))


        # if empirical[j] > np.mean(shuffled):
        #     tag = 'greater'
        # elif empirical[j] < np.mean(shuffled):
        #     tag = 'lower'
        # lbl = 'Slope is ' + str(tag) + ' than shuffled: ' + str(np.round(memp, 2)) + ' , p = ' + str(np.round(pvalue, 3)) 
        # lbls.append(lbl)
        lbl = gp_labels[j] + ' slope: ' + str(np.round(memp, 2)) 
        lbl_colors.append('k')
        lbls.append(lbl)

    # result = mk.original_test(ses_dist)
    print(ps)
    out = multitest.multipletests(ps, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    cc = 0
    for case in out[0]:
        if case:
            lbl_colors[cc] = 'r'
            # lbls[cc] = lbls[cc] + ' & is sig after BH correction'
        # else:
            # lbls[cc] = lbls[cc] + ' & is NOT sig after BH correction'
        cc += 1
    print(out)

    ax1.legend(bps, lbls, loc='upper right')
    # color label text in legend
    for text, color in zip(ax1.legend_.get_texts(), lbl_colors):
        text.set_color(color)

    ax1.set_title(score)
    # ax.set_xlabel('Session')
    ax1.set_title(titles_to_use[i])
    ax1.set_ylabel('Quantile')
              
    # ax1.set_xticks(np.arange(len(gps)) * 3 + 1.25/2)
    # ax1.set_xticklabels(gps)
    # ax1.set_xlim([-1.25/2, len(gps) * 3 - 1.25/2])

    ax1.set_xticks(np.arange(len(gps)) * 3 + .5)
    ax1.set_xticklabels(gps)
    # ax1.set_xlim([-.5/2, len(gps) * 3 - .5/2])


fig.suptitle('All indiv. cell-session appearances')
# fig.suptitle('Averaged by session')
fig.tight_layout()
plt.show()

In [None]:
bh_b6_non

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Sample data
data = [1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5, 5, 6, 6, 6]

# Create a figure and axis
fig, ax = plt.subplots()

# Replace boxplots with vertical KDE plots
sns.kdeplot(data, ax=ax, vertical=True, color='blue')

# Set labels and title
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_title('Distribution of Data (KDE)')

# Show the plot
plt.show()


In [None]:
bh_ant_non.shape

In [None]:
session_comps['B6_NON']

In [None]:
pvalsANT_B6

In [None]:
bh_ant_b6

In [None]:
pvals

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# Create a figure and a GridSpec to control subplot layout
fig = plt.figure(figsize=(10, 10))  # Adjust the figure size as needed

# Create a GridSpec for the larger square plot
gs_main = gridspec.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1])  # Adjust width_ratios and height_ratios as needed

for row in range(2):
    for col in range(2):
        # Subplot arrangements for each quadrant
        gs_sub = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=gs_main[row, col], height_ratios=[1, 0.5, 0.5, 0.5])  # Adjust height_ratios as needed
        
        # Create the larger subplot
        ax1 = plt.subplot(gs_sub[0])
        ax1.plot([1, 2, 3], [4, 5, 6])
        ax1.set_title('Larger Subplot')

        ax2 = plt.subplot(gs_sub[1], sharex=ax1)
        ax2.plot([1, 2, 3], [2, 1, 3])

        ax3 = plt.subplot(gs_sub[2], sharex=ax1)
        ax3.plot([1, 2, 3], [0.5, 1.5, 0.7])

        ax4 = plt.subplot(gs_sub[3], sharex=ax1)
        ax4.plot([1, 2, 3], [3, 2, 1])

        plt.setp(ax1.get_xticklabels(), visible=False)
        plt.setp(ax2.get_xticklabels(), visible=False)
        plt.setp(ax3.get_xticklabels(), visible=False)

# Adjust spacing between subplots
plt.tight_layout()

# Display the plot
plt.show()


In [None]:
out

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind

scores_to_use = ['whole', 'spike_density', 'field', 'binary']
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary']
gps = np.unique(df['session_id'])
gp_labels = ['ANT', 'B6', 'NON']
gp_colors = ['r', 'b', 'g']
np.random.seed(0)
def _single_shuffle(to_plot_shuffle, sesgp, metric, gplbl):
    # to_plot_shuffle['group'] = np.random.permutation(to_plot_shuffle['group'].loc[:].values)
    vals = to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'].to_numpy()
    # shuffle the vals
    # for i in range(shuffle_count):
    np.random.shuffle(vals)
    to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'] = vals

    use = to_plot_shuffle[to_plot_shuffle['session_id'] == sesgp]
    use = use[use['group'] == gplbl][metric]
    mn = np.mean(use)
    # print(metric, sesgp, mn, gplbl)
    return mn

fig = plt.figure(figsize=(25, 20))

# metric = 'obj_w'
metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    ax = fig.add_subplot(2, 2, i+1)

    # # every row for that score
    to_plot = df[df['score'] == score]
    # # scores averaged for each animal
    # to_plot = df[df['score'] == score].groupby(['group', 'name']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name']).count().reset_index()
    # scores averaged for each session
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
    # scores averaged for each neuron 
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).count().reset_index()

    to_plot_shuffle = to_plot.copy()

    # plot boxplot for each group
    # bp = sns.boxplot(x='group', y='obj_w', data=to_plot, ax=ax)
    # sns.swarmplot(x='group', y='obj_w', data=to_plot, ax=ax, color='black', alpha=0.5)
    bps = []
    lbls = []
    shuffle_count = 1000
    mns = [[] for j in range(3)]
    mns_shuffle = [[[] for sc in range(shuffle_count)] for j in range(3)]
    for k in range(len(gps)):
        # c = gp_colors[k]
        for j in range(3):
            # get group means + CI
            to_plot_now = to_plot[to_plot['group'] == gp_labels[j]]
            # means = to_plot.groupby('session_id')[metric].mean().round(2)
            # stds = to_plot.groupby('session_id')[metric].std()
            # n = to_plot.groupby('session_id')[metric].count()
            # sems = stds / np.sqrt(n)
            # sems = sems.round(2)
            bp = ax.boxplot(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], positions=[k*3+j*.5], widths=0.5, 
                        notch=False, patch_artist=True,
                        boxprops=dict(facecolor=gp_colors[j],color='k'),
                        capprops=dict(color='k'),
                        whiskerprops=dict(color='k'),
                        flierprops=dict(color='k', markeredgecolor='k'),
                        medianprops=dict(color='k'),
                        showmeans=False, 
                        meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
            if k == 0:
                bps.append(bp['boxes'][0])   

            mn = np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
            if mn == mn:
                mns[j].append(mn)
                        
            # lbls.append(str(means[j]) + ' ± ' + str(sems[j]) + ' cm, N = ' + str(n[j]))


        for j in range(3):
            for sc in range(shuffle_count):
                out = _single_shuffle(to_plot_shuffle, gps[k], metric, gp_labels[j])
                if out == out:
                    mns_shuffle[j][sc].append(out)
            
    ax.set_xticks(np.arange(len(gps)) * 3 + 1.25/2)
    ax.set_xticklabels(gps)


    # mann kendall
    import pymannkendall as mk
    from statsmodels.stats import multitest
    lbls = []
    empirical = []
    ps = []
    mns_shuffle = np.array(mns_shuffle)
    print(mns_shuffle.shape)
    print('Metric: ' + score)
    for j in range(3):
        # polyfit 

        memp, c = np.polyfit(np.arange(len(mns[j])), mns[j], 1)
        empirical.append(memp)

        shuffled = []
        for sc in range(shuffle_count):
            # if len(mns_shuffle[j,sc]) > 0:
            ses_dist = mns_shuffle[j,sc]
            mshuffled, c = np.polyfit(np.arange(len(ses_dist)), ses_dist, 1)
            shuffled.append(mshuffled)
                # result = mk.original_test(ses_dist)
                # ps.append(result.p)


        # # two sided p-value 
        pgreater = np.sum(np.array(shuffled) < empirical[j]) / len(shuffled) 
        plower = np.sum(np.array(shuffled) > empirical[j]) / len(shuffled)
        pvalue = np.min([pgreater, plower]) * 2
        ps.append(pvalue)
        # print('Group: ' + gp_labels[j])
        # print('Empirical: ' + str(empirical[j]))
        # print('Shuffled: ' + str(np.mean(shuffled)) + ' ± ' + str(np.std(shuffled)))
        # print('p-value: ' + str(np.min([pgreater, plower]) * 2))


        if empirical[j] > np.mean(shuffled):
            tag = 'greater'
        elif empirical[j] < np.mean(shuffled):
            tag = 'lower'
        lbl = 'Slope is ' + str(tag) + ' than shuffled: ' + str(np.round(memp, 2)) + ' , p = ' + str(np.round(pvalue, 3)) 
        lbls.append(lbl)

    # result = mk.original_test(ses_dist)
    print(ps)
    out = multitest.multipletests(ps, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    cc = 0
    for case in out[0]:
        if case:
            lbls[cc] = lbls[cc] + ' & is sig after BH correction'
        else:
            lbls[cc] = lbls[cc] + ' & is NOT sig after BH correction'
        cc += 1
    print(out)

    ax.legend(bps, lbls, loc='upper right')
    ax.set_title(score)
    ax.set_xlabel('Session')
    ax.set_title(titles_to_use[i])
    ax.set_ylabel('Wasserstein distnaces (cm)')

fig.suptitle('Averaged by session')
fig.tight_layout()
plt.show()

In [None]:
""" Amount of remapping per group per angle """
from statsmodels.stats.weightstats import ttest_ind

# scores_to_use = ['whole', 'spike_density', 'field', 'binary']
scores_to_use = ['spike_density' for i in range(4)]
titles_to_use = ['Spike density' for i in range(4)]
obj_angles = [0, 90, 180, 270]
# titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary']
gps = np.unique(df['session_id'])
gp_labels = ['ANT', 'B6', 'NON']
gp_colors = ['r', 'b', 'g']

fig = plt.figure(figsize=(25, 15))

metric = 'obj_w'
# metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    ax = fig.add_subplot(2, 2, i+1)
    df_angle = df[df['obj_a'] == obj_angles[i]]
    # # every row for that score
    to_plot = df_angle[df_angle['score'] == score]
    # # scores averaged for each animal
    # to_plot = df[df['score'] == score].groupby(['group', 'name']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name']).count().reset_index()
    # # scores averaged for each session
    # to_plot = df_angle[df_angle['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    # to_plot_count = df_angle[df_angle['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
    # scores averaged for each neuron 
    # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).mean().reset_index()
    # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','tetrode', 'unit_id']).count().reset_index()

    

    # plot boxplot for each group
    # bp = sns.boxplot(x='group', y='obj_w', data=to_plot, ax=ax)
    # sns.swarmplot(x='group', y='obj_w', data=to_plot, ax=ax, color='black', alpha=0.5)
    bps = []
    lbls = []
    mns = [[] for j in range(3)]
    for k in range(len(gps)):
        # c = gp_colors[k]
        for j in range(3):
            # get group means + CI
            to_plot_now = to_plot[to_plot['group'] == gp_labels[j]]
            # means = to_plot.groupby('session_id')[metric].mean().round(2)
            # stds = to_plot.groupby('session_id')[metric].std()
            # n = to_plot.groupby('session_id')[metric].count()
            # sems = stds / np.sqrt(n)
            # sems = sems.round(2)
            bp = ax.boxplot(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], positions=[k*3+j*.5], widths=0.5, 
                        notch=False, patch_artist=True,
                        boxprops=dict(facecolor=gp_colors[j],color='k'),
                        capprops=dict(color='k'),
                        whiskerprops=dict(color='k'),
                        flierprops=dict(color='k', markeredgecolor='k'),
                        medianprops=dict(color='k'),
                        showmeans=False, 
                        meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
            if k == 0:
                bps.append(bp['boxes'][0])   

            mn = np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
            if mn == mn:
                mns[j].append(mn)
                        
            # lbls.append(str(means[j]) + ' ± ' + str(sems[j]) + ' cm, N = ' + str(n[j]))
    
    ax.set_xticks(np.arange(len(gps)) * 3 + 1.25/2)
    ax.set_xticklabels(gps)
    
    lbls = []
    for j in range(len(mns)):
        # polyfit 

        m, c = np.polyfit(np.arange(len(mns[j])), mns[j], 1)
        lbl = 'Mean slope: ' + str(np.round(m, 2)) + ' cm/session'
        lbls.append(lbl)
    ax.legend(bps, lbls, loc='upper right')
    ax.set_title(score)
    ax.set_xlabel('Session')
    ax.set_title(titles_to_use[i] + ' ' + str(obj_angles[i]) + '°')
    ax.set_ylabel('Wasserstein distnaces (cm)')

fig.suptitle('All individual cell-session appearances')
fig.tight_layout()
plt.show()