In [None]:
# This notebook analyzes neuronal spikes and bursts from ABF files
# It includes spike detection, burst detection, classification, and graph-based embeddings
# All previous comments are preserved and expanded for clarity

import pyabf  # Library to load ABF files
import numpy as np  # Fundamental numerical operations
import pandas as pd  # Data handling and CSV export
import matplotlib.pyplot as plt  # Plotting
import networkx as nx  # Visibility graph creation and analysis
from scipy.signal import find_peaks  # Peak/spike detection
from sklearn.decomposition import TruncatedSVD  # Dimensionality reduction for graph embeddings
%matplotlib widget  # Interactive plotting in Jupyter

# Load ABF file
file_path = "bursting/cell89basal.abf"  # Path to ABF file
abf = pyabf.ABF(file_path)  # Load ABF file

# Concatenate all sweeps into a single continuous voltage trace
# Each sweep is set using abf.setSweep(i) and its data appended
signal = np.concatenate([abf.setSweep(i) or abf.sweepY for i in range(abf.sweepCount)])
dt = 1.0 / abf.dataRate  # Sampling interval in seconds
time = np.arange(len(signal)) * dt  # Time vector corresponding to the signal

print(f"File: {file_path} | sweeps: {abf.sweepCount} | total length: {time[-1]:.2f} s")

# Spike Detection
threshold = -35  # Voltage threshold for spike detection in mV
# find_peaks returns indices where signal exceeds threshold
spike_indices, _ = find_peaks(signal, height=threshold)
spike_times = time[spike_indices]  # Convert spike indices to times

# Burst Detection
# Calculate inter-spike intervals (ISI)
isi = np.diff(spike_times)
burst_threshold = 0.3  # Maximum ISI for spikes to belong to same burst (s)

bursts = []  # List to hold detected bursts as tuples (start_time, end_time)
current_burst = [spike_times[0]]  # Initialize first burst with first spike

# Iterate through ISIs to group spikes into bursts
for i in range(1, len(isi)):
    if isi[i-1] < burst_threshold:
        current_burst.append(spike_times[i])  # Spike belongs to current burst
    else:
        if len(current_burst) > 1:  # Save bursts with more than one spike
            bursts.append((current_burst[0], current_burst[-1]))
        current_burst = [spike_times[i]]  # Start new burst
if len(current_burst) > 1:
    bursts.append((current_burst[0], current_burst[-1]))

# Burst Classification
square_wave_bursts = []  # Bursts with min voltage above local baseline
parabolic_bursts = []    # Bursts with min voltage below local baseline
other_bursts = []        # Bursts that do not fit the above patterns

for i, (burst_start, burst_end) in enumerate(bursts):
    burst_mask = (time >= burst_start) & (time <= burst_end)  # Mask signal for burst
    burst_min = np.min(signal[burst_mask])  # Minimum voltage during burst

    # Compute baseline mean before and after the burst
    prev_mean = np.mean(signal[(time > bursts[i-1][1]) & (time < burst_start)]) if i > 0 else np.nan
    next_mean = np.mean(signal[(time > burst_end) & (time < bursts[i+1][0])]) if i < len(bursts)-1 else np.nan
    inter_mean = np.nanmean([prev_mean, next_mean])  # Average baseline

    # Classify burst type based on minimum relative to baseline
    if burst_min > inter_mean:
        square_wave_bursts.append((burst_start, burst_end))
    elif burst_min < inter_mean:
        parabolic_bursts.append((burst_start, burst_end))
    else:
        other_bursts.append((burst_start, burst_end))

# Save burst information as CSV
burst_list = []
for idx, (start, end) in enumerate(square_wave_bursts):
    burst_list.append([idx+1, start, end, "Square Wave"])
for idx, (start, end) in enumerate(parabolic_bursts):
    burst_list.append([idx+1+len(square_wave_bursts), start, end, "Parabolic"])
for idx, (start, end) in enumerate(other_bursts):
    burst_list.append([idx+1+len(square_wave_bursts)+len(parabolic_bursts), start, end, "Other"])

df_bursts = pd.DataFrame(burst_list, columns=["Burst_Number", "Start_Time_s", "End_Time_s", "Type"])
df_bursts.to_csv("burst_info_cell89.csv", index=False)
print(df_bursts)

# Plot spikes and burst types
plt.figure(figsize=(12,4))
plt.plot(time, signal, lw=0.2)  # Raw voltage trace
plt.plot(spike_times, signal[spike_indices], 'r.', markersize=3)  # Overlay detected spikes

for start, end in square_wave_bursts:
    plt.axvspan(start, end, color='blue', alpha=0.3)
for start, end in parabolic_bursts:
    plt.axvspan(start, end, color='green', alpha=0.3)
for start, end in other_bursts:
    plt.axvspan(start, end, color='orange', alpha=0.3)

plt.xlabel("Time (s)")
plt.ylabel("Voltage (mV)")
plt.title("Detected Spikes and Burst Types")
plt.show()

# Visibility Graph and Embeddings
colors_map = {"Square Wave":"blue", "Parabolic":"green", "Other":"orange"}  # Node colors
nodes_list, edges_list = [], []  # Lists to store node and edge information

