In [None]:
# To be able to make edits to repo without having to restart notebook
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, mannwhitneyu, wilcoxon, ttest_rel, ttest_ind
import seaborn as sns
from matplotlib.patches import FancyArrowPatch
from matplotlib.colors import ColorConverter


PROJECT_PATH = os.getcwd()
sys.path.append(PROJECT_PATH)

In [None]:
# df = pd.read_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\LEC_full_merged_scores.xlsx')
df = pd.read_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC.xlsx')
# df = pd.read_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC_only_trace_cells.xlsx')
df.columns

In [None]:
""" QUALITY CHECK DATA """ 
nan_idx = np.where(df['obj_q_0'].isna())[0]
not_nan_idx = np.where(~df['obj_q_0'].isna())[0]
nan_dates = (df['date'][nan_idx].unique())
nan_names = (df['name'][nan_idx].unique())

print('Number of NaN rows: ' + str(len(nan_idx)))
print('Animals with NaN rows: ' + str(nan_names))
print('Dates with NaN rows: ' + str(nan_dates))

# remove rows with NaN values
print('Removing nan rows')
df = df.iloc[not_nan_idx]

In [None]:
""" FILTERING """

""" REMOVE FIELD WITH LOW COVERAGE % """
# remove rows with field_coverage < 0.1
# df = df[df['field_coverage'] >= 0.1]

""" ONLY KEEPING MAIN FIELD """
# remove rows where field_id is not 1 and score is not 'whole' or 'spike_density'
# df = df[(df['field_id'] == 1) | (df['score'] == 'whole') | (df['score'] == 'spike_density')]

""" CHOOSING ANGLE FOR EACH ROW """
# for each row, choose lowest quantile from ['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']
# df['obj_q'] = df[['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']].min(axis=1)
df['obj_q'] = df['obj_q_NO']
df['obj_a'] = df[['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']].idxmin(axis=1)
# convert obj_a to degrees
df['obj_a'] = df['obj_a'].apply(lambda x: int(x.split('_')[2]))
# use obj_wass with angle of min quantile
df['obj_w'] = df.apply(lambda x: x['obj_wass_' + str(x['obj_a'])], axis=1)

""" ASSESSING SIG FOR EACH ROW AT EACH ANGLE """
# obj_s_rows = ['obj_s_0', 'obj_s_90', 'obj_s_180', 'obj_s_270']
# obj_q_rows = ['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']
# for i in range(len(obj_s_rows)):
#     obj_q_x = obj_q_rows[i]
#     df[obj_s_rows[i]] = df[obj_q_x].apply(lambda x: 1 if x < quantile_threshold else 0)


# df2 = df[df['score'] == 'whole'].copy()
df2 = df.copy()
# group_by_cell = ['group', 'name', 'depth', 'date','tetrode', 'unit_id']
# df2 = df2.groupby(group_by_cell).mean().reset_index()
cts = df2[df2['spike_count'] > 30000]['group'].value_counts()
for nm in ['B6', 'NON', 'ANT']:
    if nm not in cts:
        cts[nm] = 0
print('Spike count upper of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(30000 , str(len(df2[df2['spike_count'] > 30000])), 
                       cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['spike_count'] < 100]['group'].value_counts()
for nm in ['B6', 'NON', 'ANT']:
    if nm not in cts:
        cts[nm] = 0
print('Spike count lower of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(100 , str(len(df2[df2['spike_count'] < 100])),
                          cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['information'] < 0.25]['group'].value_counts()
for nm in ['B6', 'NON', 'ANT']:
    if nm not in cts:
        cts[nm] = 0
print('Spatial info of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(0.25 , str(len(df2[df2['information'] < 0.25])),
                            cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['selectivity'] < 5]['group'].value_counts()
for nm in ['B6', 'NON', 'ANT']:
    if nm not in cts:
        cts[nm] = 0
print('Selectivity of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(5 , str(len(df2[df2['selectivity'] < 5])),
                            cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['iso_dist'] < 5]['group'].value_counts()
for nm in ['B6', 'NON', 'ANT']:
    if nm not in cts:
        cts[nm] = 0
print('Isolation distance of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(5 , str(len(df2[df2['iso_dist'] < 5])),
                            cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['firing_rate'] > 20]['group'].value_counts()
for nm in ['B6', 'NON', 'ANT']:
    if nm not in cts:
        cts[nm] = 0
print('Firing rate of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(20 , str(len(df2[df2['firing_rate'] > 20])),
                            cts['ANT'], cts['B6'], cts['NON']))
cts = df2[df2['spike_width'] < 0.00005]['group'].value_counts()
for nm in ['B6', 'NON', 'ANT']:
    if nm not in cts:
        cts[nm] = 0
print('Spike width of {} would drop {} cells including {} ANT, {} B6 and {} NON'.format(0.00005 , str(len(df2[df2['spike_width'] < 0.00005])),
                            cts['ANT'], cts['B6'], cts['NON']))

# drop spike count column 
df2 = df2.drop(columns=['spike_count'])
# rename spike_count.1 to spike_count
df2 = df2.rename(columns={'spike_count.1': 'spike_count'})

# df_unfiltered = df2.copy()
# df_unfiltered.to_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC_unfiltered.xlsx')

# df2 = df2[df2['spike_count'] < 30000]
# df2 = df2[df2['spike_count'] > 100]
# df2 = df2[df2['information'] > 0.25]
# df2 = df2[df2['selectivity'] > 5]
df2 = df2[df2['iso_dist'] > 5]
# df2 = df2[df2['firing_rate'] < 80]
# df2 = df2[df2['spike_width'] > 0.00005]

print('Remaining cells: ' + str(len(df2)) + ' of which ' + str(len(df2[df2['group'] == 'ANT'])) + ' ANT, ' + str(len(df2[df2['group'] == 'B6'])) + ' B6 and ' + str(len(df2[df2['group'] == 'NON'])) + ' NON')
df = df2.copy()

In [None]:
# save
df2.to_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC_filtered.xlsx')

In [None]:
fig = plt.figure(figsize=(16, 12))

j = 1
for scr in ['whole','spike_density','field','binary']:
    mxs = []
    axs = []
    for i in ['B6','NON','ANT']:
        ax = fig.add_subplot(4, 3, j)

        df_to_use = df2[df2['group'] == i]
        # plt.hist(df_to_use[df_to_use['score'] == 'whole']['obj_w'], bins=100)

        # if i == 'B6':
        #     desired_order = ['B6-LEC1', 'B6-1M', 'B6-2M', 'B6-LEC2']
        #     df_to_use.loc[:, 'name'] = pd.Categorical(df_to_use['name'], categories=desired_order, ordered=True)
        # elif i == 'NON':
        #     desired_order = ['NON-INT-01', 'NON-88-1', 'NON-73-6', 'NON-INT-02', 'NON-INT-03']
        #     df_to_use.loc[:, 'name'] = pd.Categorical(df_to_use['name'], categories=desired_order, ordered=True)

            
        sns.histplot(data=df_to_use[df_to_use['score'] == scr], x='obj_q', bins=50, hue='name', kde=False, ax=ax, stat='density', common_norm=False)
        ax.set_title(i)
        ax.set_xlabel(str(scr) + ' obj_q')
        mxs.append(df_to_use[df_to_use['score'] == scr]['obj_q'].max())
        axs.append(ax)

        j += 1
    
    for ax in axs:
        ax.set_xlim(0, max(mxs))

fig.suptitle('Animal quantile distributions for each score', x=0.5, fontweight='bold')
fig.tight_layout()
plt.show()

In [None]:
fig = plt.figure(figsize=(8, 5))
dfuse = df2[df2['score'] == 'whole']
obj_order = ['0', '90', '180', '270', 'NO']
grp_order = ['B6', 'NON', 'ANT']
# reorder dfuse based on object_location
dfuse.loc[:, 'object_location'] = pd.Categorical(dfuse['object_location'], categories=obj_order, ordered=True)
dfuse.loc[:, 'group'] = pd.Categorical(dfuse['group'], categories=grp_order, ordered=True)
sns.histplot(data=dfuse, x='group', multiple='dodge', shrink=0.8, stat='density', common_norm=True, hue='object_location')
ttle = '% of sessions for each object location across all groups'
fig.suptitle(ttle, x=0.5, fontweight='bold')
fig.tight_layout()
plt.show()

fig = plt.figure(figsize=(8, 5))
for grp in ['B6', 'NON', 'ANT']:
    dfuse = df2[df2['score'] == 'whole']
    dfuse = dfuse[dfuse['group'] == grp]
    obj_order = ['0', '90', '180', '270', 'NO']
    # reorder dfuse based on object_location
    dfuse.loc[:, 'object_location'] = pd.Categorical(dfuse['object_location'], categories=obj_order, ordered=True)
    sns.histplot(data=dfuse, x='group', multiple='dodge', shrink=0.8, stat='density', common_norm=True, hue='object_location')
    ttle = '% of sessions for each object location within a group'
clr_palette_store = sns.color_palette()
clr_palette_settings = sns.color_palette('Set2')
fig.suptitle(ttle, x=0.5, fontweight='bold')
fig.tight_layout()
plt.show()

for grp in ['B6', 'NON', 'ANT']:
    fig = plt.figure(figsize=(14, 2))
    for ses in ['session_1', 'session_2', 'session_3', 'session_4', 'session_5', 'session_6', 'session_7']:
        ax = fig.add_subplot(1, 7, int(ses.split('_')[1]))
        if grp == 'B6' and ses == 'session_7':
            pass
        else:
            dfuse = df2[df2['score'] == 'whole']
            dfuse = dfuse[dfuse['group'] == grp]
            dfuse = dfuse[dfuse['session_id'] == ses]
            obj_order = ['0', '90', '180', '270', 'NO']
            # reorder dfuse based on object_location
            dfuse.loc[:, 'object_location'] = pd.Categorical(dfuse['object_location'], categories=obj_order, ordered=True)
            sns.histplot(data=dfuse, x='group', multiple='dodge', shrink=0.8, stat='density', common_norm=True, hue='object_location')
            ax.set_title(ses)
            # turn off legend
            ax.get_legend().remove()
            # turn off x label
            ax.set_xlabel('')
            ax.set_ylim(0, 1)
            
    ttle = '% of object location within that session for {}'.format(grp)
    fig.suptitle(ttle, x=0.5, fontweight='bold')
    fig.tight_layout()
    plt.show()



In [None]:
# Define the order of object locations
obj_order = ['0', '90', '180', '270', 'NO']

# Create a figure for the groups
for grp in ['B6', 'NON', 'ANT']:
    # Filter the DataFrame df2 for the current group
    dfuse = df2[df2['score'] == 'whole']
    dfuse = dfuse[dfuse['group'] == grp]
    total_count = len(dfuse)

    # Iterate through sessions and create subplots
    fig, axes = plt.subplots(1, len(['session_1', 'session_2', 'session_3', 'session_4', 'session_5', 'session_6']), 
                             figsize=(14, 2))
    
    for i, ses in enumerate(['session_1', 'session_2', 'session_3', 'session_4', 'session_5', 'session_6']):
        if grp == 'B6' and ses == 'session_7':
            continue

        # Filter the DataFrame for the current session
        df_session = dfuse[dfuse['session_id'] == ses]

        # Calculate the proportions of each object location for the current session
        obj_loc_counts = [len(df_session[df_session['object_location'] == obj_loc]) for obj_loc in obj_order]
        obj_loc_proportions = [count / total_count for count in obj_loc_counts]

        # Create a bar plot for the current session
        bcolors = clr_palette_store    
        axes[i].bar(obj_order, obj_loc_proportions, color=bcolors, edgecolor='black', linewidth=1.2, alpha=0.8, width=1)
        axes[i].set_title(ses)
        axes[i].set_xlabel('')
        axes[i].set_ylim(0, 0.25)
    # Set a common ylabel for the group
    axes[0].set_ylabel('Density')
    
    # Set the overall title for the group
    plt.suptitle(f'% of object location for {grp} across all sessions', x=0.5, fontweight='bold')
    
    # Adjust the layout and display the plot
    plt.tight_layout()
    plt.show()


In [None]:
# https://stackoverflow.com/questions/11517986/indicating-the-statistically-significant-difference-in-bar-graph

