# Synapse polyadicity in the *Megaphragma* lamina
- R1-6, L2, and AC consistenly form presynaptic terminals in every lamina cartridge
- How many postsynaptic partners does each terminals contact on average?
- Does average multiplicity/polyadicity vary among cell-types?
- Is there a characteristic set of post-synaptic partners that is common among a cell's terminals?  

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os.path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import mannwhitneyu
import statsmodels.api as sm

import sys
from cx_analysis.dataframe_tools import extract_connector_table
from cx_analysis.cartridge_metadata import ret_clusters
from cx_analysis.vis.fig_tools import subtype_cm
#from cx_analysis.vis.hex_lattice import hexplot

# import matplotlib as mpl
# import matplotlib.gridspec as gridspec
# from matplotlib.lines import Line2D

In [3]:
plt.rcdefaults()
plt.style.use('../cx_analysis/vis/lamina.mplstyle') # may not work if installed as a module

st_cm = subtype_cm()

save_figs = False

In [4]:
# Each row is a pre->post link
tp = '210809'
data_path = f'~/Data/{tp}_lamina/{tp}_linkdf.pickle'
df = pd.read_pickle(data_path)
subtypes = np.unique(df['post_type'])  # R1-6 are [R1R4, R2R5, ...]
ommatidia = np.unique(df['pre_om'])
# Summarize linkdf as a table with rows for each pre-terminal
ct_df = extract_connector_table(df) # DataFrame of connectors (presyn terminals)
ct_df['post_count'] = ct_df.loc[:, subtypes].sum(axis=1)

# # Rhabdom vols from Anastasia
# rb = pd.read_csv('~/Data/lamina_additional_data/ret_cell_vol.csv').set_index('rtype').T
# rb.index.name = 'om'
# rb = rb.loc[sorted(rb.index), sorted(rb.columns)]
# rb_frac = (rb.T/rb.sum(axis=1)).T.rename(mapper={'vol': 'fvol'}, axis=1)

# rtypes = rb.columns  # R1-6 are [R1, R2, ...]


In [5]:
df

Unnamed: 0,link_id,cx_id,pre_neuron,pre_om,pre_type,pre_skel,post_neuron,post_om,post_type,post_skel
0,194220,276258,omB6_LN,B6,LMC_N,25,omB6_L1,B6,LMC_1,175606
1,175596,276258,omB6_LN,B6,LMC_N,25,omB6_L2,B6,LMC_2,44725
2,175139,276258,omB6_LN,B6,LMC_N,25,omB6_L1,B6,LMC_1,175606
3,175128,276258,omB6_LN,B6,LMC_N,25,omB6_R3,B6,R_quartet,174970
4,175628,277482,omB6_LN,B6,LMC_N,25,168408,UNKNOWN,UNKNOWN,168408
...,...,...,...,...,...,...,...,...,...,...
20774,318465,479967,omC2_centri_nc,C2,centri,319210,omC2_centri_nc,C2,centri,319210
20775,318468,479967,omC2_centri_nc,C2,centri,319210,omC2_R4_nc,C2,R_quartet,294885
20776,318467,479967,omC2_centri_nc,C2,centri,319210,omC2_L3_nc,C2,LMC_3,309836
20777,318466,479967,omC2_centri_nc,C2,centri,319210,omC2_R7p_nc,C2,R7p,294545


## Average polyadicity of each presynaptic cell

In [None]:
poly_df = pd.DataFrame({'pre_neuron': ct_df['pre_neuron'].unique()})

# TODO Preprocessing: get a dict with kv for each cx_id: {post_name: [node coords]}
for pre, rows in ct_df.groupby('pre_neuron'):
    ind = (poly_df['pre_neuron'] == pre)
    
    poly_df.loc[ind, 'pre_om'] = rows['pre_om'][0]
    poly_df.loc[ind, 'pre_type'] = rows['pre_type'][0]
    
    poly_df.loc[ind, 'mn_poly'] = rows['post_count'].mean(axis=0)
    poly_df.loc[ind, 'n_terms'] = len(rows)
    poly_df.loc[ind, 'n_contacts'] = rows['post_count'].sum(axis=0)

display("### Class Averages ####")
display(poly_df.groupby('pre_type').describe().T.round(decimals=1))
#display(poly_df.groupby('pre_type').std(ddof=0)) # cause df.describe() uses Series.std(ddof=1)

display("### R1-6 Overall ####")
r1_6_pairs = ['R1R4', 'R2R5', 'R3R6']
r1_6_ind = [i for i, v in poly_df['pre_type'].items() if v in r1_6_pairs]
display(poly_df.loc[r1_6_ind].describe({}).round(decimals=1))
#display(poly_df.loc[r1_6_ind].std(ddof=0)) 

print("\nNOTE: ")
print("n != 29 or 58 when the subtype is not consistently presynaptic in the lamina")

In [None]:
ct_df

In [None]:
pretypes = ['R1R4', 'R2R5', 'R3R6', 'LMC_2', 'centri']

data = poly_df.loc[[i for i, p in poly_df['pre_type'].items() if p in pretypes]]

fig, ax = plt.subplots(1, figsize=[2.3, 2.3])
sns.stripplot(data=data, x='pre_type', y='mn_poly', 
              hue='pre_type', palette=subtype_cm(), 
              order=pretypes, ax=ax)




ax.legend_.remove()
ax.set_xticklabels(['R1/R4', 'R2/R5', 'R3/R6', 'L2', 'AC'], fontsize=7.0)
ax.set_xlabel('Presynaptic cell type')
ax.set_ylabel('Multiplicity')

