In [1]:
flex_source_link = "https://github.com/et22"

In [2]:
import seaborn as sns
import pandas as pd
import numpy as np
import ipywidgets as widgets
from scipy.io import loadmat
from IPython.display import display, Markdown, clear_output
from ipywidgets import interact, interactive, interactive_output
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import ipywidgets as widgets
from IPython.display import clear_output, display
from sklearn.cluster import AgglomerativeClustering, KMeans, MiniBatchKMeans
from sklearn.metrics import silhouette_score, silhouette_samples
from sklearn.decomposition import PCA

from scipy.spatial.distance import euclidean
from scipy.cluster.hierarchy import dendrogram
from scipy.ndimage import gaussian_filter1d

from scipy.stats import chi2_contingency
                      
import networkx as nx

from pyvis.network import Network
import pandas as pd

import matplotlib.pyplot as plt
from IPython.display import set_matplotlib_formats
%matplotlib inline
set_matplotlib_formats('svg')

np.random.seed(42)


In [3]:
def make_cont_mtx(cat_row, cat_col, cont, num_rows):
    cat_rows = np.arange(num_rows)+1
    cat_cols = np.arange(num_rows)+1
    cont_mtx = np.zeros((cat_rows.size, cat_cols.size))
    for idxr, row in np.ndenumerate(cat_rows):
        for idxc, col in np.ndenumerate(cat_cols):
            cont_mtx[idxr, idxc] = np.nanmean(cont[np.logical_and(cat_row == row,cat_col == col)]);
    return cont_mtx;

def make_prop_mtx(cat_row, cat_col, numer_var, num_rows):
    cat_rows = np.arange(num_rows)+1
    cat_cols = np.arange(num_rows)+1
    prop_numer = np.zeros((cat_rows.size, cat_cols.size))
    prop_denom = np.zeros((cat_rows.size, cat_cols.size))
    for idxr, row in np.ndenumerate(cat_rows):
        for idxc, col in np.ndenumerate(cat_cols):
            subset = np.logical_and(cat_row == row,cat_col == col)
            rev_subset = np.logical_and(cat_row == col, cat_col == row)
            either_subset = np.logical_or(subset, rev_subset)
            prop_numer[idxr, idxc] = np.nansum(numer_var[subset])
            prop_denom[idxr, idxc] = np.size(numer_var[subset])
            
    prop_mtx = np.divide(prop_numer, prop_denom)
    return prop_mtx, prop_numer, prop_denom

def get_checkbox_inclusions(boxes, pre_tlabels, post_tlabels):
    subset = np.ones_like(pre_tlabels, dtype='bool')
    for idx, box in enumerate(boxes):
        if not box.value:
            ex_val = idx + 1
            subset = np.logical_and.reduce((subset, pre_tlabels!=ex_val, post_tlabels!=ex_val))
    return subset
            
def get_subset(change):
    metric = ccg_selection.value
    
    maxima = maxima_selection.value
    lag_max = lag_selection.value[1]
    lag_min = lag_selection.value[0]
    
    std_max = std_selection.value[1]
    std_min = std_selection.value[0]
    
    area_max = area_selection.value[1]
    area_min = area_selection.value[0]
    
    ccg_curr = ccg_data[metric][0][0].copy()
    ccg_fields = ccg_data['ccg'][0][0].dtype.names
    
    lag_subset = np.logical_and(ccg_curr[maxima+'_lag']>=lag_min, ccg_curr[maxima+'_lag']<=lag_max)
    maxima_subset = np.logical_and(ccg_curr[maxima+'s']>=(std_min*ccg_curr['noise_std2']+ccg_curr['noise_mean2']), ccg_curr[maxima+'s']<=(std_max*ccg_curr['noise_std2']+ccg_curr['noise_mean2'])) 
    area_subset = np.logical_and(ccg_curr['area']>=area_min, ccg_curr['area']<=area_max)    
    
    cl_subset = get_checkbox_inclusions(cl_checkboxes[1:],ccg_curr['pre_cl'],ccg_curr['post_cl'])
    ct_subset = get_checkbox_inclusions(ct_checkboxes[1:],ccg_curr['pre_ct'],ccg_curr['post_ct'])
    sc_subset = get_checkbox_inclusions(sc_checkboxes[1:],ccg_curr['pre_sc'],ccg_curr['post_sc'])

    subset = np.logical_and.reduce((lag_subset, maxima_subset, area_subset, cl_subset, ct_subset, sc_subset))
    
    for field in ccg_fields:
        if field != 'config' and field != 'cluster' and field != 'ccg_control':
            ccg_curr[field] = ccg_curr[field][np.squeeze(subset)]
            
    return ccg_curr, maxima, metric, subset 

