# Comprehensive Spike and Burst Analysis with Visibility Graphs

This notebook provides a detailed workflow for analyzing electrophysiological ABF recordings.
It goes through the following steps:
1. Loading ABF files and concatenating sweeps
2. Spike detection
3. Burst detection
4. Burst classification
5. Saving bursts to CSV
6. Visibility graph creation for each burst
7. 2D and 3D embeddings of the burst graphs
8. Visualization of spikes, bursts, and graphs
9. Saving nodes and edges data to CSV

## Step 0: Load Libraries

In [None]:
import pyabf  # Reads ABF electrophysiology files
import numpy as np  # Numerical operations
import pandas as pd  # DataFrames and CSV handling
import matplotlib.pyplot as plt  # Plotting
from mpl_toolkits.mplot3d import Axes3D  # 3D plotting support
import networkx as nx  # Graph creation and analysis
from scipy.signal import find_peaks  # Spike detection
from sklearn.decomposition import TruncatedSVD  # Embedding
%matplotlib widget

## Step 1: Load ABF File and Concatenate Sweeps

In [None]:
file_path = "bursting/cell89basal.abf"
abf = pyabf.ABF(file_path)

# Concatenate sweeps properly
sweeps = []
for i in range(abf.sweepCount):
    abf.setSweep(i)
    sweeps.append(np.copy(abf.sweepY))
signal = np.concatenate(sweeps) if len(sweeps) > 0 else np.array([])

dt = 1.0 / abf.dataRate if abf.dataRate > 0 else 1.0
time = np.arange(len(signal)) * dt

if len(time) > 0:
    print(f"File: {file_path} | sweeps: {abf.sweepCount} | total duration: {time[-1]:.2f} s")
else:
    print(f"File: {file_path} loaded but empty signal.")

## Step 2: Spike Detection

In [None]:
threshold = -35  # mV
if len(signal) > 0:
    spike_indices, props = find_peaks(signal, height=threshold)
    spike_times = time[spike_indices]
else:
    spike_indices = np.array([])
    spike_times = np.array([])

## Step 3: Burst Detection

In [None]:
bursts = []
if len(spike_times) > 0:
    isi = np.diff(spike_times)
    burst_threshold = 0.3  # seconds
    current_burst = [0]

    for i in range(1, len(spike_times)):
        if isi[i-1] < burst_threshold:
            current_burst.append(i)
        else:
            if len(current_burst) > 1:
                bursts.append(current_burst)
            current_burst = [i]
    if len(current_burst) > 1:
        bursts.append(current_burst)

    print(f"Detected {len(bursts)} bursts")

    isi_per_spike_burst = np.zeros(len(spike_times))
    for burst in bursts:
        for i, idx in enumerate(burst):
            isi_per_spike_burst[idx] = 0 if i == 0 else (spike_times[idx] - spike_times[burst[i - 1]]) * 1000
else:
    isi_per_spike_burst = np.array([])

## Step 4: Burst Classification

In [None]:
square_wave_bursts, parabolic_bursts, other_bursts = [], [], []
burst_types = {}
baseline_window = 0.05  # seconds

for i, burst in enumerate(bursts):
    t_start = spike_times[burst[0]]
    t_end = spike_times[burst[-1]]

    burst_mask = (time >= t_start) & (time <= t_end)
    burst_min = np.min(signal[burst_mask]) if np.any(burst_mask) else np.nan

    prev_mask = (time >= max(0, t_start - baseline_window)) & (time < t_start)
    next_mask = (time > t_end) & (time <= min(time[-1], t_end + baseline_window))

    prev_mean = np.mean(signal[prev_mask]) if np.any(prev_mask) else np.nan
    next_mean = np.mean(signal[next_mask]) if np.any(next_mask) else np.nan
    inter_mean = np.nanmean([prev_mean, next_mean]) if not np.isnan(prev_mean) or not np.isnan(next_mean) else np.nanmean(signal)

    if np.isnan(inter_mean) or np.isnan(burst_min):
        burst_types[tuple(burst)] = "Other"
        other_bursts.append(burst)
    elif burst_min > inter_mean:
        burst_types[tuple(burst)] = "Square Wave"
        square_wave_bursts.append(burst)
    elif burst_min < inter_mean:
        burst_types[tuple(burst)] = "Parabolic"
        parabolic_bursts.append(burst)
    else:
        burst_types[tuple(burst)] = "Other"
        other_bursts.append(burst)

## Step 5: Save Bursts to CSV

In [None]:
burst_list = []
for idx, burst in enumerate(bursts):
    burst_type = burst_types.get(tuple(burst), "Other")
    burst_list.append([idx + 1, spike_times[burst[0]], spike_times[burst[-1]], burst_type])

df_bursts_all = pd.DataFrame(burst_list, columns=["Burst_Number", "Start_Time_s", "End_Time_s", "Type"])
df_bursts_all.to_csv("burst_basic_info_cell89_all_bursts.csv", index=False)
df_bursts_all.head()

## Step 6: Visibility Graphs and Embeddings

In [None]:
colors_map = {"Square Wave": "blue", "Parabolic": "green", "Other": "orange"}
nodes_list, edges_list = [], []

for b_idx, burst in enumerate(bursts):
    burst_type = burst_types.get(tuple(burst), "Other")
    n_nodes = len(burst)
    if n_nodes == 0:
        continue

    x_peaks = np.arange(n_nodes)
    y_peaks = isi_per_spike_burst[burst]

    G = nx.Graph()
    G.add_nodes_from(range(n_nodes))

    for a in range(n_nodes):
        for b in range(a + 1, n_nodes):
            visible = True
            for c in range(a + 1, b):
                y_line = y_peaks[a] + (y_peaks[b] - y_peaks[a]) * ((x_peaks[c] - x_peaks[a]) / (x_peaks[b] - x_peaks[a]))
                if y_peaks[c] >= y_line:
                    visible = False
                    break
            if visible:
                G.add_edge(a, b)
                edges_list.append([b_idx + 1, burst_type, a, b])

    A = nx.to_numpy_array(G)
    if A.shape[0] >= 2:
        embedding_2d = TruncatedSVD(n_components=2, random_state=42).fit_transform(A)
    else:
        embedding_2d = np.zeros((n_nodes, 2))

    if A.shape[0] >= 3:
        embedding_3d = TruncatedSVD(n_components=3, random_state=42).fit_transform(A)
    else:
        embedding_3d = np.zeros((n_nodes, 3))

    for i in range(n_nodes):
        nodes_list.append([b_idx + 1, burst_type, i,
                           embedding_2d[i, 0], embedding_2d[i, 1],
                           embedding_3d[i, 0], embedding_3d[i, 1], embedding_3d[i, 2],
                           burst[i]])

## Step 7: Save Nodes and Edges CSV

In [None]:
df_nodes = pd.DataFrame(nodes_list, columns=["Burst_Number", "Type", "Node_ID",
                                             "X_2D", "Y_2D", "X_3D", "Y_3D", "Z_3D",
                                             "Spike_Global_Index"])
df_edges = pd.DataFrame(edges_list, columns=["Burst_Number", "Type", "Node1_ID", "Node2_ID"])

df_nodes.to_csv("burst_nodes_all.csv", index=False)
df_edges.to_csv("burst_edges_all.csv", index=False)