# Spike and Burst Analysis with Visibility Graphs and Embeddings

This notebook demonstrates a complete workflow for analyzing electrophysiological ABF recordings:
1. Loading ABF files and concatenating sweeps
2. Detecting spikes
3. Detecting bursts and computing intra-burst ISI
4. Classifying bursts as Square Wave, Parabolic, or Other
5. Constructing visibility graphs for each burst
6. Computing 2D and 3D embeddings of bursts
7. Exporting burst, node, and edge information to CSV
8. Visualizing the voltage trace, spikes, bursts, and embeddings

In [None]:
# Load libraries
import pyabf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
from scipy.signal import find_peaks
from sklearn.decomposition import TruncatedSVD
%matplotlib widget

## Step 1: Load ABF file and concatenate sweeps
ABF (Axon Binary File) format is used for patch-clamp or intracellular recordings.
We merge all sweeps into a single continuous voltage trace.

In [None]:
file_path = "bursting/cell89basal.abf"
abf = pyabf.ABF(file_path)

signal = np.concatenate([abf.setSweep(i) or abf.sweepY for i in range(abf.sweepCount)])
dt = 1.0 / abf.dataRate
time = np.arange(len(signal)) * dt

print(f"File: {file_path} | sweeps: {abf.sweepCount} | total length: {time[-1]:.2f} s")

## Step 2: Spike detection
Detect spikes using a voltage threshold. Spikes are identified as local maxima above -35 mV.

In [None]:
threshold = -35  # mV
spike_indices, _ = find_peaks(signal, height=threshold)
spike_times = time[spike_indices]

## Step 3: Burst detection
Bursts are sequences of spikes with interspike interval (ISI) below 0.3 s.

In [None]:
isi = np.diff(spike_times)
burst_threshold = 0.3  # s
bursts = []
current_burst = [0]

for i in range(1, len(spike_times)):
    if isi[i-1] < burst_threshold:
        current_burst.append(i)
    else:
        if len(current_burst) > 1:
            bursts.append(current_burst)
        current_burst = [i]
if len(current_burst) > 1:
    bursts.append(current_burst)

print(f"Detected {len(bursts)} bursts")

# Compute internal ISI per spike (ms)
isi_per_spike_burst = np.zeros(len(spike_times))
for burst in bursts:
    for i, idx in enumerate(burst):
        isi_per_spike_burst[idx] = 0 if i==0 else (spike_times[idx]-spike_times[burst[i-1]])*1000

## Step 4: Burst classification
Classify bursts based on minimum voltage relative to surrounding baseline: Square Wave, Parabolic, Other.

In [None]:
square_wave_bursts, parabolic_bursts, other_bursts = [], [], []
burst_types = {}

for i, burst in enumerate(bursts):
    burst_mask = (time >= spike_times[burst[0]]) & (time <= spike_times[burst[-1]])
    burst_min = np.min(signal[burst_mask])
    prev_mean = np.mean(signal[(time > spike_times[bursts[i-1][-1]]) & (time < spike_times[burst[0]])]) if i>0 else np.nan
    next_mean = np.mean(signal[(time > spike_times[burst[-1]]) & (time < spike_times[bursts[i+1][0]])]) if i<len(bursts)-1 else np.nan
    inter_mean = np.nanmean([prev_mean, next_mean])
    if burst_min > inter_mean:
        square_wave_bursts.append(burst)
        burst_types[tuple(burst)] = "Square Wave"
    elif burst_min < inter_mean:
        parabolic_bursts.append(burst)
        burst_types[tuple(burst)] = "Parabolic"
    else:
        other_bursts.append(burst)
        burst_types[tuple(burst)] = "Other"

## Step 5: Save bursts info to CSV

In [None]:
burst_list = []
for idx, burst in enumerate(bursts):
    burst_type = burst_types[tuple(burst)]
    burst_list.append([idx+1, spike_times[burst[0]], spike_times[burst[-1]], burst_type])

df_bursts_all = pd.DataFrame(burst_list, columns=["Burst_Number","Start_Time_s","End_Time_s","Type"])
df_bursts_all.to_csv("burst_basic_info_cell89_all_bursts.csv", index=False)
df_bursts_all.head(10)

## Step 6: Visibility graphs and embeddings

In [None]:
colors_map = {"Square Wave":"blue", "Parabolic":"green", "Other":"orange"}
nodes_list, edges_list = [], []

for b_idx, burst in enumerate(bursts):
    burst_type = burst_types[tuple(burst)]
    x_peaks = np.arange(len(burst))
    y_peaks = isi_per_spike_burst[burst]

    G = nx.Graph()
    G.add_nodes_from(range(len(burst)))
    for a in range(len(x_peaks)):
        for b in range(a+1, len(x_peaks)):
            visible = True
            for c in range(a+1,b):
                y_line = y_peaks[b]+(y_peaks[a]-y_peaks[b])*(x_peaks[b]-x_peaks[c])/(x_peaks[b]-x_peaks[a])
                if y_peaks[c] >= y_line:
                    visible=False
                    break
            if visible:
                G.add_edge(a,b)
                edges_list.append([b_idx+1, burst_type, a, b])

    A = nx.to_numpy_array(G)
    n_dim = min(3,A.shape[0])
    embedding_2d = TruncatedSVD(n_components=2, random_state=42).fit_transform(A)
    embedding_3d = TruncatedSVD(n_components=n_dim, random_state=42).fit_transform(A) if n_dim>=3 else np.zeros((len(burst),3))

    for i in range(len(burst)):
        nodes_list.append([b_idx+1, burst_type, i,
                           embedding_2d[i,0], embedding_2d[i,1],
                           embedding_3d[i,0], embedding_3d[i,1], embedding_3d[i,2],
                           burst[i]])

## Step 7: Save nodes and edges CSV

In [None]:
df_nodes = pd.DataFrame(nodes_list, columns=["Burst_Number","Type","Node_ID","X_2D","Y_2D","X_3D","Y_3D","Z_3D","Spike_Global_Index"])
df_edges = pd.DataFrame(edges_list, columns=["Burst_Number","Type","Node1_ID","Node2_ID"])

df_nodes.to_csv("burst_nodes_all.csv", index=False)
df_edges.to_csv("burst_edges_all.csv", index=False)

## Step 8: Visualization
Plot voltage trace, spikes, bursts, visibility graphs, and embeddings for the first few bursts.

In [None]:
plt.figure(figsize=(12,4))
plt.plot(time, signal, lw=0.2)
plt.plot(spike_times, signal[spike_indices], 'r.', markersize=3)

for burst in square_wave_bursts:
    plt.axvspan(spike_times[burst[0]], spike_times[burst[-1]], color='blue', alpha=0.3)
for burst in parabolic_bursts:
    plt.axvspan(spike_times[burst[0]], spike_times[burst[-1]], color='green', alpha=0.3)
for burst in other_bursts:
    plt.axvspan(spike_times[burst[0]], spike_times[burst[-1]], color='orange', alpha=0.3)

plt.xlabel("Time (s)")
plt.ylabel("Voltage (mV)")
plt.title("Voltage Trace with Spikes and Burst Types")
plt.show()