# Demonstration of Spike Train Analysis Algorithm Replication

Wesley Borden

## Introduction

Here, I demonstrate the use of a set of python functions I have developed to replicate two algorithms. I also provide exploratory graph analysis and visualizations of the network.

## Setup

### Imports

In [None]:
import os
import uuid
from typing import Callable, Optional

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd

from brainbox.io.one import SpikeSortingLoader
from iblutil.util import Bunch
from one.alf.io import AlfBunch
from one.api import OneAlyx, ONE  # Docs: https://int-brain-lab.github.io/ONE/

from cc.cc import cross_correlate
from gu.utils import adjacency_matrix_from_pairwise, show_graph
from te.te import transfer_entropy

### API

IBL API as demonstrated in `.../data-demo/nsp_data_demo_jwb.ipynb`

In [None]:
one_alyx: OneAlyx = ONE(
    cache_dir="/Users/wesley/GitHub/BYU/ms-proj/tmp/one-cache",  # any directory where temporary files can be synced
    base_url="https://openalyx.internationalbrainlab.org",  # base url for the API
    password="international",  # public-access password
    silent=True,  # don't print progress, etc.
)  # most 'type: ignore' are because IBL's libraries are less strict on types # type: ignore

In [None]:
data_tag = "2024_Q2_IBL_et_al_BWM_iblsort"  # tag for most recent data release ()
all_sessions: list = one_alyx.search(  # list of sessions
    tag=data_tag, query_type="remote"
)  # type: ignore
n_sessions = len(all_sessions)
print(f"Session count: {n_sessions}")
print(f"Session example: {all_sessions[0]}")

all_insertions: list = one_alyx.search_insertions(  # list of insertions
    tag=data_tag, query_type="remote"
)  # type: ignore
n_insertions = len(all_insertions)
print(f"Insertion count: {n_insertions}")
print(f"Insertion example: {all_insertions[0]}")

In [None]:
# Use the same one for consistency between this demo and other IBL demos
pid_i = 534
pid: str = str(all_insertions[pid_i])
pid_details: tuple[str, str] = one_alyx.pid2eid(pid)
eid, p_name = pid_details

print(f"Probe ID: {pid}")
print(f"Probe Name: {p_name}")
print(f"Experiment ID: {eid}")

### Load Spike-Sorted Data

As demonstrated in `.../data-demo/nsp_data_demo_jwb.ipynb`

In [None]:
spike_loader = SpikeSortingLoader(pid=pid, one=one_alyx)

In [None]:
spike_sorting_data: tuple[AlfBunch, AlfBunch, Bunch] = spike_loader.load_spike_sorting()  # type: ignore
spikes, clusters, channels = spike_sorting_data

In [None]:
spikes_df = spikes.to_df()
spikes_df

In [None]:
clusters_wrangled: dict = {}
for k, v in clusters.items():
    if v.ndim == 1:
        clusters_wrangled[k] = v
    elif v.ndim == 2:
        for k_sub in v:
            v_sub = v[k_sub]
            clusters_wrangled[k_sub] = v_sub
    else:
        raise ValueError("Bad dimensions")


clusters_df: pd.DataFrame = pd.DataFrame(clusters_wrangled)
clusters_df

In [None]:
channels_df = AlfBunch(channels).to_df()
channels_df

In [None]:
merged_clusters: AlfBunch = spike_loader.merge_clusters(spikes, clusters, channels)  # type: ignore

In [None]:
merged_clusters_df = merged_clusters.to_df()
merged_clusters_df

### Timeframe

As demonstrated in `.../data-demo/nsp_data_demo_jwb.ipynb`

In [None]:
start_time = 150  # seconds since beginning the electrophysiology recording
end_time = 152  # seconds since beginning the electrophysiology recording

In [None]:
spikes_df_timeframe = spikes_df[start_time <= spikes_df["times"]]
spikes_df_timeframe = spikes_df_timeframe[spikes_df_timeframe["times"] <= end_time]
spikes_df_timeframe

### Determine Bin Size