In [4]:
## clustering functions 
def cluster_flow(ccgs):
    num_elems = np.min([ccgs.shape[0], 10000]) # choose at most 10000 ccgs for clustering for efficiency purposes
    inc_idx = np.random.choice(ccgs.shape[0], size=num_elems, replace=True, p=None)
    ccgs = ccgs[inc_idx,:]
    
    # smooth 
    smoothed = gaussian_filter1d(ccgs, 2, axis=0)
    # rescale
    smoothed = (smoothed-np.min(smoothed, 0))/np.ptp(smoothed, 0)
    # pca
    pca_obj = PCA(n_components=smoothed.shape[1])
    x_new = np.ascontiguousarray(pca_obj.fit_transform(smoothed))
    
    for idx, curr_var in enumerate(pca_obj.explained_variance_ratio_):
        var_explained = np.sum(pca_obj.explained_variance_ratio_[0:idx])
        k_comp = idx
        if var_explained>.95:
            break
    
    # cluster        
    cluster_obj = AgglomerativeClustering(n_clusters = 3, compute_full_tree = True, compute_distances=True)
    cluster_obj.fit(x_new[:,0:k_comp])
    
    return inc_idx, smoothed, cluster_obj

def find_number_of_clusters(ccgs, range_n_clusters):
    score = []
    for k in range_n_clusters:
        cluster_obj = AgglomerativeClustering(n_clusters=k, compute_labels=True)
        cluster_labels = cluster_obj.fit_predict(ccgs)
        if k>1:
            silhouette_samps = silhouette_samples(ccgs, cluster_labels)
            sihouette_score=np.average(silhouette_samps, weights=weights)
        else:
            sihouette_score=0
        score.append(sihouette_score)
    return score

 




In [5]:
## clustering plot functions
def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack([model.children_, model.distances_,
                                      counts]).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)
    plt.xlabel("Number of points in node (or index of point if no parenthesis).")
    plt.show()
    
## plot cluster templates
def plot_cluster_templates(ccgs, idxes):
    uq_idxes = np.unique(idxes)
    print(uq_idxes)
    for idx in uq_idxes:
        template = np.mean(ccgs[idxes==idx], axis=0)
        plt.subplot(1,uq_idxes.size,idx)
        h = plt.plot(range(-10, -10 +template.size),template, 'o-')
        plt.xlabel("tau")
    plt.tight_layout()
    plt.show()
                
## plot cluster examples
def plot_cluster_examples(ccgs, idxes):
    uq_idxes = np.unique(idxes)
    for idx in uq_idxes:
        ex_ccgs = ccgs[idxes==idx]
        ex_idxes = np.random.choice(ex_ccgs.shape[0], size = 3, replace=True)
        for idx1, example in enumerate(ex_idxes):
            plt.subplot(3,uq_idxes.size,(idx1)*3+idx)
            h = plt.plot(range(-10, -10 + ex_ccgs[example].size),ex_ccgs[example], 'o-')
            if idx1>2:
                plt.xlabel("tau")
    plt.tight_layout()
    plt.show()
                
## plot cluster heatmaps
def plot_cluster_heatmaps(idxes, ccg_curr, pre_lab, post_lab, num_rows, row_labels):
    uq_idxes = np.unique(idxes)
    for idx in uq_idxes:
        mtx, numer, denom = make_prop_mtx(np.squeeze(ccg_curr[pre_lab]), np.squeeze(ccg_curr[post_lab]), np.squeeze(idxes==idx), num_rows)
        df = pd.DataFrame(mtx, columns=row_labels, index=row_labels)
        
        sns.heatmap(df, cmap="vlag", annot=True)
        plt.show()
        