def barplot_annotate_brackets(num1, num2, data, center, height, yerr=None, dh=.05, barh=.05, fs=None, maxasterix=None):
    """ 
    Annotate barplot with p-values.

    :param num1: number of left bar to put bracket over
    :param num2: number of right bar to put bracket over
    :param data: string to write or number for generating asterixes
    :param center: centers of all bars (like plt.bar() input)
    :param height: heights of all bars (like plt.bar() input)
    :param yerr: yerrs of all bars (like plt.bar() input)
    :param dh: height offset over bar / bar + yerr in axes coordinates (0 to 1)
    :param barh: bar height in axes coordinates (0 to 1)
    :param fs: font size
    :param maxasterix: maximum number of asterixes to write (for very small p-values)
    """

    if type(data) is str:
        text = data
    else:
        # * is p < 0.05
        # ** is p < 0.005
        # *** is p < 0.0005
        # etc.
        text = ''
        p = .05

        if data <= 0.05:
            text = '*'
        if data <= 0.01:
            text = '**'
        if data <= 0.001:
            text = '***'
        if data <= 0.0001:
            text = '****'

        if len(text) == 0:
            text = 'n. s.'

    lx, ly = center[num1], height[num1]
    rx, ry = center[num2], height[num2]

    if yerr:
        ly += yerr[num1]
        ry += yerr[num2]

    ax_y0, ax_y1 = plt.gca().get_ylim()
    dh *= (ax_y1 - ax_y0)
    barh *= (ax_y1 - ax_y0)

    y = max(ly, ry) + dh

    barx = [lx, lx, rx, rx]
    bary = [y, y+barh, y+barh, y]
    mid = ((lx+rx)/2, y+barh)

    plt.plot(barx, bary, c='black')

    kwargs = dict(ha='center', va='bottom')
    if fs is not None:
        kwargs['fontsize'] = fs
    if text != 'n. s.':
        kwargs['weight'] = 'bold'
        kwargs['fontsize'] = 25

    plt.text(*mid, text, **kwargs)