# Compute internal ISI for each spike in each burst (ms)
isi_per_spike_burst = np.zeros(len(spike_times))
for burst in bursts:
    for i, idx in enumerate(burst if isinstance(burst,list) else range(len(burst))):
        isi_per_spike_burst[idx] = 0 if i==0 else (spike_times[idx]-spike_times[burst[i-1]])*1000

# Dictionary to store burst type
burst_types = {}
for i, burst in enumerate(bursts):
    burst_mask = (time >= spike_times[burst[0]]) & (time <= spike_times[burst[-1]])
    burst_min = np.min(signal[burst_mask])
    prev_mean = np.mean(signal[(time > spike_times[bursts[i-1][-1]]) & (time < spike_times[burst[0]])]) if i>0 else np.nan
    next_mean = np.mean(signal[(time > spike_times[burst[-1]]) & (time < spike_times[bursts[i+1][0]])]) if i<len(bursts)-1 else np.nan
    inter_mean = np.nanmean([prev_mean, next_mean])
    if burst_min > inter_mean:
        square_wave_bursts.append(burst)
        burst_types[tuple(burst)] = "Square Wave"
    elif burst_min < inter_mean:
        parabolic_bursts.append(burst)
        burst_types[tuple(burst)] = "Parabolic"
    else:
        other_bursts.append(burst)
        burst_types[tuple(burst)] = "Other"

# Save nodes and edges with embeddings
for b_idx, burst in enumerate(bursts[:15]):  # Plot and compute for first 15 bursts
    burst_type = burst_types[tuple(burst)]
    x_peaks = np.arange(len(burst))
    y_peaks = isi_per_spike_burst[burst]
    G = nx.Graph()
    G.add_nodes_from(range(len(burst)))
    for a in range(len(x_peaks)):
        for b in range(a+1, len(x_peaks)):
            visible = True
            for c in range(a+1,b):
                y_line = y_peaks[b]+(y_peaks[a]-y_peaks[b])*(x_peaks[b]-x_peaks[c])/(x_peaks[b]-x_peaks[a])
                if y_peaks[c] >= y_line:
                    visible=False
                    break
            if visible:
                G.add_edge(a,b)
                edges_list.append([b_idx+1, burst_type, a, b])
    A = nx.to_numpy_array(G)
    n_dim = min(3,A.shape[0])
    embedding_2d = TruncatedSVD(n_components=2, random_state=42).fit_transform(A)
    embedding_3d = TruncatedSVD(n_components=n_dim, random_state=42).fit_transform(A) if n_dim>=3 else np.zeros((len(burst),3))
    for i in range(len(burst)):
        nodes_list.append([b_idx+1, burst_type, i, embedding_2d[i,0], embedding_2d[i,1], embedding_3d[i,0], embedding_3d[i,1], embedding_3d[i,2], burst[i]])
    # Plot visibility graph and embeddings
    fig = plt.figure(figsize=(12,10))
    ax0 = fig.add_subplot(3,1,1)
    for u,v in G.edges():
        ax0.plot([x_peaks[u], x_peaks[v]], [y_peaks[u], y_peaks[v]], 'gray', alpha=0.5)
    ax0.scatter(x_peaks, y_peaks, color=colors_map[burst_type], s=40)
    ax0.set_title(f"Burst {b_idx+1} Visibility graph ({burst_type})")
    ax0.set_xlabel("Spike index")
    ax0.set_ylabel("ISI (ms)")
    ax0.grid(True)
    ax1 = fig.add_subplot(3,1,2)
    for i in range(len(burst)):
        ax1.scatter(embedding_2d[i,0], embedding_2d[i,1], color=colors_map[burst_type], s=50)
        ax1.text(embedding_2d[i,0]+0.01, embedding_2d[i,1]+0.01, str(i), fontsize=8)
    for u,v in G.edges():
        ax1.plot([embedding_2d[u,0], embedding_2d[v,0]], [embedding_2d[u,1], embedding_2d[v,1]], 'r-', alpha=0.3)
    ax1.set_title(f"Burst {b_idx+1} Embedding 2D")
    ax1.grid(True)
    ax2 = fig.add_subplot(3,1,3, projection='3d')
    for i in range(len(burst)):
        ax2.scatter(embedding_3d[i,0], embedding_3d[i,1], embedding_3d[i,2], color=colors_map[burst_type], s=40)
        ax2.text(embedding_3d[i,0], embedding_3d[i,1], embedding_3d[i,2], str(i), fontsize=8)
    for u,v in G.edges():
        ax2.plot([embedding_3d[u,0], embedding_3d[v,0]], [embedding_3d[u,1], embedding_3d[v,1]], [embedding_3d[u,2], embedding_3d[v,2]], 'r-', alpha=0.3)
    ax2.set_title(f"Burst {b_idx+1} Embedding 3D")
    plt.tight_layout()
    plt.show()

# Save CSV of nodes and edges
df_nodes = pd.DataFrame(nodes_list, columns=["Burst_Number","Type","Node_ID","X_2D","Y_2D","X_3D","Y_3D","Z_3D","Spike_Global_Index"])
df_edges = pd.DataFrame(edges_list, columns=["Burst_Number","Type","Node1_ID","Node2_ID"])
df_nodes.to_csv("burst_nodes_all.csv", index=False)
df_edges.to_csv("burst_edges_all.csv", index=False)