## plot cluster network
def plot_cluster_network(idxes, ccg_curr):
    G = nx.DiGraph()
    edge_list = []
    inc_idxes = np.random.choice(ccg_curr['pre_id'].shape[0], size = 400, replace=True)

    for idx, id_val in enumerate(ccg_curr['pre_id'][inc_idxes]):
         edge_list.append((int(ccg_curr['pre_id'][inc_idxes[idx]]), int(ccg_curr['post_id'][inc_idxes[idx]]))) #, {'color', str(idxes[idx])}
    print(edge_list)
    G.add_edges_from(edge_list)
    print(G.number_of_nodes())
    print(G.number_of_edges())
    nx.draw(G)
    plt.show()




In [6]:
ccg_data = loadmat('int_output/combined_ccg_data.mat', chars_as_strings=True)

In [7]:
ccg_data = ccg_data['ccg_data'][0][0]

In [8]:
# define cat variable labels
labels = {'cl': ["2/3", "4a/b", "4cα", "4cβ", "5", "6", "WM"], 
         'sc': ["Complex", "Simple"],
         'ct': ["AS", "FS", "RM", "RL"]}

# Introduction


### 

# Correlated Neural Activity in Macaque V1
---

Welcome to the interactive analysis tool for *Functional Connectivity of Neurons within Single Cortical Columns Measured with Neuropixels*! 

This tool is designed to allow the community to explore a large dataset of Neuropixel recordings from macaque V1 and test hypotheses regarding functional connectivity among distinct layers and functional and putative cell types. 
## Tutorial 

## Complete Documentation 

### Input
####  Metrics
#### Exclusion Criteria 
<center>
<img src='https://static.wixstatic.com/media/2997bf_9718a381f05f4cc9a721428ad26c6639~mv2.jpeg/v1/crop/x_35,y_47,w_1223,h_614/fill/w_388,h_195,al_c,q_80,usm_0.66_1.00_0.01/Image%207-17-20%20at%203_45%20PM.webp'>
<center>

# Input Selection

## Column

In [9]:
# code
default_label = widgets.Label(value="Default settings:")
default_selection = widgets.RadioButtons(
    options=[('s.d. (smith/kohn)', 'sd'),('s.d.+area (ours)', 'sd_area')],
    value='sd_area',
    disabled=False
)
default_button = widgets.Button(
    description='reset to defaults',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    icon='check' # (FontAwesome names without the `fa-` prefix)
)


ccg_label = widgets.Label(value="Metric:")
ccg_selection = widgets.RadioButtons(
    options=[('Efficacy', 'ccg'),('Contribution', 'ccgpren'), ('Geom. mean','ccgn')],
    value='ccg',
    disabled=False
)
maxima_label = widgets.Label(value="Maxima:")
maxima_selection = widgets.RadioButtons(
    options=[('Peaks', 'peak'), ('Troughs', 'trough')],
    value='peak',
    disabled=False
)
std_label = widgets.Label(value="Maxima > or < k std. + mean noise:")
std_selection = widgets.FloatRangeSlider(
    value=[3, 10],
    min=-10.0,
    max=10.0,
    step=0.1,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)