In [None]:
paths = [r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\B6vsNON_quantile_indiv_beta_binary.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\B6vsNON_quantile_indiv_beta_field.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\B6vsNON_quantile_indiv_beta_spike_density.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\B6vsNON_quantile_indiv_beta_whole.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\B6vsANT_quantile_indiv_beta_whole.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\B6vsANT_quantile_indiv_beta_binary.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\B6vsANT_quantile_indiv_beta_field.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\B6vsANT_quantile_indiv_beta_spike_density.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\NONvsANT_quantile_indiv_beta_whole.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\NONvsANT_quantile_indiv_beta_binary.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\NONvsANT_quantile_indiv_beta_field.csv",
r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\NONvsANT_quantile_indiv_beta_spike_density.csv"]

score_beta_ps = {'whole': {}, 'spike_density': {}, 'field': {}, 'binary': {}}

for pth in paths:
    fname = pth.split('\\')[-1]
    score = fname.split('_')[-1].split('.')[0]
    if 'density' in score:
        score = 'spike_density'

    comp_group = fname.split('_')[0]
    a1 = comp_group.split('vs')[0]
    a2 = comp_group.split('vs')[1]

    a12 = np.sort([a1, a2])

    cgroup = a12[0] + '_' + a12[1]
    

    data = pd.read_csv(pth)
    score_beta_ps[score][cgroup] = data['Pvalue'][1]

In [None]:
score_beta_ps

In [None]:
""" Amount of remapping per group """
from statsmodels.stats.weightstats import ttest_ind
from statsmodels.stats import multitest
from statsmodels.stats.multitest import multipletests
from matplotlib import gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy import interpolate
from scipy.interpolate import griddata


scores_to_use = ['whole', 'spike_density', 'field', 'binary']
# 'field', 'binary', 'centroid', 'firing_rate']
quad_arrange = [[0,0],[0,1],[1,0],[1,1], [2,0], [2,1]]
titles_to_use = ['Whole-map', 'Spike Density', 'Field', 'Binary', 'Centroid', 'Firing Rate']
gps = np.unique(df['session_id'])
gp_labels = ['B6', 'NON', 'ANT']
gp_colors = ['b', 'g', 'r']

np.random.seed(0)

def _single_shuffle(to_plot_shuffle, sesgp, metric, gplbl):
    vals = to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'].to_numpy()
    np.random.shuffle(vals)
    to_plot_shuffle.loc[to_plot_shuffle['session_id'] == sesgp, 'group'] = vals

    use = to_plot_shuffle[to_plot_shuffle['session_id'] == sesgp]
    use = use[use['group'] == gplbl][metric]
    mn = np.mean(use)
    return mn

fig = plt.figure(figsize=(23, 46))
gs_main = gridspec.GridSpec(3, 2, width_ratios=[1,1], height_ratios=[1,1, 1])  # Adjust width_ratios and height_ratios as needed


# metric = 'obj_w'
metric = 'obj_q'

for i, score in enumerate(scores_to_use):
    # scores averaged for each session
    if score != 'firing_rate':
        to_plot = df[df['score'] == score]
        # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
        # to_plot_count = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
        # to_plot_shuffle = to_plot.copy()
    else:
        to_plot = df[df['score'] == 'whole']
        # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
        # to_plot_count = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).count().reset_index()
        # to_plot_shuffle = to_plot.copy()


    row, col = quad_arrange[i]
    gs_sub = gridspec.GridSpecFromSubplotSpec(5, 1, subplot_spec=gs_main[row, col], height_ratios=[12,12,1,1,1], hspace=0.1)

    ax = plt.subplot(gs_sub[0])
    axf = ax
    bps = []
    lbls = []
    means = []
    sems = []
    n = []
    comps = score_beta_ps[score]
    comp_maxs = []
    for k in range(3):
        c = gp_colors[k]
        if score != 'firing_rate':
            mtouse = metric 
        else:
            mtouse = 'firing_rate'
        bp = ax.boxplot(to_plot[to_plot['group'] == gp_labels[k]][mtouse], positions=[k], widths=0.5, 
                    notch=False, patch_artist=True,
                    boxprops=dict(facecolor=c, color='k'),
                    capprops=dict(color='k'),
                    whiskerprops=dict(color='k'),
                    flierprops=dict(color='k', markeredgecolor='k'),
                    medianprops=dict(color='k'),
                    showmeans=True, 
                    meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
        comp_maxs.append(np.max(to_plot[to_plot['group'] == gp_labels[k]][mtouse]))
        for k2 in range(3):
            if k2 != k:
                lbl_pair = np.sort([gp_labels[k], gp_labels[k2]])
                comp_group = lbl_pair[0] + '_' + lbl_pair[1]
                if comp_group not in comps.keys():
                    comps[comp_group] = np.nan

                    if 'B6' in comp_group and 'ANT' in comp_group:
                        if score != 'firing_rate':
                            to_plot_model = df[df['score'] == score]
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        else:
                            to_plot_model = df[df['score'] == 'whole']
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        model_data = to_plot_model[to_plot_model['group'].isin(['B6', 'ANT'])]
                        group_order = ['B6', 'ANT']  # 'B6' becomes the reference group
                    elif 'B6' in comp_group and 'NON' in comp_group:
                        if score != 'firing_rate':
                            to_plot_model = df[df['score'] == score]
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        else:
                            to_plot_model = df[df['score'] == 'whole']
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        model_data = to_plot_model[to_plot_model['group'].isin(['B6', 'NON'])]
                        group_order = ['B6', 'NON']
                    elif 'ANT' in comp_group and 'NON' in comp_group:
                        if score != 'firing_rate':
                            to_plot_model = df[df['score'] == score]
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        else:
                            to_plot_model = df[df['score'] == 'whole']
                            # .groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
                        model_data = to_plot_model[to_plot_model['group'].isin(['NON', 'ANT'])]
                        group_order = ['NON', 'ANT']
                    # model_data['group'] = pd.Categorical(model_data['group'], categories=group_order, ordered=True)
                    model_data.loc[:, 'group'] = pd.Categorical(model_data.loc[:, 'group'], categories=group_order, ordered=True)

        means.append(np.mean(to_plot[to_plot['group'] == gp_labels[k]][metric]))
        sems.append(np.std(to_plot[to_plot['group'] == gp_labels[k]][metric]) / np.sqrt(len(to_plot[to_plot['group'] == gp_labels[k]][metric])))
        n.append(len(to_plot[to_plot['group'] == gp_labels[k]][metric]))
                                    
        bps.append(bp['boxes'][0])
        # lbls.append(str(means[k]) + ' ± ' + str(sems[k]) + ' cm, N = ' + str(n[k]))
        lbls.append(gp_labels[k])
        #  + ': N = ' + str(n[k]))
    
    ax.set_xticklabels(gp_labels)
    ax.legend(bps, lbls, loc='upper right')
    ax.set_xlabel('Group')
    ax.set_ylabel('EMD quantile')

    # benjamini hochberg correction
    kys, vals = zip(*comps.items())
    # accepted, pvals_corrected, _, _ = multipletests(vals, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    pvals_corrected = vals
    accepted = np.array([True for x in range(len(vals))])

    p_count = 0
    for comp_key, val in comps.items():
        comparison = comps[comp_key]
        if 'ANT' in comp_key and 'B6' in comp_key:
            nme = [0,2]
        elif 'ANT' in comp_key and 'NON' in comp_key:
            nme = [1,2]
        elif 'B6' in comp_key and 'NON' in comp_key:
            nme = [0,1]

        # if accepted[k]:
        barplot_annotate_brackets(nme[0],nme[1],pvals_corrected[p_count],[0,1,2], comp_maxs, maxasterix=5)
        
        p_count += 1

    # ax = fig.add_subplot(2, 2, i+1)
    ax = plt.subplot(gs_sub[1])
    ax1 = ax

    # # every row for that score
    if score != 'firing_rate':
        to_plot = df[df['score'] == score]
        # to_plot = df[df['score'] == score].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
    else:
        to_plot = df[df['score'] == 'whole']
        # to_plot = df[df['score'] == 'whole'].groupby(['group', 'name', 'depth', 'date','stim','session_id']).mean().reset_index()
        metric = 'firing_rate'
        # to_plot = to_plot[to_plot['depth']]
    to_plot_shuffle = to_plot.copy()



    bps = []
    lbls = []
    shuffle_count = 1000
    mns = [[] for j in range(3)]
    mns_shuffle = [[[] for sc in range(shuffle_count)] for j in range(3)]
    
    group_ses_frs = {'ANT': [], 'B6': [], 'NON': []}
    for k in range(len(gps)):
        # c = gp_colors[k]
        ses_visited = []
        # ses_frs = []
        for j in range(3):
            # get group means + CI
            to_plot_now = to_plot[to_plot['group'] == gp_labels[j]]
            ses_fr = to_plot_now[to_plot_now['session_id'] == gps[k]][metric]
            # ses_fr = to_plot_now[to_plot_now['session_id'] == gps[k]]['information']
            bp = ax.boxplot(to_plot_now[to_plot_now['session_id'] == gps[k]][metric], positions=[k*3+j*.5], widths=0.5, 
                        notch=False, patch_artist=True,
                        boxprops=dict(facecolor=gp_colors[j],color='k'),
                        capprops=dict(color='k'),
                        whiskerprops=dict(color='k'),
                        flierprops=dict(color='k', markeredgecolor='k'),
                        medianprops=dict(color='k'),
                        showmeans=False, 
                        meanprops=dict(markeredgecolor='k', markerfacecolor='k', markersize=10))
            if k == 0:
                bps.append(bp['boxes'][0])   

            if len(ses_fr) > 0 and np.mean(ses_fr) == np.mean(ses_fr):       
                group_ses_frs[gp_labels[j]].append(np.mean(ses_fr))

            mn = np.mean(to_plot_now[to_plot_now['session_id'] == gps[k]][metric])
            if mn == mn:
                mns[j].append(mn)

        for j in range(3):
            for sc in range(shuffle_count):
                out = _single_shuffle(to_plot_shuffle, gps[k], metric, gp_labels[j])
                if out == out:
                    mns_shuffle[j][sc].append(out)

    # plt.setp(ax1.get_xticklabels(), visible=False)

    ax = plt.subplot(gs_sub[4])
    # bh_ant_b6 = np.hstack((bh_ant_b6, [.5]))
    # ax.imshow(np.expand_dims(bh_ant_b6, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    ant_fr = group_ses_frs['ANT']
    ant_interp = np.linspace(0, len(ant_fr)-1, shuffle_count)
    ant_fr_smooth = np.interp(ant_interp, np.arange(len(ant_fr)), ant_fr)
    im = ax.imshow(np.expand_dims(ant_fr_smooth, 0), cmap='jet', aspect='auto', extent=[0, len(gps), 0, 1])
    # colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ylbl = ax.set_ylabel('ANT', labelpad=15, rotation=0)
    # pos = ylbl.get_position()
    # ylbl.set_position((pos[0], pos[1] -.5))
    ax.set_yticks([])
    # ax.set_ylabel('ANT-B6', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='x', which='both', bottom=False, top=False)


    ax = plt.subplot(gs_sub[2])
    # bh_b6_non = np.hstack((bh_b6_non, [.5]))
    # ax.imshow(np.expand_dims(bh_b6_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    b6_fr = group_ses_frs['B6']
    b6_interp = np.linspace(0, len(b6_fr)-1, shuffle_count)
    b6_fr_smooth = np.interp(b6_interp, np.arange(len(b6_fr)), b6_fr)
    im = ax.imshow(np.expand_dims(b6_fr_smooth, 0), cmap='jet', aspect='auto', extent=[0, len(gps), 0, 1])
    # colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ylbl = ax.set_ylabel('B6', labelpad=15, rotation=0)
    # pos = ylbl.get_position()
    # ylbl.set_position((pos[0], pos[1]-.5))
    ax.set_yticks([])
    # ax.set_ylabel('B6-NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    plt.setp(ax.get_xticklabels(), visible=False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='x', which='both', bottom=False, top=False)


    ax = plt.subplot(gs_sub[3])
    # # bh_ant_non = np.hstack((bh_ant_non, [.5]))
    # ax.imshow(np.expand_dims(bh_ant_non, 0), cmap='Greys_r', aspect='auto', extent=[0, len(gps), 0, 1], vmin=0, vmax=1)
    non_fr = group_ses_frs['NON']
    non_interp = np.linspace(0, len(non_fr)-1, shuffle_count)
    non_fr_smooth = np.interp(non_interp, np.arange(len(non_fr)), non_fr)
    im = ax.imshow(np.expand_dims(non_fr_smooth, 0), cmap='jet', aspect='auto', extent=[0, len(gps), 0, 1])
    # colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ylbl = ax.set_ylabel('NON', labelpad=15, rotation=0)
    # pos = ylbl.get_position()
    # ylbl.set_position((pos[0], pos[1] - .5))
    plt.setp(ax.get_xticklabels(), visible=False)

    ax.set_yticks([])
    # ylbl = ax.set_ylabel('ANT-NON', labelpad=15, rotation=45)
    ax.set_xticks(np.arange(len(gps)) + 0.5)
    ax.set_xticklabels(gps)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='x', which='both', bottom=False, top=False)


    # mann kendall
    import pymannkendall as mk
    from statsmodels.stats import multitest
    lbls = []
    lbl_colors = []
    empirical = []
    slopes = []
    ps = []
    mns_shuffle = np.array(mns_shuffle)
    print(mns_shuffle.shape)
    print('Metric: ' + score)
    for j in range(3):
        # polyfit 

        # slp, c = np.polyfit(np.arange(len(mns[j])), mns[j], 1)
        # slopes.append(slp)

        memp = mk.original_test(mns[j])
        empirical.append(memp.z)
        slopes.append(memp.slope)


        shuffled = []
        for sc in range(shuffle_count):
            # if len(mns_shuffle[j,sc]) > 0:
            ses_dist = mns_shuffle[j,sc]
            # mshuffled, c = np.polyfit(np.arange(len(ses_dist)), ses_dist, 1)
            mshuffled = mk.original_test(ses_dist)
            shuffled.append(mshuffled.z)
                # result = mk.original_test(ses_dist)
                # ps.append(result.p)


        # # two sided p-value 
        pgreater = np.sum(np.array(shuffled) < empirical[j]) / len(shuffled) 
        plower = np.sum(np.array(shuffled) > empirical[j]) / len(shuffled)
        pvalue = np.min([pgreater, plower]) * 2
        ps.append(pvalue)
        print('Group: ' + gp_labels[j])
        print('Empirical: ' + str(empirical[j]))
        print('Shuffled: ' + str(np.mean(shuffled)) + ' ± ' + str(np.std(shuffled)))
        # print('p-value: ' + str(np.min([pgreater, plower]) * 2))


        # if empirical[j] > np.mean(shuffled):
        #     tag = 'greater'
        # elif empirical[j] < np.mean(shuffled):
        #     tag = 'lower'
        # lbl = 'Slope is ' + str(tag) + ' than shuffled: ' + str(np.round(memp, 2)) + ' , p = ' + str(np.round(pvalue, 3)) 
        # lbls.append(lbl)

        # lbl = gp_labels[j] + ' 
        lbl = 'slope: ' + str(np.round(memp.slope, 4)) 
        lbl_colors.append('k')
        lbls.append(lbl)

    # result = mk.original_test(ses_dist)
    print(ps)
    out = multitest.multipletests(ps, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
    cc = 0
    for case in out[0]:
        if case:
            lbl_colors[cc] = 'r'
            # lbls[cc] = lbls[cc] + ' & is sig after BH correction'
        # else:
            # lbls[cc] = lbls[cc] + ' & is NOT sig after BH correction'
        cc += 1
    print(out)

    ax1.legend(bps, lbls, loc='upper right')
    # color label text in legend
    idx = 0
    for text, color in zip(ax1.legend_.get_texts(), lbl_colors):
        # text.set_color(color)
        # set font weight to bold
        if color == 'r':
            text.set_weight('bold')
        pval = out[1][idx]
        if pval <= 0.05:
            astk = '*'
        if pval <= 0.01:
            astk = '**'
        if pval <= 0.001:
            astk = '***'
        if pval <= 0.0001:
            astk = '****'
        if pval > 0.05:
            astk = 'n.s'
        print(text)
        print(text.get_text())
        new_text = str(text.get_text()) + ' ' + str(astk)
        text.set_text(new_text)
        idx += 1

    # ax1.set_title(score)
    # ax.set_xlabel('Session')
    axf.set_title(titles_to_use[i])
    ax1.set_ylabel('EMD quantile')
              
    # ax1.set_xticks(np.arange(len(gps)) * 3 + 1.25/2)
    # ax1.set_xticklabels(gps)
    # ax1.set_xlim([-1.25/2, len(gps) * 3 - 1.25/2])

    ax1.set_xticks(np.arange(len(gps)) * 3 + .5)
    ax1.set_xticklabels(gps)
    ax1.set_ylim(-0.05,)
    # ax1.set_xlim([-.5/2, len(gps) * 3 - .5/2])


fig.suptitle('All indiv. cell-session appearances', fontweight='bold')
# fig.suptitle('Averaged by session', fontweight='bold')
gs_main.tight_layout(fig, rect=[0, 0, 1, 0.98])
# fig.suptitle('Averaged by session')
fig.tight_layout()
plt.show()

In [None]:
# stop()

In [None]:
# df = pd.read_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\LEC_full_merged_scores.xlsx')
df = pd.read_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC.xlsx')

# df['obj_q'] = df['obj_q_NO']
df['obj_q'] = df[['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']].min(axis=1)
df['obj_a'] = df[['obj_q_0', 'obj_q_90', 'obj_q_180', 'obj_q_270']].idxmin(axis=1)
# convert obj_a to degrees
df['obj_a'] = df['obj_a'].apply(lambda x: int(x.split('_')[2]))
# use obj_wass with angle of min quantile
df['obj_w'] = df.apply(lambda x: x['obj_wass_' + str(x['obj_a'])], axis=1)
# drop spike count column 
df = df.drop(columns=['spike_count'])
# rename spike_count.1 to spike_count
df = df.rename(columns={'spike_count.1': 'spike_count'})

consecutive_sessions_threshold = 2
quantile_threshold = 0.2
consecutive = False
score = 'field'
main_field_only = False

In [None]:
dlist = []
ses_cut_dict = {}

for ses_limit in ['session_3', 'session_4', 'session_5', 'session_6']:
# for ses_limit in ['session_7']:

    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

    scores_to_compare = ['centroid', 'field']
    lim = ses_limit.split('_')[1]
    ses_to_use = ['session_1', 'session_2']
    for l in range(int(lim)):
        ses_to_use.append('session_' + str(l+1))
    df_use = df[df['session_id'].isin(ses_to_use)]
    # 'session_4','session_5','session_6'])]   
    main_centroid_df = df_use[(df_use['score'] == 'centroid') & (df_use['field_id'] == 1)]
    main_field_df = df_use[(df_use['score'] == 'field') & (df_use['field_id'] == 1)]
    all_field_df = df_use[df_use['score'] == 'field']
    all_centroid_df = df_use[df_use['score'] == 'centroid']


    assert len(main_centroid_df) == len(main_field_df)
    assert len(all_field_df) == len(all_centroid_df)

    matched = []
    unmatched = []
    all_matched = []
    all_unmatched = []
    all_ambig = []
    all_unambig = []
    all_diffs = []
    for i in range(len(all_field_df)):
        all_field_obj_a = all_field_df.iloc[i]['obj_a']
        all_centroid_obj_a = all_centroid_df.iloc[i]['obj_a']

        obj_qs = [all_field_df.iloc[i]['obj_q_0'], all_field_df.iloc[i]['obj_q_90'], all_field_df.iloc[i]['obj_q_180'], all_field_df.iloc[i]['obj_q_270']]
        sorted_obj_qs = np.sort(obj_qs)
        min1 = sorted_obj_qs[0]
        min2 = sorted_obj_qs[1]
        all_diffs.append(abs(min1 - min2))

        if abs(min1 - min2) < 0.05:
            all_ambig.append(i)
        else:
            all_unambig.append(i)

        if all_field_obj_a == all_centroid_obj_a:
            all_matched.append(i)
        else:
            all_unmatched.append(i)
    for i in range(len(main_field_df)):
        field_obj_a = main_field_df.iloc[i]['obj_a']
        centroid_obj_a = main_centroid_df.iloc[i]['obj_a']
        if field_obj_a == centroid_obj_a:
            matched.append(i)
        else:
            unmatched.append(i)
    df_matched_field = main_field_df.iloc[matched]
    df_all_matched_field = all_field_df.iloc[all_matched]
    df_all_ambiguous_field = all_field_df.iloc[all_ambig]
    df_unmatched_field = main_field_df.iloc[unmatched]
    df_all_unmatched_field = all_field_df.iloc[all_unmatched]
    df_all_unambiguous_field = all_field_df.iloc[all_unambig]
    df_matched_centroid = main_centroid_df.iloc[matched]
    df_all_matched_centroid = all_centroid_df.iloc[all_matched]
    df_unmatched_centroid = main_centroid_df.iloc[unmatched]
    df_all_unmatched_centroid = all_centroid_df.iloc[all_unmatched]
    assert len(df_matched_field) == len(df_matched_centroid)
    assert len(df_all_matched_field) == len(df_all_matched_centroid)
    assert len(df_unmatched_field) == len(df_unmatched_centroid)
    assert len(df_all_unmatched_field) == len(df_all_unmatched_centroid)
    print('There are {} rows where main field distance and centroid distance are in the same direction'.format(len(df_matched_field)))
    print('There are {} rows where main field distance and centroid distance are in different directions'.format(len(df_unmatched_field)))
    print('There are {} rows where all field distance and all centroid distance are in the same direction'.format(len(df_all_matched_field)))
    print('There are {} rows where all field distance and all centroid distance are in different directions'.format(len(df_all_unmatched_field)))

    ANT_same_dir= float(len(df_matched_field[df_matched_field['group'] == 'ANT']) / (len(df_matched_field[df_matched_field['group'] == 'ANT']) + len(df_unmatched_field[df_unmatched_field['group'] == 'ANT'])))
    ANT_diff_dir= float(len(df_unmatched_field[df_unmatched_field['group'] == 'ANT']) / (len(df_matched_field[df_matched_field['group'] == 'ANT']) + len(df_unmatched_field[df_unmatched_field['group'] == 'ANT'])))
    B6_same_dir= float(len(df_matched_field[df_matched_field['group'] == 'B6']) / (len(df_matched_field[df_matched_field['group'] == 'B6']) + len(df_unmatched_field[df_unmatched_field['group'] == 'B6'])))
    B6_diff_dir= float(len(df_unmatched_field[df_unmatched_field['group'] == 'B6']) / (len(df_matched_field[df_matched_field['group'] == 'B6']) + len(df_unmatched_field[df_unmatched_field['group'] == 'B6'])))
    NON_same_dir= float(len(df_matched_field[df_matched_field['group'] == 'NON']) / (len(df_matched_field[df_matched_field['group'] == 'NON']) + len(df_unmatched_field[df_unmatched_field['group'] == 'NON'])))
    NON_diff_dir= float(len(df_unmatched_field[df_unmatched_field['group'] == 'NON']) / (len(df_matched_field[df_matched_field['group'] == 'NON']) + len(df_unmatched_field[df_unmatched_field['group'] == 'NON'])))

    print('There are {} rows where main field distance and centroid distance are in the same direction for ANT'.format(ANT_same_dir))
    print('There are {} rows where main field distance and centroid distance are in different directions for ANT'.format(ANT_diff_dir))
    print('There are {} rows where main field distance and centroid distance are in the same direction for B6'.format(B6_same_dir))
    print('There are {} rows where main field distance and centroid distance are in different directions for B6'.format(B6_diff_dir))
    print('There are {} rows where main field distance and centroid distance are in the same direction for NON'.format(NON_same_dir))
    print('There are {} rows where main field distance and centroid distance are in different directions for NON'.format(NON_diff_dir))

    ANT_same_dir_all= float(len(df_all_matched_field[df_all_matched_field['group'] == 'ANT']) / (len(df_all_matched_field[df_all_matched_field['group'] == 'ANT']) + len(df_all_unmatched_field[df_all_unmatched_field['group'] == 'ANT'])))
    ANT_diff_dir_all= float(len(df_all_unmatched_field[df_all_unmatched_field['group'] == 'ANT']) / (len(df_all_matched_field[df_all_matched_field['group'] == 'ANT']) + len(df_all_unmatched_field[df_all_unmatched_field['group'] == 'ANT'])))
    B6_same_dir_all= float(len(df_all_matched_field[df_all_matched_field['group'] == 'B6']) / (len(df_all_matched_field[df_all_matched_field['group'] == 'B6']) + len(df_all_unmatched_field[df_all_unmatched_field['group'] == 'B6'])))
    B6_diff_dir_all= float(len(df_all_unmatched_field[df_all_unmatched_field['group'] == 'B6']) / (len(df_all_matched_field[df_all_matched_field['group'] == 'B6']) + len(df_all_unmatched_field[df_all_unmatched_field['group'] == 'B6'])))
    NON_same_dir_all= float(len(df_all_matched_field[df_all_matched_field['group'] == 'NON']) / (len(df_all_matched_field[df_all_matched_field['group'] == 'NON']) + len(df_all_unmatched_field[df_all_unmatched_field['group'] == 'NON'])))
    NON_diff_dir_all= float(len(df_all_unmatched_field[df_all_unmatched_field['group'] == 'NON']) / (len(df_all_matched_field[df_all_matched_field['group'] == 'NON']) + len(df_all_unmatched_field[df_all_unmatched_field['group'] == 'NON'])))

    print('There are {} rows where all field distance and all centroid distance are in the same direction for ANT'.format(ANT_same_dir_all))
    print('There are {} rows where all field distance and all centroid distance are in different directions for ANT'.format(ANT_diff_dir_all))
    print('There are {} rows where all field distance and all centroid distance are in the same direction for B6'.format(B6_same_dir_all))
    print('There are {} rows where all field distance and all centroid distance are in different directions for B6'.format(B6_diff_dir_all))
    print('There are {} rows where all field distance and all centroid distance are in the same direction for NON'.format(NON_same_dir_all))
    print('There are {} rows where all field distance and all centroid distance are in different directions for NON'.format(NON_diff_dir_all))

    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

    group_by_unique_cell = ['group', 'name', 'depth', 'date','tetrode', 'unit_id']

    object_cell_df = df[df['score'] == score]
    object_cell_df = object_cell_df[object_cell_df['session_id'].isin(ses_to_use)]   
    # 'session_4','session_5','session_6'])]

    group_by_unique_cell_field = ['group', 'name', 'depth', 'date','tetrode', 'unit_id', 'field_id']

    ANT_df = object_cell_df[object_cell_df['group'] == 'ANT'].reset_index(drop=True)
    B6_df = object_cell_df[object_cell_df['group'] == 'B6'].reset_index(drop=True)
    NON_df = object_cell_df[object_cell_df['group'] == 'NON'].reset_index(drop=True)

    ANT_object_cell_df = ANT_df.copy() 
    B6_object_cell_df = B6_df.copy() 
    NON_object_cell_df = NON_df.copy() 
    ANT_object_cell_df.sort_values(by=group_by_unique_cell_field, inplace=True)
    B6_object_cell_df.sort_values(by=group_by_unique_cell_field, inplace=True)
    NON_object_cell_df.sort_values(by=group_by_unique_cell_field, inplace=True)
    ANT_object_cell_df['cell_type'] = 'unassigned'
    B6_object_cell_df['cell_type'] = 'unassigned'
    NON_object_cell_df['cell_type'] = 'unassigned'
    ANT_object_cell_df['isTrace'] = 0
    B6_object_cell_df['isTrace'] = 0
    NON_object_cell_df['isTrace'] = 0
    ANT_object_cell_df['isObject'] = 0
    B6_object_cell_df['isObject'] = 0
    NON_object_cell_df['isObject'] = 0
    ANT_object_cell_df['trace_a'] = None
    B6_object_cell_df['trace_a'] = None
    NON_object_cell_df['trace_a'] = None


    ANT_object_cell_df.loc[ANT_object_cell_df['obj_a'].astype(str)  == ANT_object_cell_df['object_location'].astype(str),'isObject'] = 1
    B6_object_cell_df.loc[B6_object_cell_df['obj_a'].astype(str)  == B6_object_cell_df['object_location'].astype(str),'isObject'] = 1
    NON_object_cell_df.loc[NON_object_cell_df['obj_a'].astype(str)  == NON_object_cell_df['object_location'].astype(str),'isObject'] = 1

    keep_trace_appearances = []
    for df_touse in [ANT_object_cell_df, B6_object_cell_df, NON_object_cell_df]:
        prev_angles = []
        prev_unit_id = None
        prev_field_id = None
        prev_tetrode = None
        prev_name = None
        prev_date = None
        prev_depth = None
        to_keep_trace_appearance = []
        for i, row in df_touse.iterrows():
            curr_unit_id = row['unit_id']
            curr_tetrode = row['tetrode']
            curr_field_id = row['field_id']
            curr_angle = row['obj_a']
            curr_name = row['name']
            curr_date = row['date']
            curr_depth = row['depth']

            if curr_unit_id == prev_unit_id and curr_field_id == prev_field_id and curr_tetrode == prev_tetrode and curr_name == prev_name and curr_date == prev_date and curr_depth == prev_depth:
                if str(curr_angle) in prev_angles:
                    to_keep_trace_appearance.append(i)
            else:
                prev_angles = []

            prev_unit_id = curr_unit_id
            prev_angle = curr_angle
            prev_field_id = curr_field_id
            prev_tetrode = curr_tetrode
            prev_name = curr_name
            prev_date = curr_date
            prev_depth = curr_depth
            prev_angles.append(row['object_location'])

        keep_trace_appearances.append(to_keep_trace_appearance)

    ANT_object_cell_df.loc[keep_trace_appearances[0],'isTrace'] = 1
    B6_object_cell_df.loc[keep_trace_appearances[1],'isTrace'] = 1
    NON_object_cell_df.loc[keep_trace_appearances[2],'isTrace'] = 1

    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

    import copy

    group_by_unique_cell_session = ['group', 'name', 'depth', 'date', 'tetrode', 'unit_id','session_id'] # (joins field rowsfrom that ses)

    for ses_cut in [lim]:
        ses_cut = int(ses_cut)

        center_type_dict = {'ANT': 0,
                        'B6':0,
                        'NON': 0}

        center_inv_type_dict = {'ANT': 0,
                        'B6':0,
                        'NON': 0}

        quality_type_dict = {'ANT': 0,
                        'B6':0,
                        'NON': 0}   

        quality_inv_type_dict = {'ANT': 0,
                        'B6':0,
                        'NON': 0}           

        ambiguous_type_dict = {'ANT': 0,
                        'B6':0,
                        'NON': 0}       

        ambiguous_inv_type_dict = {'ANT': 0,
                        'B6':0,
                        'NON': 0}

        ses_cut_dict[ses_cut] = {}
        ANT_object_cell_df_to_use = ANT_object_cell_df[ANT_object_cell_df['session_id'].isin(['session_'+str(i) for i in range(1,ses_cut+1)])]
        B6_object_cell_df_to_use = B6_object_cell_df[B6_object_cell_df['session_id'].isin(['session_'+str(i) for i in range(1,ses_cut+1)])]
        NON_object_cell_df_to_use = NON_object_cell_df[NON_object_cell_df['session_id'].isin(['session_'+str(i) for i in range(1,ses_cut+1)])]
        df_all_unmatched_field_to_use = df_all_unmatched_field[df_all_unmatched_field['session_id'].isin(['session_'+str(i) for i in range(1,ses_cut+1)])]
        df_all_matched_field_to_use = df_all_matched_field[df_all_matched_field['session_id'].isin(['session_'+str(i) for i in range(1,ses_cut+1)])]
        # df_all_unmatched_field_to_use = df_all_ambiguous_field[df_all_ambiguous_field['session_id'].isin(['session_'+str(i) for i in range(1,ses_cut+1)])]
        # df_all_matched_field_to_use = df_all_unambiguous_field[df_all_unambiguous_field['session_id'].isin(['session_'+str(i) for i in range(1,ses_cut+1)])]
        c = 0
        for df_current in [ANT_object_cell_df_to_use, B6_object_cell_df_to_use, NON_object_cell_df_to_use]:
            quality_dropped_identifiers = df_current[df_current['iso_dist'] < 5].groupby(group_by_unique_cell_session).groups.keys()
            quality_non_dropped_identifiers = df_current[df_current['iso_dist'] >= 5].groupby(group_by_unique_cell_session).groups.keys()
            center_dropped_identifiers = df_current[df_current['obj_q_NO'] < df_current['obj_q']].groupby(group_by_unique_cell_session).groups.keys()
            center_non_dropped_identifiers = df_current[df_current['obj_q_NO'] >= df_current['obj_q']].groupby(group_by_unique_cell_session).groups.keys()
            ambiguous_dropped_identifiers = df_all_unmatched_field_to_use[df_all_unmatched_field_to_use['group'] == ['ANT','B6','NON'][c]].groupby(group_by_unique_cell_session).groups.keys()
            ambiguous_non_dropped_identifiers = df_all_matched_field_to_use[df_all_matched_field_to_use['group'] == ['ANT','B6','NON'][c]].groupby(group_by_unique_cell_session).groups.keys()
            # mask_dropped = pd.concat([(df_current['group'] == id1) & (df_current['name'] == id2) & (df_current['depth'] == id3) & (df_current['date'] == id4) & (df_current['tetrode'] == id5) & (df_current['unit_id'] == id6) for id1, id2, id3, id4, id5, id6 in ambiguous_dropped_identifiers], axis=1).any(axis=1)
            # mask_non_dropped = pd.concat([(df_current['group'] == id1) & (df_current['name'] == id2) & (df_current['depth'] == id3) & (df_current['date'] == id4) & (df_current['tetrode'] == id5) & (df_current['unit_id'] == id6) for id1, id2, id3, id4, id5, id6 in ambiguous_non_dropped_identifiers], axis=1).any(axis=1)
            if c == 0:
                quality_type_dict['ANT'] = len(quality_dropped_identifiers)
                quality_inv_type_dict['ANT'] = len(quality_non_dropped_identifiers)
                center_type_dict['ANT'] = len(center_dropped_identifiers)
                center_inv_type_dict['ANT'] = len(center_non_dropped_identifiers)
                ambiguous_type_dict['ANT'] = len(ambiguous_dropped_identifiers)
                ambiguous_inv_type_dict['ANT'] = len(ambiguous_non_dropped_identifiers)
            elif c == 1:
                quality_type_dict['B6'] = len(quality_dropped_identifiers)
                quality_inv_type_dict['B6'] = len(quality_non_dropped_identifiers)
                center_type_dict['B6'] = len(center_dropped_identifiers)
                center_inv_type_dict['B6'] = len(center_non_dropped_identifiers)
                ambiguous_type_dict['B6'] = len(ambiguous_dropped_identifiers)
                ambiguous_inv_type_dict['B6'] = len(ambiguous_non_dropped_identifiers)
            elif c == 2:
                quality_type_dict['NON'] = len(quality_dropped_identifiers)
                quality_inv_type_dict['NON'] = len(quality_non_dropped_identifiers)
                center_type_dict['NON'] = len(center_dropped_identifiers)
                center_inv_type_dict['NON'] = len(center_non_dropped_identifiers)
                ambiguous_type_dict['NON'] = len(ambiguous_dropped_identifiers)
                ambiguous_inv_type_dict['NON'] = len(ambiguous_non_dropped_identifiers)

            c += 1
        ses_cut_dict[ses_cut]['quality_type_dict'] = copy.deepcopy(quality_type_dict)
        ses_cut_dict[ses_cut]['quality_inv_type_dict'] = copy.deepcopy(quality_inv_type_dict)
        ses_cut_dict[ses_cut]['center_type_dict'] = copy.deepcopy(center_type_dict)
        ses_cut_dict[ses_cut]['center_inv_type_dict'] = copy.deepcopy(center_inv_type_dict)
        ses_cut_dict[ses_cut]['ambiguous_type_dict'] = copy.deepcopy(ambiguous_type_dict)
        ses_cut_dict[ses_cut]['ambiguous_inv_type_dict'] = copy.deepcopy(ambiguous_inv_type_dict)

    c = 0
    for df_current in [ANT_object_cell_df, B6_object_cell_df, NON_object_cell_df]:

        param = 6.5
        # filter out rows where iso_dist is < 5 - Quality control
        quality_dropped_identifiers = df_current[df_current['iso_dist'] < param].groupby(group_by_unique_cell_session).groups.keys()
        quality_non_dropped_identifiers = df_current[df_current['iso_dist'] >= param].groupby(group_by_unique_cell_session).groups.keys()
        df_current = df_current[df_current['iso_dist'] >= param]

        # # filter out rows where obj_q_NO is < obj_q - CLOSER to middle than a side
        # center_dropped_identifiers = df_current[df_current['obj_q_NO'] < df_current['obj_q']].groupby(group_by_unique_cell_session).groups.keys()
        # center_non_dropped_identifiers = df_current[df_current['obj_q_NO'] >= df_current['obj_q']].groupby(group_by_unique_cell_session).groups.keys()
        # df_current = df_current[df_current['obj_q_NO'] >= df_current['obj_q']] 
        # """ HAVE TO RE RUN ALL U HAD THIS LINE AS /3, need ot check combos of ambiguous classic and mbiuous theshold 0.05 """

        # # filter out rows where obj_a for centroid is != obj_a for field - Ambiguous
        # ambiguous_dropped_identifiers = df_all_unmatched_field.groupby(group_by_unique_cell_session).groups.keys()
        # # ambiguous_dropped_identifiers = df_all_ambiguous_field.groupby(group_by_unique_cell_session).groups.keys()
        # mask = pd.concat([(df_current['group'] == id1) & (df_current['name'] == id2) & (df_current['depth'] == id3) & (df_current['date'] == id4) & (df_current['tetrode'] == id5) & (df_current['unit_id'] == id6 & (df_current['session_id'] == id7)) for id1, id2, id3, id4, id5, id6, id7 in ambiguous_dropped_identifiers], axis=1).any(axis=1)
        # df_current = df_current[~mask]

        # # filter out > 3 fields
        # field_dropped_identifiers = df_current[df_current['field_id'] > 3].groupby(group_by_unique_cell).groups.keys()
        # df_current = df_current[df_current['field_id'] <= 3]

        # filter out less than 2 sessions
        remaining_session_dropped_identifiers = df_current.groupby(group_by_unique_cell).filter(lambda x: len(x) < 2).groupby(group_by_unique_cell).groups.keys()
        df_current = df_current.groupby(group_by_unique_cell).filter(lambda x: len(x) >= 2)

        if c == 0:
            ANT_object_cell_df = df_current
        elif c == 1:
            B6_object_cell_df = df_current
        elif c == 2:
            NON_object_cell_df = df_current

        c += 1  
        
    ANT_cell_type_df = ANT_object_cell_df.copy()
    B6_cell_type_df = B6_object_cell_df.copy()
    NON_cell_type_df = NON_object_cell_df.copy()

    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

    identifier_dict = {}

    identifier_dict['object'] = {}
    identifier_dict['trace'] = {}
    identifier_dict['unassigned'] = {}

    identifier_dict['object']['ANT'] = list(ANT_cell_type_df[ANT_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) >= 2).groupby(group_by_unique_cell).groups.keys())
    identifier_dict['object']['B6'] = list(B6_cell_type_df[B6_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) >= 2).groupby(group_by_unique_cell).groups.keys())
    identifier_dict['object']['NON'] = list(NON_cell_type_df[NON_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) >= 2).groupby(group_by_unique_cell).groups.keys())
    mask = pd.concat([(ANT_cell_type_df['group'] == id1) & (ANT_cell_type_df['name'] == id2) & (ANT_cell_type_df['depth'] == id3) & (ANT_cell_type_df['date'] == id4) & (ANT_cell_type_df['tetrode'] == id5) & (ANT_cell_type_df['unit_id'] == id6) for id1, id2, id3, id4, id5, id6 in identifier_dict['object']['ANT']], axis=1).any(axis=1)
    ANT_cell_type_df.loc[mask,'cell_type'] = 'object'
    mask = pd.concat([(B6_cell_type_df['group'] == id1) & (B6_cell_type_df['name'] == id2) & (B6_cell_type_df['depth'] == id3) & (B6_cell_type_df['date'] == id4) & (B6_cell_type_df['tetrode'] == id5) & (B6_cell_type_df['unit_id'] == id6) for id1, id2, id3, id4, id5, id6 in identifier_dict['object']['B6']], axis=1).any(axis=1)
    B6_cell_type_df.loc[mask,'cell_type'] = 'object'
    mask = pd.concat([(NON_cell_type_df['group'] == id1) & (NON_cell_type_df['name'] == id2) & (NON_cell_type_df['depth'] == id3) & (NON_cell_type_df['date'] == id4) & (NON_cell_type_df['tetrode'] == id5) & (NON_cell_type_df['unit_id'] == id6) for id1, id2, id3, id4, id5, id6 in identifier_dict['object']['NON']], axis=1).any(axis=1)
    NON_cell_type_df.loc[mask,'cell_type'] = 'object'

    if ses_limit == 'session_3':

        only_object_ANT = ANT_cell_type_df[ANT_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) >= 2)
        only_object_B6 = B6_cell_type_df[B6_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) >= 2)
        only_object_NON = NON_cell_type_df[NON_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) >= 2)
        only_object = pd.concat([only_object_ANT, only_object_B6, only_object_NON], axis=0)
        only_object.to_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC_only_object.xlsx')

    filt = ANT_cell_type_df[ANT_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) < 2)
    first_set = filt[filt['isTrace'] == 1].groupby(group_by_unique_cell).groups.keys()
    second_set = ANT_cell_type_df[(ANT_cell_type_df['isObject'] == 0) & (ANT_cell_type_df['isTrace'] == 1)].groupby(group_by_unique_cell).groups.keys()
    # unq_set = set(first_set).difference(second_set)
    unique_set = list(set(list(first_set) + list(second_set)))
    unique_set = [x for x in unique_set if x not in identifier_dict['object']['ANT']]
    identifier_dict['trace']['ANT'] = list(unique_set)
    mask = pd.concat([(ANT_cell_type_df['group'] == id1) & (ANT_cell_type_df['name'] == id2) & (ANT_cell_type_df['depth'] == id3) & (ANT_cell_type_df['date'] == id4) & (ANT_cell_type_df['tetrode'] == id5) & (ANT_cell_type_df['unit_id'] == id6) for id1, id2, id3, id4, id5, id6 in identifier_dict['trace']['ANT']], axis=1).any(axis=1)
    ANT_cell_type_df.loc[mask,'cell_type'] = 'trace'

    if ses_limit == 'session_3':

        only_trace_ANT_set1 = filt[filt['isTrace'] == 1]
        only_trace_ANT_set2 = ANT_cell_type_df[(ANT_cell_type_df['isObject'] == 0) & (ANT_cell_type_df['isTrace'] == 1)]
        only_trace_ANT = pd.concat([only_trace_ANT_set1, only_trace_ANT_set2], axis=0)

    filt = B6_cell_type_df[B6_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) < 2)
    first_set = filt[filt['isTrace'] == 1].groupby(group_by_unique_cell).groups.keys()
    second_set = B6_cell_type_df[(B6_cell_type_df['isObject'] == 0) & (B6_cell_type_df['isTrace'] == 1)].groupby(group_by_unique_cell).groups.keys()
    # unq_set = set(first_set).difference(second_set)
    unique_set = list(set(list(first_set) + list(second_set)))
    unique_set = [x for x in unique_set if x not in identifier_dict['object']['B6']]
    identifier_dict['trace']['B6'] = list(unique_set)
    mask = pd.concat([(B6_cell_type_df['group'] == id1) & (B6_cell_type_df['name'] == id2) & (B6_cell_type_df['depth'] == id3) & (B6_cell_type_df['date'] == id4) & (B6_cell_type_df['tetrode'] == id5) & (B6_cell_type_df['unit_id'] == id6) for id1, id2, id3, id4, id5, id6 in identifier_dict['trace']['B6']], axis=1).any(axis=1)
    B6_cell_type_df.loc[mask,'cell_type'] = 'trace'

    if ses_limit == 'session_3':

        only_trace_B6_set1 = filt[filt['isTrace'] == 1]
        only_trace_B6_set2 = B6_cell_type_df[(B6_cell_type_df['isObject'] == 0) & (B6_cell_type_df['isTrace'] == 1)]
        only_trace_B6 = pd.concat([only_trace_B6_set1, only_trace_B6_set2], axis=0)

    filt = NON_cell_type_df[NON_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) < 2)
    first_set = filt[filt['isTrace'] == 1].groupby(group_by_unique_cell).groups.keys()
    second_set = NON_cell_type_df[(NON_cell_type_df['isObject'] == 0) & (NON_cell_type_df['isTrace'] == 1)].groupby(group_by_unique_cell).groups.keys()
    # .groupby(group_by_unique_cell).groups.keys()
    # unq_set = set(first_set).difference(second_set)
    unique_set = list(set(list(first_set) + list(second_set)))
    unique_set = [x for x in unique_set if x not in identifier_dict['object']['NON']]
    identifier_dict['trace']['NON'] = list(unique_set)
    mask = pd.concat([(NON_cell_type_df['group'] == id1) & (NON_cell_type_df['name'] == id2) & (NON_cell_type_df['depth'] == id3) & (NON_cell_type_df['date'] == id4) & (NON_cell_type_df['tetrode'] == id5) & (NON_cell_type_df['unit_id'] == id6) for id1, id2, id3, id4, id5, id6 in identifier_dict['trace']['NON']], axis=1).any(axis=1)
    NON_cell_type_df.loc[mask,'cell_type'] = 'trace'
    
    if ses_limit == 'session_3':
        only_trace_NON_set1 = filt[filt['isTrace'] == 1]
        only_trace_NON_set2 = NON_cell_type_df[(NON_cell_type_df['isObject'] == 0) & (NON_cell_type_df['isTrace'] == 1)]
        only_trace_NON = pd.concat([only_trace_NON_set1, only_trace_NON_set2], axis=0)

        only_trace = pd.concat([only_trace_ANT, only_trace_B6, only_trace_NON], axis=0)
        only_trace.to_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC_only_trace.xlsx')

    objectfiltout = ANT_cell_type_df[ANT_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) < 2)
    first_set = objectfiltout[objectfiltout['isTrace'] == 0].groupby(group_by_unique_cell).groups.keys()
    second_set = ANT_cell_type_df[(ANT_cell_type_df['isTrace'] == 0) & (ANT_cell_type_df['isObject'] == 0)].groupby(group_by_unique_cell).groups.keys()
    # unq_set = set(first_set).difference(second_set)
    unique_set = list(set(list(first_set) + list(second_set)))
    unique_set = [x for x in unique_set if x not in identifier_dict['trace']['ANT']]
    unique_set = [x for x in unique_set if x not in identifier_dict['object']['ANT']]
    identifier_dict['unassigned']['ANT'] = list(unique_set)

    if ses_limit == 'session_3':
            
        only_unassigned_ANT_set1 = objectfiltout[objectfiltout['isTrace'] == 0]
        only_unassigned_ANT_set2 = ANT_cell_type_df[(ANT_cell_type_df['isTrace'] == 0) & (ANT_cell_type_df['isObject'] == 0)]
        only_unassigned_ANT = pd.concat([only_unassigned_ANT_set1, only_unassigned_ANT_set2], axis=0)

    objectfiltout = B6_cell_type_df[B6_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) < 2)
    first_set = objectfiltout[objectfiltout['isTrace'] == 0].groupby(group_by_unique_cell).groups.keys()
    second_set = B6_cell_type_df[(B6_cell_type_df['isTrace'] == 0) & (B6_cell_type_df['isObject'] == 0)].groupby(group_by_unique_cell).groups.keys()
    # unq_set = set(first_set).difference(second_set)
    unique_set = list(set(list(first_set) + list(second_set)))
    unique_set = [x for x in unique_set if x not in identifier_dict['trace']['B6']]
    unique_set = [x for x in unique_set if x not in identifier_dict['object']['B6']]
    identifier_dict['unassigned']['B6'] = list(unique_set)

    if ses_limit == 'session_3':
            
        only_unassigned_B6_set1 = objectfiltout[objectfiltout['isTrace'] == 0]
        only_unassigned_B6_set2 = B6_cell_type_df[(B6_cell_type_df['isTrace'] == 0) & (B6_cell_type_df['isObject'] == 0)]
        only_unassigned_B6 = pd.concat([only_unassigned_B6_set1, only_unassigned_B6_set2], axis=0)

    objectfiltout = NON_cell_type_df[NON_cell_type_df['isObject'] == 1].groupby(group_by_unique_cell_field).filter(lambda x: len(x) < 2)
    first_set = objectfiltout[objectfiltout['isTrace'] == 0].groupby(group_by_unique_cell).groups.keys()
    second_set = NON_cell_type_df[(NON_cell_type_df['isTrace'] == 0) & (NON_cell_type_df['isObject'] == 0)].groupby(group_by_unique_cell).groups.keys()
    # unq_set = set(first_set).difference(second_set)
    unique_set = list(set(list(first_set) + list(second_set)))
    unique_set = [x for x in unique_set if x not in identifier_dict['trace']['NON']]
    unique_set = [x for x in unique_set if x not in identifier_dict['object']['NON']]
    identifier_dict['unassigned']['NON'] = list(unique_set)

    if ses_limit == 'session_3':
                
        only_unassigned_NON_set1 = objectfiltout[objectfiltout['isTrace'] == 0]
        only_unassigned_NON_set2 = NON_cell_type_df[(NON_cell_type_df['isTrace'] == 0) & (NON_cell_type_df['isObject'] == 0)]
        only_unassigned_NON = pd.concat([only_unassigned_NON_set1, only_unassigned_NON_set2], axis=0)

        only_unassigned = pd.concat([only_unassigned_ANT, only_unassigned_B6, only_unassigned_NON], axis=0)
        only_unassigned.to_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC_only_unassigned.xlsx')

    dlist.append(identifier_dict)

    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""