if save_figs:
    fig.savefig('/mnt/home/nchua/Dropbox/lamina_figures/pre-multi-scatter.svg')
    fig.savefig('/mnt/home/nchua/Dropbox/lamina_figures/pre-multi-scatter.png')


In [None]:
# pretypes = ['R1R4', 'R2R5', 'R3R6', 'LMC_2', 'centri']

# data = poly_df.loc[[i for i, p in poly_df['pre_type'].items() if p in pretypes]]

# fig, ax = plt.subplots(1, figsize=[2.3, 2.3])
# sns.boxplot(data=data, x='pre_type', y='mn_poly', 
#               hue='pre_type', palette=subtype_cm(), 
#               order=pretypes, ax=ax)


# ax.legend_.remove()
# ax.set_xticklabels(['R1/R4', 'R2/R5', 'R3/R6', 'L2', 'AC'], fontsize=7.0)
# ax.set_xlabel('Presynaptic cell type')
# ax.set_ylabel('Multiplicity')

# if save_figs:
#     fig.savefig('/mnt/home/nchua/Dropbox/lamina_figures/pre-multi-scatter.svg')
#     fig.savefig('/mnt/home/nchua/Dropbox/lamina_figures/pre-multi-scatter.png')


## Hypothesis: Average multiplicity is larger among R1-6 than L2 and AC
$$H_{0}: P(x_{i} > y_{j}) <= 1/2$$
- Mann-Whitney U-test. Compares two samples by their rank 

In [None]:
i = [a for a, this_type in poly_df['pre_type'].items() if this_type in ['R1R4', 'R2R5', 'R3R6']]
j = [b for b, this_type in poly_df['pre_type'].items() if this_type in ['LMC_2', 'centri']]
s, p, = mannwhitneyu(poly_df.loc[i, 'mn_poly'], 
                     poly_df.loc[j, 'mn_poly'], alternative='greater')
print("###### RESULTS ######")
print(f"Test statistic: {s}, p-value: {p: .2e}")
if p > 0.001:
    print("Fail to reject null")
else:
    print("Reject null")

### Bivariate distibution (number of terminals vs average multiplicity)

In [None]:
pretypes = ['R1R4', 'R2R5', 'R3R6', 'LMC_2', 'centri']
x_var = 'mn_poly'
y_var = 'n_terms'

labels = {'R1R4': 'R1/R4', 'R2R5': 'R2/R5', 'R3R6': 'R3/R6', 
          'LMC_2': 'L2', 'centri': 'AC'}

data = poly_df.loc[[i for i, p in poly_df['pre_type'].items() if p in pretypes]]

g = sns.JointGrid(data=data, x=x_var, y=y_var, hue='pre_type', palette=st_cm,
                  hue_order=pretypes, height=3.0)

g.plot_joint(sns.scatterplot)
ax = plt.gca()
#ax.legend([labels[pre] for pre in g.hue_order])
g.plot_marginals(sns.kdeplot)
g.set_axis_labels('Multiplicity', '# presynaptic terminals')

if save_figs:
    g.savefig('/mnt/home/nchua/Dropbox/lamina_figures/pre-joint-multi-terms.svg')
    g.savefig('/mnt/home/nchua/Dropbox/lamina_figures/pre-joint-multi-terms.png')



In [None]:
any_contact = ct_df.iloc[:, [0, 1, 2, -1]]
multi_contact = ct_df.iloc[:, [0, 1, 2, -1]]
#average_contact = ct_df.iloc[:, [0, 1, 2, -1]]

for post in subtypes:
    any_contact.loc[:, post] = ct_df[post].values > 0
    multi_contact.loc[:, post] = ct_df[post].values > 1
    #multi.loc[:, post] = ct_df[post].values > 1
display(multi_contact)
    

In [None]:
# Probability of multi-contact for each presynaptic neuron 

data = dict()
for pre_neuron, rows in multi_contact.groupby('pre_neuron'):
    neuron_data = pd.Series(data={'pre_neuron': pre_neuron, 
                                  'pre_type': rows['pre_type'][0],
                                  'pre_om': rows['pre_om'][0]})
    
    percent = pd.Series(rows.loc[:, subtypes].sum() / float(len(rows)))
    neuron_data = neuron_data.append(percent)
    data.update({pre_neuron: neuron_data})

percent_multi = pd.DataFrame(data).T
    
display(percent_multi)

In [None]:
these_pre = ['R1R4', 'R2R5', 'R3R6']
these_post = ['LMC_1', 'LMC_2', 'LMC_3', 'LMC_4', 'centri']

for pre, rows in percent_multi.groupby('pre_type'):
    display(f"{pre}->")
    display(rows[these_post].mean())
    display(rows[these_post].std(ddof=0))

#fig, axes = plt.subplots(len(these_pre)) 
i = 0

for pre_type, rows in percent_multi.groupby('pre_type'):
    if pre_type in these_pre:
        g = sns.displot(data=rows.loc[:, these_post], palette=st_cm)
        #g.set_title(pre_type)
        i += 1
    else:
        continue

In [None]:
these_pre = ['R1R4', 'R2R5', 'R3R6']
these_post = ['LMC_1', 'LMC_2', 'LMC_3', 'LMC_4']

#fig, axes = plt.subplots(len(these_pre)) 
i = 0

for post_type in these_post:
    for pre_type, rows in percent_multi.groupby('pre_type'):
        if pre_type in these_pre:
            g = sns.displot(data=rows, x=post_type, palette=st_cm)
        #g.set_title(pre_type)
            i += 1
        else:
            continue