lag_label = widgets.Label(value="Tau:")
lag_selection = widgets.IntRangeSlider(
    value=[0, 10],
    min=0,
    max=10,
    step=1,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
)
area_label = widgets.Label(value="CCG integral over |Tau|=0-10:")
area_selection = widgets.FloatRangeSlider(
    value=[.05, .5],
    min=-.5,
    max=.5,
    step=.01,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

cl_label = widgets.Label(value="Layer:")
cl_checkboxes = [widgets.Checkbox(value=True, indent=True, description=label) for label in labels['cl']]
cl_checkboxes[6].value = False; # don't include WM in default layers
cl_checkboxes.insert(0,cl_label);
cl_selection = widgets.VBox(children=cl_checkboxes)

ct_label = widgets.Label(value="Putat. type:")
ct_checkboxes = [widgets.Checkbox(value=True, indent=True, description=label) for label in labels['ct']]
ct_checkboxes.insert(0,ct_label);
ct_selection = widgets.VBox(children=ct_checkboxes)

sc_label = widgets.Label(value="Func. type:")
sc_checkboxes = [widgets.Checkbox(value=True, indent=True, description=label) for label in labels['sc']]
sc_checkboxes.insert(0,sc_label);
sc_selection = widgets.VBox(children=sc_checkboxes)
ex_selection = widgets.HBox([cl_selection, ct_selection, sc_selection])

plot_label = widgets.Label(value="Plot By:")
plot_selection = widgets.RadioButtons(
    options=[('Layer', 'cl'), ('Putative cell type', 'ct'), ('Functional cell type', 'sc')],
    value='cl',
    disabled=False
)

slide_ex_label = widgets.Label(value="Sliders determine which CCGs to include in the analysis, e.g., selecting a range for Tau of 1-10 means CCGs with Tau = 0 are excluded.")
check_ex_label = widgets.Label(value="Checkboxes determine which neurons to include in the analysis, e.g., checking W.M. means neurons in white matter will be included in figures.")





### Metric Input

In [10]:
# ignore for now widgets.VBox([default_label, default_selection, default_button]),
widgets.HBox([
              widgets.VBox([ccg_label, ccg_selection]), 
              widgets.VBox([maxima_label, maxima_selection]),
              widgets.VBox([plot_label, plot_selection])])


HBox(children=(VBox(children=(Label(value='Metric:'), RadioButtons(options=(('Efficacy', 'ccg'), ('Contributio…

### CCG Subset

In [11]:
widgets.VBox([widgets.VBox([slide_ex_label, std_label, std_selection,lag_label,lag_selection, area_label, area_selection]),
              widgets.VBox([check_ex_label, ex_selection])])

VBox(children=(VBox(children=(Label(value='Sliders determine which CCGs to include in the analysis, e.g., sele…

In [12]:
data = ccg_data['ccg'][0][0]['cluster'][0][0]
df = pd.DataFrame({'session': np.squeeze(data['Cluster_session']),
              'celllayer': np.squeeze(data['Cluster_celllayer']),
              'cell depth': np.squeeze(data['Cluster_celldepth']),
               'celltype': np.squeeze(data['Cluster_celltype']), 
              'MI_max': np.squeeze(data['Cluster_MI_max']), 
              'simpcomp': np.squeeze(data['Cluster_simpcomp'])})

def on_layer_change(change):
    depth = data['Cluster_celldepth']
    new_layers = np.sum(depth<l23.value,depth<l4ab.value,depth<l4ca.value,depth<l4cb.value,depth<l5.value,depth<l6.value,depth<lWM.value)
    new_layers[new_layers==0] = float("NaN")
    
def layer_picker_output(ses_num):
    df_ses1 = df[df["session"]==ses_num]

    y = []
    for i in range(len(labels['cl'])):
        y.append(np.max(np.squeeze(data['Cluster_celldepth'])
                        [np.squeeze(np.logical_and(data['Cluster_session']==ses_num,data['Cluster_celllayer']==(i+1)))]));

    l23=widgets.IntSlider(value=int(y[0]), min=0, max=3000,description=labels['cl'][0], continuous_update=False)
    l4ab=widgets.IntSlider(value=int(y[1]), min=0, max=3000,description=labels['cl'][1], continuous_update=False)
    l4ca=widgets.IntSlider(value=int(y[2]), min=0, max=3000,description=labels['cl'][2], continuous_update=False)
    l4cb=widgets.IntSlider(value=int(y[3]), min=0, max=3000,description=labels['cl'][3], continuous_update=False)
    l5=widgets.IntSlider(value=int(y[4]), min=0, max=3000,description=labels['cl'][4], continuous_update=False)
    l6=widgets.IntSlider(value=int(y[5]), min=0, max=3000,description=labels['cl'][5], continuous_update=False)
    lWM=widgets.IntSlider(value=int(y[6]), min=0, max=3000,description=labels['cl'][6], continuous_update=False)
    ui = widgets.VBox([l23, l4ab, l4ca,l4cb, l5, l6,lWM])

    def plot_layer_lines(l23, l4ab, l4ca, l4cb, l5, l6, lWM):
        plt.figure(2)
        ax1 = sns.kdeplot(data=df_ses1, y="cell depth", hue="simpcomp", fill=False, cut=0, bw_adjust=.4)
        ax1_lim = ax1.get_xlim()
        ax1.set_xlim((ax1_lim[0],2*ax1_lim[1])) 
        plt.legend(labels['sc'], frameon=False,loc='upper left')
        plt.xticks([])
        plt.xlabel('')
        ax1.spines['right'].set_visible(False)
        ax1.spines['top'].set_visible(False)
        ax1.spines['bottom'].set_visible(False)


        ax2 = ax1.twiny()
        sns.kdeplot(data=df_ses1, y="cell depth", hue="celltype",fill=False, ax=ax2, cut=0, bw_adjust=.4)
        ax2_lim = ax2.get_xlim()
        ax2.set_xlim((ax2_lim[0]-ax2_lim[1],ax2_lim[1])) 
        plt.legend(labels['ct'], frameon=False, loc='upper center')
        plt.xticks([])
        plt.xlabel('')
        ax2.spines['right'].set_visible(False)
        ax2.spines['top'].set_visible(False)
        ax2.spines['bottom'].set_visible(False)

        layer_min = []
        x_vals = []
        layer_lab = []
        ax3 = ax1.twiny()

        for i in range(len(labels['cl'])):
            if i == 0:
                y = l23
            elif i == 1:
                y = l4ab
            elif i == 2:
                y = l4ca
            elif i == 3:
                y = l4cb
            elif i == 4:
                y = l5
            elif i == 5:
                y = l6
            elif i == 6:
                y = lWM

            layer_min.append(y)
            layer_min.append(y)
            layer_lab.append(labels['cl'][i])
            layer_lab.append(labels['cl'][i])
            x_vals.append(ax3.get_xlim()[0])
            x_vals.append(ax3.get_xlim()[1])

        df_layer = pd.DataFrame({'x': x_vals, 'y': layer_min, 'layer_lab':layer_lab})

        sns.lineplot(data=df_layer, x='x', y='y', hue='layer_lab', ax=ax3, palette="crest" )
        plt.legend(labels['cl'], frameon=False, loc='upper right')
        plt.xticks([])
        plt.xlabel('')
        ax3.spines['right'].set_visible(False)
        ax3.spines['top'].set_visible(False)
        ax3.spines['bottom'].set_visible(False)
        ax3.set_ylim(ax3.get_ylim()[0], ax3.get_ylim()[1]*1.5)

        plt.show()


    output = interactive_output(plot_layer_lines, {'l23':l23,
                                                      'l4ab':l4ab,
                                                      'l4ca':l4ca,
                                                      'l4cb':l4cb,
                                                      'l5':l5,
                                                      'l6':l6,
                                                      'lWM':lWM})

    out_layer = widgets.Output()



    display(widgets.HBox([ui,output]))

### Ses. 1 Layer

In [13]:
layer_picker_output(1)



HBox(children=(VBox(children=(IntSlider(value=2210, continuous_update=False, description='2/3', max=3000), Int…

### Ses. 2 Layer

In [14]:
layer_picker_output(2)


HBox(children=(VBox(children=(IntSlider(value=2210, continuous_update=False, description='2/3', max=3000), Int…

### Ses. 3 Layer

In [15]:
layer_picker_output(3)

HBox(children=(VBox(children=(IntSlider(value=2210, continuous_update=False, description='2/3', max=3000), Int…

In [16]:
# code
pre_pair_label = widgets.Label(value="Pre:")
pre_opt = [(label, idx+1) for idx, label in enumerate(labels[plot_selection.value])]
pre_pair_selection = widgets.RadioButtons(
    options=pre_opt,
    value=pre_opt[0][1],
    disabled=False
)
post_pair_label = widgets.Label(value="Post:")
post_opt =  [(label, idx+1) for idx, label in enumerate(labels[plot_selection.value])]
post_pair_selection = widgets.RadioButtons(
    options=post_opt,
    value=post_opt[0][1],
    #layout={'width': 'max-content'},
    disabled=False
)

In [17]:
out0 = widgets.Output()
out0s = widgets.Output()

out = widgets.Output()
outs = widgets.Output()
outsi = widgets.Output()

out2 = widgets.Output()
out2s = widgets.Output()
out2si = widgets.Output()

out3 = widgets.Output()
out3s = widgets.Output()
out3si = widgets.Output()

out4 = widgets.Output()
out4s = widgets.Output()
out4si = widgets.Output()

out5 = widgets.Output()


def on_value_change(change):
    ccg_curr, maxima, metric, subset  = get_subset(change)
    plot_type = plot_selection.value;
    row_labels = labels[plot_type]
    num_rows = len(row_labels);
    pre_lab = "pre_" + plot_type
    post_lab = "post_" + plot_type
    
    with out0: 
        ex_ccgs = ccg_curr['ccgs']
        ex_idxes = np.random.choice(ex_ccgs.shape[0], size = 4, replace=True)
        
        out0.clear_output(wait=True)
        for idx, example in enumerate(ex_idxes):
            plt.subplot(2,2,idx+1)
            h = plt.plot(range(-10, -10 + ex_ccgs[example].size),ex_ccgs[example], 'o-')
            if idx>1:
                plt.xlabel("tau")
            if idx == 0 or idx == 2:
                plt.ylabel(metric)
            plt.tight_layout()
                
        plt.show()
        
    with out:
        cont_mtx = make_cont_mtx(ccg_curr[pre_lab], ccg_curr[post_lab], ccg_curr[maxima+'_lag'], num_rows)
        df = pd.DataFrame(cont_mtx, columns=row_labels, index=row_labels)
        
        out.clear_output(wait=True)
        sns.heatmap(df, annot=True)
        plt.xlabel('post-syn. neuron')
        plt.ylabel('pre-syn. neuron')
        plt.title("mean lead (ms), n = " + str(ccg_curr[pre_lab].size))
        plt.show()
    
    with out2:
        cont_mtx = make_cont_mtx(ccg_curr[pre_lab], ccg_curr[post_lab], ccg_curr[maxima+'s'], num_rows)
        df = pd.DataFrame(cont_mtx, columns=row_labels, index=row_labels)
        
        out2.clear_output(wait=True)
        sns.heatmap(df, cmap='YlOrBr', annot=True)
        plt.xlabel('post-syn. neuron')
        plt.ylabel('pre-syn. neuron')
        plt.title("mean maxima (ms), n = " + str(ccg_curr[pre_lab].size))
        plt.show()
        
    with out3:
        mtx, numer, denom = make_prop_mtx(ccg_curr[pre_lab], ccg_curr[post_lab], ccg_curr['peak_lag']>-1, num_rows)
        prop_mtx = np.divide(numer, numer + np.transpose(numer));
        df = pd.DataFrame(prop_mtx, columns=row_labels, index=row_labels)
        
        out3.clear_output(wait=True)
        sns.heatmap(df, cmap='vlag', center=.5, annot=True) #, annot=True
        plt.xlabel('neuron k')
        plt.ylabel('neuron j')
        plt.title("prop. j leads k, n = " + str(ccg_curr[pre_lab].size))
        plt.show()
    
    with out4:
        ccg_pre_subset = ccg_data[metric][0][0].copy()
        mtx, numer, denom = make_prop_mtx(ccg_pre_subset[pre_lab], ccg_pre_subset[post_lab], subset, num_rows)
        df = pd.DataFrame(mtx, columns=row_labels, index=row_labels)
        
        out4.clear_output(wait=True)
        sns.heatmap(df, cmap='YlOrBr', annot=True) #, annot=True
        plt.xlabel('neuron k')
        plt.ylabel('neuron j')
        plt.title("prop. sig, n = " + str(ccg_curr[pre_lab].size) + "/" + str(ccg_pre_subset[pre_lab].size))
        plt.show()
    
    with out5:
        ccg_fields = ccg_data['ccg'][0][0].dtype.names
        ccgs = ccg_curr['ccgs']
        subset, smoothed, cluster_obj = cluster_flow(ccgs)

        for field in ccg_fields:
            if field != 'config' and field != 'cluster' and field != 'ccg_control':
                ccg_curr[field] = ccg_curr[field][np.squeeze(subset)]

        out5.clear_output(wait=True)
        #plot_cluster_network(cluster_obj.labels_, ccg_curr)
        plot_cluster_examples(smoothed, cluster_obj.labels_+1)
        plot_cluster_templates(smoothed, cluster_obj.labels_+1)
        plot_dendrogram(cluster_obj, truncate_mode='level', p=3)
        plot_cluster_heatmaps(cluster_obj.labels_, ccg_curr, 'pre_cl', 'post_cl', len(labels['cl']), labels['cl'])

    on_pairwise_change(change)

def on_metric_change(change):
    metric = ccg_selection.value
    maxima = maxima_selection.value
    ccg_curr = ccg_data[metric][0][0].copy()

    maxima_val = ccg_curr[maxima+'s']
    maxima_noise = ccg_curr['noise_std2']
    maxima_noise_mean = ccg_curr['noise_mean2']
    maxima_area = np.squeeze(ccg_curr['area'])

    sd_inclusion = np.squeeze((maxima_val-maxima_noise_mean)/maxima_noise);
    
    nan_mask = np.logical_or.reduce((np.isnan(sd_inclusion), np.isnan(maxima_area), np.isinf(sd_inclusion), np.isinf(maxima_area)))
    xrange_mask = np.logical_and(sd_inclusion<=10, sd_inclusion>=0)
    yrange_mask = np.logical_and(maxima_area>=-.1, maxima_area<=.1)
    data_mask = np.logical_and.reduce((~nan_mask, xrange_mask, yrange_mask))
    
    with out0s:
        out0s.clear_output(wait=True)
        fig = sns.jointplot(x=sd_inclusion[data_mask], y=maxima_area[data_mask], kind='hist',dropna=True, xlim=(0, 10), ylim=(-.1, .1))     
        fig.set_axis_labels("x SDs above noise mean","ccg area")
        plt.show()
        
def on_plot_type_change(change):        
    plot_type = plot_selection.value;    
    pre_opt = [(label, idx+1) for idx, label in enumerate(labels[plot_type])]
    pre_pair_selection.value = pre_opt[0][1]
    pre_pair_selection.options = pre_opt
    
    post_opt = [(label, idx+1) for idx, label in enumerate(labels[plot_type])]
    post_pair_selection.value = post_opt[0][1]
    post_pair_selection.options = post_opt
    
    with outsi:
        outsi.clear_output(wait=True)
        display(widgets.HBox([widgets.VBox([pre_pair_label, pre_pair_selection]), widgets.VBox([post_pair_label, post_pair_selection])]))
    
    with out2si:
        out2si.clear_output(wait=True)
        display(widgets.HBox([widgets.VBox([pre_pair_label, pre_pair_selection]), widgets.VBox([post_pair_label, post_pair_selection])]))
    
    with out3si:
        out3si.clear_output(wait=True)
        display(widgets.HBox([widgets.VBox([pre_pair_label, pre_pair_selection]), widgets.VBox([post_pair_label, post_pair_selection])]))

    with out4si:
        out4si.clear_output(wait=True)
        display(widgets.HBox([widgets.VBox([pre_pair_label, pre_pair_selection]), widgets.VBox([post_pair_label, post_pair_selection])]))
        
    on_pairwise_change(change)

def on_pairwise_change(change):
    ccg_curr, maxima, metric, subset  = get_subset(change)
    plot_type = plot_selection.value;
    row_labels = labels[plot_type]
    num_rows = len(row_labels);
    pre_lab = "pre_" + plot_type
    post_lab = "post_" + plot_type
    
    pre_to_post = np.squeeze(np.logical_and(ccg_curr[pre_lab] == pre_pair_selection.value, ccg_curr[post_lab] ==post_pair_selection.value))
    post_to_pre = np.squeeze(np.logical_and(ccg_curr[post_lab] == pre_pair_selection.value, ccg_curr[pre_lab] ==post_pair_selection.value))
    conx_labels = {1: "pre->post", 2: "post->pre", 0: "other"}        

    with outs:
        lag = np.squeeze(ccg_curr[maxima+'_lag'])
        connection = pd.Categorical(np.append(pre_to_post, 2*post_to_pre)).rename_categories(conx_labels)
        lag = np.append(lag, lag)
        df = pd.DataFrame({'lead':lag, 'connection':connection})
        
        outs.clear_output(wait=True)
        sns.histplot(data=df, x='lead', hue="connection", stat="density", common_norm=False, fill=False)
        plt.show()
    
    with out2s:   
        lag = np.squeeze(ccg_curr[maxima+'s'])
        lag = np.append(lag, lag)
            
        connection = pd.Categorical(np.append(pre_to_post, 2*post_to_pre)).rename_categories(conx_labels)
        df = pd.DataFrame({maxima:lag, 'connection':connection})
        
        out2s.clear_output(wait=True)
        sns.histplot(data=df, x=maxima, hue="connection", stat="density", common_norm=False, fill=False)
        plt.show()
        
    with out3s:
        data = np.append(-1*ccg_curr[maxima+'_lag'][pre_to_post], ccg_curr[maxima+'_lag'][post_to_pre])
        data = data[data!=0]
        out3s.clear_output(wait=True)
        sns.histplot(x=data, stat="density", fill=False, bins=21)
        plt.show()
        
    with out4s:
        ccg_pre_subset = ccg_data[metric][0][0].copy()
        all_pre_to_post = np.squeeze(np.logical_and(ccg_pre_subset[pre_lab] == pre_pair_selection.value, ccg_pre_subset[post_lab] ==post_pair_selection.value))
        all_post_to_pre = np.squeeze(np.logical_and(ccg_pre_subset[post_lab] == pre_pair_selection.value, ccg_pre_subset[pre_lab] ==post_pair_selection.value))
  
        sig_notsig = np.zeros((2,2))
        sig_notsig[0,0] =  np.sum(pre_to_post) # sig + pre->post
        sig_notsig[0,1] =  np.sum(all_pre_to_post)-np.sum(pre_to_post)# notsig + pre->post
        sig_notsig[1,0] =  np.sum(post_to_pre) # sig + post->pre
        sig_notsig[1,1] =  np.sum(all_post_to_pre)-np.sum(post_to_pre)# notsig + post-> pre
        
        g, p, dof, expctd = chi2_contingency(sig_notsig)
        
        out4s.clear_output(wait=True)
        
        print("Chi^2 test: chi^2 = " + str(g) + " p = " + str(p))

def on_click_reset(change):
    if default_selection.value == "sd_area":
        plot_selection.value = 'cl'
        ccg_selection.value = 'ccg'
        maxima_selection.value = 'peak'
        area_selection.value = (.05,.5)
        lag_selection.value = (0,10)
        std_selection.value = (3,10)
    elif default_selection.value == "sd":
        plot_selection.value = 'cl'
        ccg_selection.value = 'ccgn'
        maxima_selection.value = 'peak'
        area_selection.value = (-.5,.5)
        lag_selection.value = (0,10)
        std_selection.value = (5,10)
    else:
        error("wrong")

for cbox in ct_checkboxes:
    cbox.observe(on_value_change, names="value")
for cbox in cl_checkboxes:
    cbox.observe(on_value_change, names="value")
for cbox in sc_checkboxes:
    cbox.observe(on_value_change, names="value")


plot_selection.observe(on_value_change, names="value")
ccg_selection.observe(on_value_change, names="value")
maxima_selection.observe(on_value_change, names="value")
area_selection.observe(on_value_change, names="value")
lag_selection.observe(on_value_change, names="value")
std_selection.observe(on_value_change, names="value")

plot_selection.observe(on_plot_type_change, names="value")

pre_pair_selection.observe(on_pairwise_change, names="value")
post_pair_selection.observe(on_pairwise_change, names="value")

ccg_selection.observe(on_metric_change, names="value")
maxima_selection.observe(on_metric_change, names="value")

default_button.on_click(on_click_reset)

# CCG Examples

## Column 1

### Output

In [18]:
on_value_change(None)
out0

Output()

## Column 2


### Output

In [19]:
on_metric_change(None)
out0s

Output()

# |Tau| Heatmap

## Column 1

### Output

In [20]:
out

Output()

## Column 2

### Pairwise Test Input


In [21]:
on_plot_type_change(None)
outsi

Output()

### Pairwise Test Output

In [22]:
outs


Output()

# Maxima Heatmap

## Column 1

### Output

In [23]:
out2

Output()

## Column 2

### Pairwise Test Input

In [24]:
out2si

Output()

### Pairwise Test Output

In [25]:
out2s

Output()

# Prop. j leads k Heatmap

## Column 1

### Output

In [26]:
out3

Output()

## Column 2

### Pairwise Test Input

In [27]:
out3si

Output()

### Pairwise Test Output

In [28]:
out3s

Output()

# Prop. significant

## Column 1

### Output

In [29]:
out4

Output()

## Column 2

### Pairwise Test Input

In [30]:
out4si

Output()

### Pairwise Test Output

In [31]:
out4s

Output()

# Clustering

## Column

### Cluster Output

In [32]:
out5

Output()