In [None]:
stop()

In [None]:
full_cell_type_df = pd.concat([ANT_cell_type_df, B6_cell_type_df, NON_cell_type_df], axis=0)
# full_cell_type_df.to_excel('/Users/alexgonzalez/Google Drive/PostDoc/Data/ephys_summary/summary_cell_type_df.xlsx')
full_cell_type_df.to_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC_assigned.xlsx')


In [None]:
import pandas as pd
from scipy.stats import fisher_exact
from statsmodels.stats.multitest import multipletests

# Assuming 'full_cell_type' is the DataFrame
groups = ['ANT', 'NON', 'B6']
cell_types = ['object','trace','unassigned']
cell_type_opp = ['non-object', 'non-trace', 'assigned']
identifier_dict = dlist[0]
#  'trace', 'trace_2', 'unassigned']


type_dict = {'ANT': {'object': 0, 'trace': 0, 'unassigned': 0},
                'B6': {'object': 0, 'trace': 0, 'unassigned': 0},
                'NON': {'object': 0, 'trace': 0, 'unassigned': 0}}

inv_type_dict = {'ANT': {'object': 0, 'trace': 0, 'unassigned': 0},
                'B6': {'object': 0, 'trace': 0, 'unassigned': 0},
                'NON': {'object': 0, 'trace': 0, 'unassigned': 0}}

