# Spike and Burst Analysis with Visibility Graphs

This notebook demonstrates detection of spikes and bursts in an ABF file, classifies bursts, and constructs visibility graphs with 2D and 3D embeddings for bursts.

In [None]:
# Load libraries
import pyabf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
%matplotlib widget

## Load ABF file and concatenate sweeps
We combine all sweeps from the ABF into a single continuous voltage signal.

In [None]:
file_path = "bursting/cell89basal.abf"
abf = pyabf.ABF(file_path)

signal = np.concatenate([abf.setSweep(i) or abf.sweepY for i in range(abf.sweepCount)])
dt = 1.0 / abf.dataRate
time = np.arange(len(signal)) * dt

print(f"File: {file_path} | sweeps: {abf.sweepCount} | total length: {time[-1]:.2f} s")

## Spike detection
Detect spikes using a voltage threshold.

In [None]:
threshold = -35  # mV
spike_indices, _ = find_peaks(signal, height=threshold)
spike_times = time[spike_indices]

## Burst detection
Define bursts as sequences of spikes with interspike interval (ISI) below `burst_threshold`.

In [None]:
isi = np.diff(spike_times)
burst_threshold = 0.3  # seconds

bursts = []
current_burst = [spike_times[0]]
for i in range(1, len(isi)):
    if isi[i-1] < burst_threshold:
        current_burst.append(spike_times[i])
    else:
        if len(current_burst) > 1:
            bursts.append((current_burst[0], current_burst[-1]))
        current_burst = [spike_times[i]]
if len(current_burst) > 1:
    bursts.append((current_burst[0], current_burst[-1]))

## Burst classification
Classify bursts as 'Square Wave', 'Parabolic', or 'Other' by comparing the minimum voltage of the burst with surrounding baseline.

In [None]:
square_wave_bursts, parabolic_bursts, other_bursts = [], [], []

for i, (burst_start, burst_end) in enumerate(bursts):
    burst_mask = (time >= burst_start) & (time <= burst_end)
    burst_min = np.min(signal[burst_mask])

    prev_mean = np.mean(signal[(time > bursts[i-1][1]) & (time < burst_start)]) if i>0 else np.nan
    next_mean = np.mean(signal[(time > burst_end) & (time < bursts[i+1][0])]) if i < len(bursts)-1 else np.nan
    inter_mean = np.nanmean([prev_mean, next_mean])

    if burst_min > inter_mean:
        square_wave_bursts.append((burst_start, burst_end))
    elif burst_min < inter_mean:
        parabolic_bursts.append((burst_start, burst_end))
    else:
        other_bursts.append((burst_start, burst_end))

## Save burst info to CSV

In [None]:
burst_list = []

for idx, (start, end) in enumerate(square_wave_bursts):
    burst_list.append([idx+1, start, end, "Square Wave"])
for idx, (start, end) in enumerate(parabolic_bursts):
    burst_list.append([idx+1+len(square_wave_bursts), start, end, "Parabolic"])
for idx, (start, end) in enumerate(other_bursts):
    burst_list.append([idx+1+len(square_wave_bursts)+len(parabolic_bursts), start, end, "Other"])

df_bursts = pd.DataFrame(burst_list, columns=["Burst_Number", "Start_Time_s", "End_Time_s", "Type"])
df_bursts.to_csv("burst_info_cell89.csv", index=False)
print(df_bursts)

## Plot spikes and bursts
Visualize spikes as red dots and bursts as shaded regions.

In [None]:
plt.figure(figsize=(12,4))
plt.plot(time, signal, lw=0.2)
plt.plot(spike_times, signal[spike_indices], 'r.', markersize=3)

for start, end in square_wave_bursts:
    plt.axvspan(start, end, color='blue', alpha=0.3)
for start, end in parabolic_bursts:
    plt.axvspan(start, end, color='green', alpha=0.3)
for start, end in other_bursts:
    plt.axvspan(start, end, color='orange', alpha=0.3)

plt.xlabel("Time (s)")
plt.ylabel("Voltage (mV)")
plt.title("Detected Spikes and Burst Types")
plt.show()