In `.../data-demo/nsp_data_demo_jwb.ipynb`, we used high-resolution bins that slowed down processing. For spike train analysis, we can tune bin size as a hyperparameter. We will use a 10ms bins size, which aligns with a prior study (Moore, 1970, Statistical Signs of Synaptic Interaction in Neurons, https://doi.org/10.1016/S0006-3495(70)86341-X)

In [None]:
bins_per_s = 100 # each bin is 10ms

### Wrangle to Clusters-by-Time Matrix

Adapted from `.../data-demo/nsp_data_demo_jwb.ipynb`

In [None]:
cluster_channel_map = (
    merged_clusters_df[["cluster_id", "channels"]]
    .copy()
    .sort_values(by="channels", ascending=True)
    .reset_index(drop=True)
    .reset_index(drop=False)
    .rename(inplace=False, columns={"index": "cluster_channel_id"})
)
cluster_channel_map

In [None]:
spikes_df_timeframe = spikes_df_timeframe.merge(
    cluster_channel_map, left_on="clusters", right_on="cluster_id", how="left"
)
spikes_df_timeframe["time_bin"] = (np.floor((spikes_df_timeframe["times"] - start_time) * bins_per_s)).astype(int)  # bin by microsecond
spikes_df_timeframe

In [None]:
clusters_spikes_matrix = np.zeros(
    (cluster_channel_map.shape[0], ((end_time - start_time) * bins_per_s))
)  # type: ignore
clusters_spikes_matrix[
    (
        spikes_df_timeframe["cluster_channel_id"].max()
        - spikes_df_timeframe["cluster_channel_id"].values
    ),
    spikes_df_timeframe["time_bin"].values,
] = int(1)  # 1 represents a spike # type: ignore
clusters_spikes_matrix

### Visualize Spike Trains

As demonstrated in `.../data-demo/nsp_data_demo_jwb.ipynb`

In [None]:
fig, axs = plt.subplots(figsize=(10, 8))

axs.scatter(
    spikes_df_timeframe["times"].values,  # type: ignore
    spikes_df_timeframe["cluster_channel_id"].values,  # type: ignore
    s=1,
    alpha=0.5,
    c="#000000",
    marker="s",
)

axs.set_title("Putative Neural Spikes")
axs.set_xlabel("Time (s)")
axs.set_ylabel("Putative Neuron")

### Notes

Everything to this point has been copied or adapted from `.../data-demo/nsp_data_demo_jwb.ipynb`. Now we will show how to use the data to identify a biological neural network: a partial connectome.

## Cross Correlation

Cross correlation involves a sliding dot product of two vectors that represent parallel spike trains. The resulting distribution includes outliers if there is a significant correlation between the two spike trains. This is implemented in `cross_correlate`, which returns a category as follows:

|Category | Meaning |
|---|---|
|  1| relationship |
|  0| no relationship |

In [None]:
print(f"Comparing a spike train to itself returns {cross_correlate(clusters_spikes_matrix[0], clusters_spikes_matrix[0])}")
print(f"Comparing a spike train to a distant spike train returns {cross_correlate(clusters_spikes_matrix[0], clusters_spikes_matrix[-1])}")

In [None]:
sample_limit = 10
sample_count = 0
for i, _ in enumerate(clusters_spikes_matrix):
    for j in range(i+1, min(i+100, len(clusters_spikes_matrix))):
        if cross_correlate(clusters_spikes_matrix[i], clusters_spikes_matrix[j]):
            sample_count += 1
            print(f"Spike train {i} is functionally connected to spike train {j}")
            if sample_count >= sample_limit:
                break
    if sample_count >= sample_limit:
        break 


We can use cross correlation to construct an adjacency matrix and a graph:

In [None]:
adjacency_matrix_from_cc = adjacency_matrix_from_pairwise(clusters_spikes_matrix, cross_correlate)
g_from_cc = nx.from_numpy_array(adjacency_matrix_from_cc)

Let's visualize the resulting graph

In [None]:
show_graph(g_from_cc, (5, 3))

That gave a lot of unconnected nodes. Let's look at the largest component

In [None]:
largest_connected_component_nodes_cc = max(nx.connected_components(g_from_cc), key=len)
largest_connected_component_subgr_cc = g_from_cc.subgraph(largest_connected_component_nodes_cc).copy()
show_graph(largest_connected_component_subgr_cc, layout_fun_ = nx.spring_layout)


## Transfer Entropy

Transfer entropy could be described as how much the outcome of time series Y is described by time series X, considering L timepoints in history.

In [None]:
print(f"Comparing a spike train to itself returns {transfer_entropy(clusters_spikes_matrix[0], clusters_spikes_matrix[0])}")
print(f"Comparing a spike train to a distant spike train returns {transfer_entropy(clusters_spikes_matrix[0], clusters_spikes_matrix[-1])}")

In [None]:
sample_limit = 10
sample_count = 0
for i, _ in enumerate(clusters_spikes_matrix):
    for j in range(i+1, min(i+100, len(clusters_spikes_matrix))):
        if transfer_entropy(clusters_spikes_matrix[i], clusters_spikes_matrix[j]) > 0: #TODO
            sample_count += 1
            print(f"Spike train {i} is functionally connected to spike train {j}")
            if sample_count >= sample_limit:
                break
    if sample_count >= sample_limit:
        break 


We can use transfer entropy to construct an adjacency matrix and a graph:

In [None]:
adjacency_matrix_from_te = adjacency_matrix_from_pairwise(clusters_spikes_matrix, transfer_entropy)
g_from_te = nx.from_numpy_array(adjacency_matrix_from_te, create_using=nx.DiGraph)

Let's visualize the resulting graph and largest connected component

In [None]:
show_graph(g_from_te, (5, 3))

largest_connected_component_nodes_te = max(nx.weakly_connected_components(g_from_te), key=len)
largest_connected_component_subgr_te = g_from_cc.subgraph(largest_connected_component_nodes_te).copy()
show_graph(largest_connected_component_subgr_te, layout_fun_ = nx.spring_layout)


## Graph Analysis

The cross correlation and transfer entropy algorithms yielded the `largest_connected_component_subgr_cc` and `largest_connected_component_subgr_te` respectively, where each is the maximal connected subgraph of the identified partial connectome. Each graph includes nodes corresponding to putative neurons, and edges corresponding to putative neural connections. Moving forward, we'll focus on the transfer-entropy-produced graph, `G`, analyzing it using basic network science techniques.

In [None]:
G = largest_connected_component_subgr_te

In [None]:
for n in G.nodes:
    G.nodes[n]["ind"] = n

### Degree Distribution

In [None]:
for n in G.nodes:
    G.nodes[n]['degree'] = G.degree[n]

print(f"The degree distribution has max: {max(dict(G.degree()).values())}, min: {min(dict(G.degree()).values())}, and mean: {sum(dict(G.degree()).values())/len(dict(G.degree()).values())}")

def show_degree_distribution(G: nx.Graph, fp: Optional[str] = None) -> None:
    """ 
    Adapted from CS 575 work, which was adapted from Hands-On Graph Neural Networks Using Python by Maxime Labonne, chapter 6.
    """
    degree_list: list[int] = [y for (_,y) in G.degree] # type: ignore
    
    _, ax = plt.subplots()
    ax.set_title('Degree Distribution')
    ax.set_xlabel('Node degree')
    ax.set_ylabel('Number of nodes')
    
    plt.hist(degree_list)

    if fp:
        plt.savefig(fp)
        plt.close()
    else:
        plt.show()

show_degree_distribution(G)

The network can be visualized by degree:

![degree image](/Users/wesley/GitHub/BYU/ms-proj/replication/gephi/degree.png)

### Scale Free Property

In [None]:
def show_degree_density(G: nx.Graph, log_log: bool = True) -> None:
    degree_list: list[int] = [y for (_,y) in G.degree] # type: ignore
    degree_dict: dict[int, int] = {}
    for deg in degree_list:
        if deg not in degree_dict.keys():
            degree_dict[deg] = 0
        degree_dict[deg] += 1

    _, ax = plt.subplots()
    ax.set_title(f'{"Log-Log " if log_log else ""}Degree Distribution')
    ax.set_xlabel('Node degree')
    ax.set_ylabel('Probability')

    x = [degree for degree, _ in sorted(degree_dict.items())]
    y = [count for _, count in sorted(degree_dict.items())]
    if 0 not in x:
        x.insert(0,0)
        y.insert(0,0)
    y = y / np.sum(y)
    if log_log:
        plt.loglog(x,y)
    else:
        plt.plot(x,y)

show_degree_density(G, False)
show_degree_density(G, True)

This is not a scale-free network

### Centrality

In [None]:
g_pagerank = nx.pagerank(G)

for n, val in g_pagerank.items():
    G.nodes[n]['centrality'] = val

print(f"The centrality distribution has max: {max(g_pagerank.values())}, min: {min(g_pagerank.values())}, and mean: {sum(g_pagerank.values())/len(g_pagerank.values())}")

def show_centrality_distribution(G: nx.Graph, fp: Optional[str] = None) -> None:
    """ 
    Adapted from CS 575 work, which was adapted from Hands-On Graph Neural Networks Using Python by Maxime Labonne, chapter 6.
    """
    centrality_list: list[int] = [G.nodes[n]['centrality'] for n in G.nodes] # type: ignore
    
    _, ax = plt.subplots()
    ax.set_title('Centrality Distribution')
    ax.set_xlabel('Node centrality')
    ax.set_ylabel('Number of nodes')
    
    plt.hist(centrality_list)

    if fp:
        plt.savefig(fp)
        plt.close()
    else:
        plt.show()

show_centrality_distribution(G)

The network can be visualized by centrality:

![centrality image](/Users/wesley/GitHub/BYU/ms-proj/replication/gephi/centrality.png)

### Clustering Coefficient

In [None]:
g_clustering = dict(nx.clustering(G)) # type: ignore

for n, val in g_clustering.items():
    G.nodes[n]['clustering'] = val

print(f"The clustering distribution has max: {max(g_clustering.values())}, min: {min(g_clustering.values())}, and mean: {sum(g_clustering.values())/len(g_clustering.values())}") # type: ignore

def show_clustering_distribution(G: nx.Graph, fp: Optional[str] = None) -> None:
    """ 
    Adapted from CS 575 work, which was adapted from Hands-On Graph Neural Networks Using Python by Maxime Labonne, chapter 6.
    """
    clustering_list: list[int] = [G.nodes[n]['clustering'] for n in G.nodes] # type: ignore
    
    _, ax = plt.subplots()
    ax.set_title('Clustering Distribution')
    ax.set_xlabel('Node clustering coefficient')
    ax.set_ylabel('Number of nodes')
    
    plt.hist(clustering_list)

    if fp:
        plt.savefig(fp)
        plt.close()
    else:
        plt.show()

show_clustering_distribution(G)

### Eccentricity and the Small World Property

In [None]:
g_eccentricity = dict(nx.eccentricity(G)) # type: ignore

for n, val in g_eccentricity.items():
    G.nodes[n]['eccentricity'] = val

print(f"The eccentricity distribution has max (Diameter): {max(g_eccentricity.values())}, min (Radius): {min(g_eccentricity.values())}, and mean: {sum(g_eccentricity.values())/len(g_eccentricity.values())}") # type: ignore

def show_eccentricity_distribution(G: nx.Graph, fp: Optional[str] = None) -> None:
    """ 
    Adapted from CS 575 work, which was adapted from Hands-On Graph Neural Networks Using Python by Maxime Labonne, chapter 6.
    """
    eccentricity_list: list[int] = [G.nodes[n]['eccentricity'] for n in G.nodes] # type: ignore
    
    _, ax = plt.subplots()
    ax.set_title('Eccentricity Distribution')
    ax.set_xlabel('Node eccentricity coefficient')
    ax.set_ylabel('Number of nodes')
    
    plt.hist(eccentricity_list)

    if fp:
        plt.savefig(fp)
        plt.close()
    else:
        plt.show()

show_eccentricity_distribution(G)


### Partition

In [None]:
communities = nx.algorithms.community.louvain_communities(G)
for n in G.nodes:
    for i, c in enumerate(communities): # type: ignore
        if n in c:
            G.nodes[n]["community"] = i
    if G.nodes[n].get("community") is None:
        G.nodes[n]["community"] = "9999"

The network can be visualized by community:

![community image](/Users/wesley/GitHub/BYU/ms-proj/replication/gephi/communities.png)

### Core-Periphery Structures

In [None]:
k_core_info = dict(nx.core_number(G)) # type: ignore

for n, val in k_core_info.items():
    G.nodes[n]['kcore'] = val

print(f"The K core distribution has max: {max(k_core_info.values())}, min: {min(k_core_info.values())}, and mean: {sum(k_core_info.values())/len(k_core_info.values())}") # type: ignore

def show_kcore_distribution(G: nx.Graph) -> None: # type: ignore
    """ 
    Adapted from CS 575 work, which was adapted from Hands-On Graph Neural Networks Using Python by Maxime Labonne, chapter 6.
    """
    g_functional = nx.DiGraph() # type: ignore
    g_functional.add_nodes_from([n for n in G.nodes])
    g_functional.add_edges_from([(e[0], e[1]) for e in G.edges if (e[0], e[1]) not in g_functional.edges])
    k_core_info = nx.core_number(g_functional)

    k_core_scores = [v for _, v in k_core_info.items()]

    _, ax = plt.subplots()
    ax.set_title("K Core Structure")
    ax.set_xlabel('Max Node K Core')
    ax.set_ylabel('Number of nodes')
    plt.hist(k_core_scores)
    plt.show()

show_kcore_distribution(G)

The network can be visualized by K core:

![kcore image](/Users/wesley/GitHub/BYU/ms-proj/replication/gephi/kcores.png)

Consistent with prior work in CS 575, this network shows a strong core-periphery structure

## Export to Gephi

In [None]:
tmp_fp = f"/Users/wesley/GitHub/BYU/ms-proj/tmp/{str(uuid.uuid4())}.gexf"
assert os.path.exists(os.path.dirname(tmp_fp))
nx.write_gexf(G, tmp_fp)