for ctype in identifier_dict:
    for group in identifier_dict[ctype]:
        type_dict[group][ctype] += len(identifier_dict[ctype][group])
        for ctype2 in identifier_dict:
            if ctype2 != ctype:
                inv_type_dict[group][ctype] += len(identifier_dict[ctype2][group])



# Initialize empty lists to store results
comparisons = []
odds_ratios = []
p_values = []

for i in range(len(groups)):
    for j in range(i+1, len(groups)):
        group1 = groups[i]
        group2 = groups[j]
        
        ctype_count = 0
        
        for cell_type in cell_types:
            
            contingency_table_2x2 = pd.DataFrame({str(cell_type): [type_dict[group1][cell_type], type_dict[group2][cell_type]],
                                    str(cell_type_opp[ctype_count]): [inv_type_dict[group1][cell_type], inv_type_dict[group2][cell_type]]},
                                    index=[group1, group2])
            print(contingency_table_2x2)
            # Performing Fisher's Exact Test
            odds_ratio, p_value = fisher_exact(contingency_table_2x2)

            # Storing results
            comparisons.append(f"{group1} vs. {group2} - {cell_type.upper()}")
            odds_ratios.append(odds_ratio)
            p_values.append(p_value)
            ctype_count += 1

# Adjust p-values for multiple comparisons using the Benjamini-Hochberg method
rejected, adjusted_p_values, _, _ = multipletests(p_values, method='fdr_bh', alpha=0.05, is_sorted=False, returnsorted=False)

# Display the results
results_df = pd.DataFrame({
    "Accepted": rejected,
    "Comparison": comparisons,
    "Odds Ratio": odds_ratios,
    "P-value": p_values,
    "Adjusted P-value": adjusted_p_values
})

print(results_df)


In [None]:
dlist_to_use = [ses_cut_dict[3]['quality_type_dict'], ses_cut_dict[3]['center_type_dict'], ses_cut_dict[3]['ambiguous_type_dict']]
dlist_inv = [ses_cut_dict[3]['quality_inv_type_dict'], ses_cut_dict[3]['center_inv_type_dict'], ses_cut_dict[3]['ambiguous_inv_type_dict']]
dflist = []

for ct in range(len(dlist_to_use)):
    tp_dict = dlist_to_use[ct]
    tp_inv_dict = dlist_inv[ct]
    # Initialize empty lists to store results
    comparisons = []
    odds_ratios = []
    p_values = []

    for i in range(len(groups)):
        for j in range(i+1, len(groups)):
            group1 = groups[i]
            group2 = groups[j]
                        
                
            contingency_table_2x2 = pd.DataFrame({'excluded': [tp_dict[group1], tp_dict[group2]],
                                    'included': [tp_inv_dict[group1], tp_inv_dict[group2]]},
                                    index=[group1, group2])
            print(contingency_table_2x2)
            # Performing Fisher's Exact Test
            odds_ratio, p_value = fisher_exact(contingency_table_2x2)
            # Storing results
            comparisons.append(f"{group1} vs. {group2} - {cell_type.upper()}")
            odds_ratios.append(odds_ratio)
            p_values.append(p_value)

    # Adjust p-values for multiple comparisons using the Benjamini-Hochberg method
    rejected, adjusted_p_values, _, _ = multipletests(p_values, method='fdr_bh', alpha=0.05, is_sorted=False, returnsorted=False)

    # Display the results
    res_df = pd.DataFrame({
        "Accepted": rejected,
        "Comparison": comparisons,
        "Odds Ratio": odds_ratios,
        "P-value": p_values,
        "Adjusted P-value": adjusted_p_values
    })

    dflist.append(res_df)

    print(res_df)


In [None]:
# https://stackoverflow.com/questions/11517986/indicating-the-statistically-significant-difference-in-bar-graph

def barplot_annotate_brackets(num1, num2, data, center, height, yerr=None, dh=.05, barh=.05, fs=None, maxasterix=None):
    """ 
    Annotate barplot with p-values.

    :param num1: number of left bar to put bracket over
    :param num2: number of right bar to put bracket over
    :param data: string to write or number for generating asterixes
    :param center: centers of all bars (like plt.bar() input)
    :param height: heights of all bars (like plt.bar() input)
    :param yerr: yerrs of all bars (like plt.bar() input)
    :param dh: height offset over bar / bar + yerr in axes coordinates (0 to 1)
    :param barh: bar height in axes coordinates (0 to 1)
    :param fs: font size
    :param maxasterix: maximum number of asterixes to write (for very small p-values)
    """

    if type(data) is str:
        text = data
    else:
        # * is p < 0.05
        # ** is p < 0.005
        # *** is p < 0.0005
        # etc.
        text = ''
        p = .05

        if data <= 0.05:
            text = '*'
        if data <= 0.01:
            text = '**'
        if data <= 0.001:
            text = '***'
        if data <= 0.0001:
            text = '****'

        # while data < p:
        #     text += '*'
        #     p /= 10.

        #     if maxasterix and len(text) == maxasterix:
        #         break

        if len(text) == 0:
            text = 'n. s.'

    lx, ly = center[num1], height[num1]
    rx, ry = center[num2], height[num2]

    if yerr:
        ly += yerr[num1]
        ry += yerr[num2]

    ax_y0, ax_y1 = plt.gca().get_ylim()
    dh *= (ax_y1 - ax_y0)
    barh *= (ax_y1 - ax_y0)

    y = max(ly, ry) + dh

    barx = [lx, lx, rx, rx]
    bary = [y, y+barh, y+barh, y]
    mid = ((lx+rx)/2, y+barh)

    plt.plot(barx, bary, c='black')

    kwargs = dict(ha='center', va='bottom')
    if fs is not None:
        kwargs['fontsize'] = fs

    plt.text(*mid, text, **kwargs)

In [None]:
from matplotlib import gridspec


# ses1234_identifier_dict, ses12345_identifier_dict, ses123456_identifier_dict]
dtitles = ['Session 1 to 3', 'Session 1 to 4', 'Session 1 to 5', 'Session 1 to 6']
# dpositions = [0.93, 0.68, 0.33, 0.03]
# axids = [[[0,0],[0,1],[0,2]], [[1,0],[1,1],[1,2]], [[2,0],[2,1],[2,2]], [[3,0],[3,1],[3,2]]]
lbls = ['N=3', 'N=4', 'N=5', 'N=6']
fig = plt.figure(figsize=(15,8))
# gspec 3x3
gspec = gridspec.GridSpec(2,3)

b6notdone_1 = True
b6notdone_2 = True
b6notdone_3 = True
nonnotdone_2 = True

for i in range(len(dlist)):
    dtouse = dlist[i]
    dtitle = dtitles[i]
    # dpos = dpositions[i]
    # axid = axids[i]

    ANT_obj = dtouse['object']['ANT']
    ANT_trace = dtouse['trace']['ANT']
    ANT_unassigned = dtouse['unassigned']['ANT']

    prev_ANT_obj_per = len(ANT_obj) / (len(ANT_obj) + len(ANT_trace) + len(ANT_unassigned))
    sum_ANT = ANT_cell_type_df[ANT_cell_type_df['object_location'].astype(str) != 'NO']
    sum_ANT = sum_ANT.groupby(group_by_unique_cell).filter(lambda x: len(x) >= 2).groupby(group_by_unique_cell)
    sum_ANT = len(sum_ANT)
    ANT_obj_per = len(ANT_obj) / sum_ANT
    print(ANT_obj_per, prev_ANT_obj_per)
    ANT_trace_per = len(ANT_trace) / (len(ANT_obj) + len(ANT_trace) + len(ANT_unassigned))
    ANT_unassigned_per = 1 - ANT_obj_per - ANT_trace_per

    B6_obj = dtouse['object']['B6']
    B6_trace = dtouse['trace']['B6']
    B6_unassigned = dtouse['unassigned']['B6']

    # B6_obj_per = len(B6_obj) / (len(B6_obj) + len(B6_trace) + len(B6_unassigned))
    sum_B6 = B6_cell_type_df[B6_cell_type_df['object_location'].astype(str) != 'NO']
    sum_B6 = sum_B6.groupby(group_by_unique_cell).filter(lambda x: len(x) >= 2).groupby(group_by_unique_cell)
    sum_B6 = len(sum_B6)
    B6_obj_per = len(B6_obj) / sum_B6
    B6_trace_per = len(B6_trace) / (len(B6_obj) + len(B6_trace) + len(B6_unassigned))
    # B6_unassigned_per = len(B6_unassigned) / (len(B6_obj) + len(B6_trace) + len(B6_unassigned))
    B6_unassigned_per = 1 - B6_obj_per - B6_trace_per

    NON_obj = dtouse['object']['NON']
    NON_trace = dtouse['trace']['NON']
    NON_unassigned = dtouse['unassigned']['NON']

    # NON_obj_per = len(NON_obj) / (len(NON_obj) + len(NON_trace) + len(NON_unassigned))
    sum_NON = NON_cell_type_df[NON_cell_type_df['object_location'].astype(str) != 'NO']
    sum_NON = sum_NON.groupby(group_by_unique_cell).filter(lambda x: len(x) >= 2).groupby(group_by_unique_cell)
    sum_NON = len(sum_NON)
    NON_obj_per = len(NON_obj) / sum_NON
    NON_trace_per = len(NON_trace) / (len(NON_obj) + len(NON_trace) + len(NON_unassigned))
    # NON_unassigned_per = len(NON_unassigned) / (len(NON_obj) + len(NON_trace) + len(NON_unassigned))
    NON_unassigned_per = 1 - NON_obj_per - NON_trace_per

    # obj_per = np.array([ANT_obj_per, B6_obj_per, NON_obj_per]) * 100
    obj_per = np.array([B6_obj_per, NON_obj_per, ANT_obj_per]) * 100
    # trace_per = np.array([ANT_trace_per, B6_trace_per, NON_trace_per]) * 100
    trace_per = np.array([B6_trace_per, NON_trace_per, ANT_trace_per]) * 100
    # unassigned_per = np.array([ANT_unassigned_per, B6_unassigned_per, NON_unassigned_per]) * 100
    unassigned_per = np.array([B6_unassigned_per, NON_unassigned_per, ANT_unassigned_per]) * 100

    unassigned_per = 100 - unassigned_per


    # top left
    ax1 = plt.subplot(gspec[0,0])
    ax1.bar(['B6', 'NON', 'ANT'], obj_per, color=['blue', 'green', 'red'], alpha = 0.2)
    if i == 0:
        ax1.bar(['B6', 'NON', 'ANT'], obj_per, color=['blue', 'green', 'red'], alpha = 1)
    if i == len(dlist)-1:
        comps = results_df['Comparison']
        accepted = results_df['Accepted']
        adjusted = results_df['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'OBJECT' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], unassigned_per, maxasterix=5)

                    
    ax1.plot(['B6', 'NON', 'ANT'], obj_per, 'k-', marker='o', alpha=0.5, lw=0)
    # annotate
    for j, lbl in enumerate(['B6', 'NON', 'ANT']):
        if lbl == 'B6' and b6notdone_1:
            ax1.annotate(lbls[i], xy=(lbl, obj_per[j]),textcoords='offset points', xytext=(0,10), ha='center')
        elif lbl != 'B6':
            ax1.annotate(lbls[i], xy=(lbl, obj_per[j]),textcoords='offset points', xytext=(0,3), ha='center')
        if lbl == 'B6':
            b6notdone_1 = False


    ax1.set_title('Object')
    ax1.set_ylabel('% unique cells')

    # top middle
    ax2 = plt.subplot(gspec[0,1])
    ax2.bar(['B6', 'NON', 'ANT'], trace_per, color=['blue', 'green', 'red'], alpha = 0.2)
    if i == 0:
        ax2.bar(['B6', 'NON', 'ANT'], trace_per, color=['blue', 'green', 'red'], alpha = 1)
    if i == len(dlist)-1:
        comps = results_df['Comparison']
        accepted = results_df['Accepted']
        adjusted = results_df['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'TRACE' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], unassigned_per, maxasterix=5)

                    
    ax2.plot(['B6', 'NON', 'ANT'], trace_per, 'k-', marker='o', alpha=0.5, lw=0)
    for j, lbl in enumerate(['B6', 'NON', 'ANT']):
        if lbl == 'B6' and b6notdone_2:
            ax2.annotate(lbls[i], xy=(lbl, trace_per[j]),textcoords='offset points', xytext=(0,10), ha='center')
        # elif lbl == 'NON' and nonnotdone_2 and i != 0:
        #     ax2.annotate('N=4,5,6', xy=(lbl, trace_per[j]),textcoords='offset points', xytext=(0,15), ha='center')
        # elif lbl == 'ANT' or (lbl == 'NON' and i == 0):
        elif lbl != 'B6':
            ax2.annotate(lbls[i], xy=(lbl, trace_per[j]),textcoords='offset points', xytext=(0,3), ha='center')
        if lbl == 'B6':
            b6notdone_2 = False
        if lbl == 'NON' and i != 0:
            nonnotdone_2 = False
    ax2.set_title('Trace')

    # top right
    ax3 = plt.subplot(gspec[0,2])
    ax3.bar(['B6', 'NON', 'ANT'], unassigned_per, color=['blue', 'green', 'red'], alpha = 0.2)
    if i == 0:
        ax3.bar(['B6', 'NON', 'ANT'], unassigned_per, color=['blue', 'green', 'red'], alpha = 1)
    if i == len(dlist)-1:
        comps = results_df['Comparison']
        accepted = results_df['Accepted']
        adjusted = results_df['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'UNASSIGNED' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], unassigned_per, maxasterix=5)

                    


    ax3.plot(['B6', 'NON', 'ANT'], unassigned_per, 'k-', marker='o', alpha=0.5, lw=0)
    ax3.set_title('Assigned')
    for j, lbl in enumerate(['B6', 'NON', 'ANT']):
        if lbl == 'B6' and b6notdone_3:
            ax3.annotate(lbls[i], xy=(lbl, unassigned_per[j]),textcoords='offset points', xytext=(0,10), ha='center')
        elif lbl != 'B6':
            ax3.annotate(lbls[i], xy=(lbl, unassigned_per[j]),textcoords='offset points', xytext=(0,3), ha='center')
        if lbl == 'B6':
            b6notdone_3 = False
    ax3.set_ylim([0,100])


    ax6 = plt.subplot(gspec[1,2])
    per1 = ses_cut_dict[int(3+i)]['ambiguous_type_dict']['ANT'] / (ses_cut_dict[int(3+i)]['ambiguous_type_dict']['ANT'] +  ses_cut_dict[int(3+i)]['ambiguous_inv_type_dict']['ANT'])
    per2 = ses_cut_dict[int(3+i)]['ambiguous_type_dict']['B6'] / (ses_cut_dict[int(3+i)]['ambiguous_type_dict']['B6'] + ses_cut_dict[int(3+i)]['ambiguous_inv_type_dict']['B6'])
    per3 = ses_cut_dict[int(3+i)]['ambiguous_type_dict']['NON'] / (ses_cut_dict[int(3+i)]['ambiguous_type_dict']['NON'] + ses_cut_dict[int(3+i)]['ambiguous_inv_type_dict']['NON'])
    ambiguous_pers = np.array( [per2, per3, per1])*100
    ax6.plot(['B6', 'NON', 'ANT'], ambiguous_pers, 'k-', marker='o', alpha=0.5, lw=0)
    ax6.set_title('Between 2 angles')
    ax6.bar(['B6', 'NON', 'ANT'], ambiguous_pers, color=['blue', 'green', 'red'], alpha=0.2)
    if i == 0:
        ax6.bar(['B6', 'NON', 'ANT'], ambiguous_pers, color=['blue', 'green', 'red'], alpha = 1)
    print(ambiguous_pers)
    if i == len(dlist)-1:
        comps = dflist[2]['Comparison']
        accepted = dflist[2]['Accepted']
        adjusted = dflist[2]['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'UNASSIGNED' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], ambiguous_pers, maxasterix=5)
    # for j, lbl in enumerate(['B6', 'NON', 'ANT']):
    #     if lbl == 'B6' and i == 0:
    #         ax6.annotate(lbls[i], xy=(lbl, ambiguous_pers[j]),textcoords='offset points', xytext=(0,10), ha='center')
    #     elif lbl != 'B6' and i == 0 or lbl == 'B6' and i == 3:
    #         ax6.annotate(lbls[i], xy=(lbl, ambiguous_pers[j]),textcoords='offset points', xytext=(0,3), ha='center')



    ax5 = plt.subplot(gspec[1,1])
    per1 = ses_cut_dict[int(3+i)]['center_type_dict']['ANT'] / (ses_cut_dict[int(3+i)]['center_type_dict']['ANT'] + ses_cut_dict[int(3+i)]['center_inv_type_dict']['ANT'])
    per2 = ses_cut_dict[int(3+i)]['center_type_dict']['B6'] / (ses_cut_dict[int(3+i)]['center_type_dict']['B6'] + ses_cut_dict[int(3+i)]['center_inv_type_dict']['B6'])
    per3 = ses_cut_dict[int(3+i)]['center_type_dict']['NON'] / (ses_cut_dict[int(3+i)]['center_type_dict']['NON'] + ses_cut_dict[int(3+i)]['center_inv_type_dict']['NON'])
    center_pers = np.array( [per2, per3, per1])*100
    ax5.plot(['B6', 'NON', 'ANT'], center_pers, 'k-', marker='o', alpha=0.5, lw=0)
    ax5.set_title('Closer to center')
    ax5.bar(['B6', 'NON', 'ANT'], center_pers, color=['blue', 'green', 'red'], alpha = .2)


    if i == 0:
        ax5.bar(['B6', 'NON', 'ANT'], center_pers, color=['blue', 'green', 'red'], alpha = 1)

    if i == len(dlist)-1:
        comps = dflist[1]['Comparison']
        accepted = dflist[1]['Accepted']
        adjusted = dflist[1]['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'UNASSIGNED' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], center_pers, maxasterix=5)

    ax4 = plt.subplot(gspec[1,0])
    per1 = ses_cut_dict[int(3+i)]['quality_type_dict']['ANT'] / (ses_cut_dict[int(3+i)]['quality_type_dict']['ANT'] + ses_cut_dict[int(3+i)]['quality_inv_type_dict']['ANT'])
    per2 = ses_cut_dict[int(3+i)]['quality_type_dict']['B6'] / (ses_cut_dict[int(3+i)]['quality_type_dict']['B6'] + ses_cut_dict[int(3+i)]['quality_inv_type_dict']['B6'])
    per3 = ses_cut_dict[int(3+i)]['quality_type_dict']['NON'] / (ses_cut_dict[int(3+i)]['quality_type_dict']['NON'] + ses_cut_dict[int(3+i)]['quality_inv_type_dict']['NON'])
    quality_pers = np.array( [per2, per3, per1])*100
    ax4.plot(['B6', 'NON', 'ANT'], quality_pers, 'k-', marker='o', alpha=0.5, lw=0)
    ax4.set_title('Low quality')
    ax4.bar(['B6', 'NON', 'ANT'], quality_pers, color=['blue', 'green', 'red'], alpha = .2)

    if i == 0:
        ax4.bar(['B6', 'NON', 'ANT'], quality_pers, color=['blue', 'green', 'red'], alpha = 1)
    if i == len(dlist)-1:   
        comps = dflist[0]['Comparison']
        accepted = dflist[0]['Accepted']
        adjusted = dflist[0]['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'UNASSIGNED' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], quality_pers, maxasterix=5)

    ax4.set_ylabel('% of cell-session appearances')


    fig.suptitle("Fisher's exact test, capped at first 3 session")

    fig.tight_layout()


In [None]:
from matplotlib import gridspec


# ses1234_identifier_dict, ses12345_identifier_dict, ses123456_identifier_dict]
dtitles = ['Session 1 to 3']
#  'Session 1 to 4', 'Session 1 to 5', 'Session 1 to 6']
# dpositions = [0.93, 0.68, 0.33, 0.03]
# axids = [[[0,0],[0,1],[0,2]], [[1,0],[1,1],[1,2]], [[2,0],[2,1],[2,2]], [[3,0],[3,1],[3,2]]]
lbls = ['N=3']
#  'N=4', 'N=5', 'N=6']
fig = plt.figure(figsize=(15,8))
# gspec 3x3
gspec = gridspec.GridSpec(2,3)

b6notdone_1 = True
b6notdone_2 = True
b6notdone_3 = True
nonnotdone_2 = True

for i in range(1):
    dtouse = dlist[i]
    dtitle = dtitles[i]
    # dpos = dpositions[i]
    # axid = axids[i]

    ANT_obj = dtouse['object']['ANT']
    ANT_trace = dtouse['trace']['ANT']
    ANT_unassigned = dtouse['unassigned']['ANT']

    prev_ANT_obj_per = len(ANT_obj) / (len(ANT_obj) + len(ANT_trace) + len(ANT_unassigned))
    sum_ANT = ANT_cell_type_df[ANT_cell_type_df['object_location'].astype(str) != 'NO']
    sum_ANT = sum_ANT.groupby(group_by_unique_cell).filter(lambda x: len(x) >= 2).groupby(group_by_unique_cell)
    sum_ANT = len(sum_ANT)
    ANT_obj_per = len(ANT_obj) / sum_ANT
    print(ANT_obj_per, prev_ANT_obj_per)
    ANT_trace_per = len(ANT_trace) / (len(ANT_obj) + len(ANT_trace) + len(ANT_unassigned))
    ANT_unassigned_per = 1 - ANT_obj_per - ANT_trace_per

    B6_obj = dtouse['object']['B6']
    B6_trace = dtouse['trace']['B6']
    B6_unassigned = dtouse['unassigned']['B6']

    # B6_obj_per = len(B6_obj) / (len(B6_obj) + len(B6_trace) + len(B6_unassigned))
    sum_B6 = B6_cell_type_df[B6_cell_type_df['object_location'].astype(str) != 'NO']
    sum_B6 = sum_B6.groupby(group_by_unique_cell).filter(lambda x: len(x) >= 2).groupby(group_by_unique_cell)
    sum_B6 = len(sum_B6)
    B6_obj_per = len(B6_obj) / sum_B6
    B6_trace_per = len(B6_trace) / (len(B6_obj) + len(B6_trace) + len(B6_unassigned))
    # B6_unassigned_per = len(B6_unassigned) / (len(B6_obj) + len(B6_trace) + len(B6_unassigned))
    B6_unassigned_per = 1 - B6_obj_per - B6_trace_per

    NON_obj = dtouse['object']['NON']
    NON_trace = dtouse['trace']['NON']
    NON_unassigned = dtouse['unassigned']['NON']

    # NON_obj_per = len(NON_obj) / (len(NON_obj) + len(NON_trace) + len(NON_unassigned))
    sum_NON = NON_cell_type_df[NON_cell_type_df['object_location'].astype(str) != 'NO']
    sum_NON = sum_NON.groupby(group_by_unique_cell).filter(lambda x: len(x) >= 2).groupby(group_by_unique_cell)
    sum_NON = len(sum_NON)
    NON_obj_per = len(NON_obj) / sum_NON
    NON_trace_per = len(NON_trace) / (len(NON_obj) + len(NON_trace) + len(NON_unassigned))
    # NON_unassigned_per = len(NON_unassigned) / (len(NON_obj) + len(NON_trace) + len(NON_unassigned))
    NON_unassigned_per = 1 - NON_obj_per - NON_trace_per

    # obj_per = np.array([ANT_obj_per, B6_obj_per, NON_obj_per]) * 100
    obj_per = np.array([B6_obj_per, NON_obj_per, ANT_obj_per]) * 100
    # trace_per = np.array([ANT_trace_per, B6_trace_per, NON_trace_per]) * 100
    trace_per = np.array([B6_trace_per, NON_trace_per, ANT_trace_per]) * 100
    # unassigned_per = np.array([ANT_unassigned_per, B6_unassigned_per, NON_unassigned_per]) * 100
    unassigned_per = np.array([B6_unassigned_per, NON_unassigned_per, ANT_unassigned_per]) * 100

    # unassigned_per = 100 - unassigned_per
    unassigned_per = unassigned_per


    # top left
    ax1 = plt.subplot(gspec[0,0])
    ax1.bar(['B6', 'NON', 'ANT'], obj_per, color=['blue', 'green', 'red'], alpha = 0.2)
    if i == 0:
        ax1.bar(['B6', 'NON', 'ANT'], obj_per, color=['blue', 'green', 'red'], alpha = 1)
    # if i == len(dlist)-1:
        comps = results_df['Comparison']
        accepted = results_df['Accepted']
        adjusted = results_df['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'OBJECT' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], unassigned_per, maxasterix=5)

                    
    ax1.plot(['B6', 'NON', 'ANT'], obj_per, 'k-', marker='o', alpha=0.5, lw=0)
    # annotate
    for j, lbl in enumerate(['B6', 'NON', 'ANT']):
        if lbl == 'B6' and b6notdone_1:
            ax1.annotate(lbls[i], xy=(lbl, obj_per[j]),textcoords='offset points', xytext=(0,10), ha='center')
        elif lbl != 'B6':
            ax1.annotate(lbls[i], xy=(lbl, obj_per[j]),textcoords='offset points', xytext=(0,3), ha='center')
        if lbl == 'B6':
            b6notdone_1 = False


    ax1.set_title('Object')
    ax1.set_ylabel('% unique cells')

    # top middle
    ax2 = plt.subplot(gspec[0,1])
    ax2.bar(['B6', 'NON', 'ANT'], trace_per, color=['blue', 'green', 'red'], alpha = 0.2)
    if i == 0:
        ax2.bar(['B6', 'NON', 'ANT'], trace_per, color=['blue', 'green', 'red'], alpha = 1)
    # if i == len(dlist)-1:
        comps = results_df['Comparison']
        accepted = results_df['Accepted']
        adjusted = results_df['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'TRACE' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], unassigned_per, maxasterix=5)

                    
    ax2.plot(['B6', 'NON', 'ANT'], trace_per, 'k-', marker='o', alpha=0.5, lw=0)
    for j, lbl in enumerate(['B6', 'NON', 'ANT']):
        if lbl == 'B6' and b6notdone_2:
            ax2.annotate(lbls[i], xy=(lbl, trace_per[j]),textcoords='offset points', xytext=(0,10), ha='center')
        # elif lbl == 'NON' and nonnotdone_2 and i != 0:
        #     ax2.annotate('N=4,5,6', xy=(lbl, trace_per[j]),textcoords='offset points', xytext=(0,15), ha='center')
        # elif lbl == 'ANT' or (lbl == 'NON' and i == 0):
        elif lbl != 'B6':
            ax2.annotate(lbls[i], xy=(lbl, trace_per[j]),textcoords='offset points', xytext=(0,3), ha='center')
        if lbl == 'B6':
            b6notdone_2 = False
        if lbl == 'NON' and i != 0:
            nonnotdone_2 = False
    ax2.set_title('Trace')

    # top right
    ax3 = plt.subplot(gspec[0,2])
    ax3.bar(['B6', 'NON', 'ANT'], unassigned_per, color=['blue', 'green', 'red'], alpha = 0.2)
    if i == 0:
        ax3.bar(['B6', 'NON', 'ANT'], unassigned_per, color=['blue', 'green', 'red'], alpha = 1)
    # if i == len(dlist)-1:
        comps = results_df['Comparison']
        accepted = results_df['Accepted']
        adjusted = results_df['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'UNASSIGNED' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], unassigned_per, maxasterix=5)

                    


    ax3.plot(['B6', 'NON', 'ANT'], unassigned_per, 'k-', marker='o', alpha=0.5, lw=0)
    ax3.set_title('Unassigned')
    for j, lbl in enumerate(['B6', 'NON', 'ANT']):
        if lbl == 'B6' and b6notdone_3:
            ax3.annotate(lbls[i], xy=(lbl, unassigned_per[j]),textcoords='offset points', xytext=(0,10), ha='center')
        elif lbl != 'B6':
            ax3.annotate(lbls[i], xy=(lbl, unassigned_per[j]),textcoords='offset points', xytext=(0,3), ha='center')
        if lbl == 'B6':
            b6notdone_3 = False
    ax3.set_ylim([0,100])


    ax6 = plt.subplot(gspec[1,2])
    per1 = ses_cut_dict[int(3+i)]['ambiguous_type_dict']['ANT'] / (ses_cut_dict[int(3+i)]['ambiguous_type_dict']['ANT'] +  ses_cut_dict[int(3+i)]['ambiguous_inv_type_dict']['ANT'])
    per2 = ses_cut_dict[int(3+i)]['ambiguous_type_dict']['B6'] / (ses_cut_dict[int(3+i)]['ambiguous_type_dict']['B6'] + ses_cut_dict[int(3+i)]['ambiguous_inv_type_dict']['B6'])
    per3 = ses_cut_dict[int(3+i)]['ambiguous_type_dict']['NON'] / (ses_cut_dict[int(3+i)]['ambiguous_type_dict']['NON'] + ses_cut_dict[int(3+i)]['ambiguous_inv_type_dict']['NON'])
    ambiguous_pers = np.array( [per2, per3, per1])*100
    ax6.plot(['B6', 'NON', 'ANT'], ambiguous_pers, 'k-', marker='o', alpha=0.5, lw=0)
    ax6.set_title('Between 2 angles')
    ax6.bar(['B6', 'NON', 'ANT'], ambiguous_pers, color=['blue', 'green', 'red'], alpha=0.2)
    if i == 0:
        ax6.bar(['B6', 'NON', 'ANT'], ambiguous_pers, color=['blue', 'green', 'red'], alpha = 1)
    # print(ambiguous_pers)
    # if i == len(dlist)-1:
        comps = dflist[2]['Comparison']
        accepted = dflist[2]['Accepted']
        adjusted = dflist[2]['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'UNASSIGNED' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], ambiguous_pers, maxasterix=5)
    # for j, lbl in enumerate(['B6', 'NON', 'ANT']):
    #     if lbl == 'B6' and i == 0:
    #         ax6.annotate(lbls[i], xy=(lbl, ambiguous_pers[j]),textcoords='offset points', xytext=(0,10), ha='center')
    #     elif lbl != 'B6' and i == 0 or lbl == 'B6' and i == 3:
    #         ax6.annotate(lbls[i], xy=(lbl, ambiguous_pers[j]),textcoords='offset points', xytext=(0,3), ha='center')



    ax5 = plt.subplot(gspec[1,1])
    per1 = ses_cut_dict[int(3+i)]['center_type_dict']['ANT'] / (ses_cut_dict[int(3+i)]['center_type_dict']['ANT'] + ses_cut_dict[int(3+i)]['center_inv_type_dict']['ANT'])
    per2 = ses_cut_dict[int(3+i)]['center_type_dict']['B6'] / (ses_cut_dict[int(3+i)]['center_type_dict']['B6'] + ses_cut_dict[int(3+i)]['center_inv_type_dict']['B6'])
    per3 = ses_cut_dict[int(3+i)]['center_type_dict']['NON'] / (ses_cut_dict[int(3+i)]['center_type_dict']['NON'] + ses_cut_dict[int(3+i)]['center_inv_type_dict']['NON'])
    center_pers = np.array( [per2, per3, per1])*100
    ax5.plot(['B6', 'NON', 'ANT'], center_pers, 'k-', marker='o', alpha=0.5, lw=0)
    ax5.set_title('Closer to center')
    ax5.bar(['B6', 'NON', 'ANT'], center_pers, color=['blue', 'green', 'red'], alpha = .2)


    if i == 0:
        ax5.bar(['B6', 'NON', 'ANT'], center_pers, color=['blue', 'green', 'red'], alpha = 1)

    # if i == len(dlist)-1:
        comps = dflist[1]['Comparison']
        accepted = dflist[1]['Accepted']
        adjusted = dflist[1]['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'UNASSIGNED' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], center_pers, maxasterix=5)

    ax4 = plt.subplot(gspec[1,0])
    per1 = ses_cut_dict[int(3+i)]['quality_type_dict']['ANT'] / (ses_cut_dict[int(3+i)]['quality_type_dict']['ANT'] + ses_cut_dict[int(3+i)]['quality_inv_type_dict']['ANT'])
    per2 = ses_cut_dict[int(3+i)]['quality_type_dict']['B6'] / (ses_cut_dict[int(3+i)]['quality_type_dict']['B6'] + ses_cut_dict[int(3+i)]['quality_inv_type_dict']['B6'])
    per3 = ses_cut_dict[int(3+i)]['quality_type_dict']['NON'] / (ses_cut_dict[int(3+i)]['quality_type_dict']['NON'] + ses_cut_dict[int(3+i)]['quality_inv_type_dict']['NON'])
    quality_pers = np.array( [per2, per3, per1])*100
    ax4.plot(['B6', 'NON', 'ANT'], quality_pers, 'k-', marker='o', alpha=0.5, lw=0)
    ax4.set_title('Low quality')
    ax4.bar(['B6', 'NON', 'ANT'], quality_pers, color=['blue', 'green', 'red'], alpha = .2)

    if i == 0:
        ax4.bar(['B6', 'NON', 'ANT'], quality_pers, color=['blue', 'green', 'red'], alpha = 1)
    # if i == len(dlist)-1:   
        comps = dflist[0]['Comparison']
        accepted = dflist[0]['Accepted']
        adjusted = dflist[0]['Adjusted P-value']
        for k in range(len(comps)):
            comparison = comps[k]
            if 'UNASSIGNED' in str(comparison):
                if 'ANT' in comparison and 'B6' in comparison:
                    nme = [0,2]
                elif 'ANT' in comparison and 'NON' in comparison:
                    nme = [1,2]
                elif 'B6' in comparison and 'NON' in comparison:
                    nme = [0,1]
                
                # if accepted[k]:
                barplot_annotate_brackets(nme[0],nme[1],adjusted[k],[0,1,2], quality_pers, maxasterix=5)

    ax4.set_ylabel('% of cell-session appearances')
    
    ax1.set_ylim(0,100)
    ax2.set_ylim(0,100)
    ax3.set_ylim(0,100)
    ax4.set_ylim(0,100)
    ax5.set_ylim(0,100)
    ax6.set_ylim(0,100)

    fig.suptitle("Fisher's exact test, capped at first 3 session")

    fig.tight_layout()


In [None]:
tosave = dlist[0]
import pickle
# save the dictionary
with open('ses123_identifier_dict.pkl', 'wb') as f:
    pickle.dump(tosave, f)

In [None]:
prev_df = None
for ky in tosave['trace']:
    for idd in tosave['trace'][ky]:
        id1, id2, id3, id4, id5, id6 = idd
        # print(id1, id2, id3, id4, id5, id6)
        # mask = pd.concat([(df2['group'] == id1) & (df2['name'] == id2)?
        mask = pd.concat([(df2['group'] == id1) & (df2['name'] == id2) & (df2['depth'] == id3) & (df2['date'] == id4) & (df2['tetrode'] == id5) & (df2['unit_id'] == id6)], axis=1).any(axis=1)
        if prev_df is None:
            new_df = df2[mask]
        else:
            new_df = pd.concat([prev_df, df2[mask]])
        prev_df = new_df

prev_df.to_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC_only_trace_cells.xlsx')

prev_df = None
for ky in tosave['object']:
    for idd in tosave['object'][ky]:
        id1, id2, id3, id4, id5, id6 = idd
        # mask = pd.concat([(df2['group'] == id1) & (df2['name'] == id2)?
        mask = pd.concat([(df2['group'] == id1) & (df2['name'] == id2) & (df2['depth'] == id3) & (df2['date'] == id4) & (df2['tetrode'] == id5) & (df2['unit_id'] == id6)], axis=1).any(axis=1)
        if prev_df is None:
            new_df = df2[mask]
        else:
            new_df = pd.concat([prev_df, df2[mask]])
        prev_df = new_df

prev_df.to_excel(r'C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit\df_full_LEC_only_object_cells.xlsx')




In [None]:
len(tosave['unassigned']['ANT']) + len(tosave['unassigned']['B6']) + len(tosave['unassigned']['NON'])

In [None]:
126+53+41