In [None]:
# Notebook: Spike detection, burst classification, visibility graphs and embeddings
# Content: load ABF(s), detect spikes, group bursts, classify bursts, build visibility graphs,
# compute embeddings (TruncatedSVD), compute segment metrics and visualize with UMAP/plots.

import os
import pyabf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import find_peaks, hilbert
from scipy.stats import skew, kurtosis
from numpy.linalg import lstsq
from sklearn.metrics import r2_score
import networkx as nx
from sklearn.decomposition import TruncatedSVD
import umap
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'browser'
%matplotlib widget

# Parameters
folder_path = 'bursting'
example_file = 'bursting/cell89basal.abf'
threshold = -35
burst_threshold = 0.3
default_fs = 10000
dt = 1.0 / default_fs
min_spikes_in_burst = 4

all_segment_metrics = []
global_nodes_list = []
global_edges_list = []

def detect_spikes(signal, data_rate, threshold_local=None):
    thr = threshold_local if threshold_local is not None else threshold
    peaks, props = find_peaks(signal, height=thr)
    times = peaks / data_rate
    return peaks, times, props

def detect_bursts_from_spike_times(spike_times, burst_isi_thresh=burst_threshold, min_spikes=min_spikes_in_burst):
    bursts = []
    bursts_idx = []
    if len(spike_times) < min_spikes:
        return bursts, bursts_idx
    isi = np.diff(spike_times)
    current_indices = [0]
    for i in range(1, len(spike_times)):
        if isi[i-1] < burst_isi_thresh:
            current_indices.append(i)
        else:
            if len(current_indices) >= min_spikes:
                bursts_idx.append(current_indices.copy())
                bursts.append((spike_times[current_indices[0]], spike_times[current_indices[-1]]))
            current_indices = [i]
    if len(current_indices) >= min_spikes:
        bursts_idx.append(current_indices.copy())
        bursts.append((spike_times[current_indices[0]], spike_times[current_indices[-1]]))
    return bursts, bursts_idx

def classify_burst_segment(voltage, time, seg_start, seg_end, seg_index, segments):
    seg_mask = (time >= seg_start) & (time <= seg_end)
    seg_voltage = voltage[seg_mask]
    if seg_voltage.size == 0:
        return 'Other'
    prev_mean = np.nan
    next_mean = np.nan
    if seg_index > 0:
        prev_end = segments[seg_index-1][1]
        inter_mask_prev = (time > prev_end) & (time < seg_start)
        if np.any(inter_mask_prev):
            prev_mean = np.mean(voltage[inter_mask_prev])
    if seg_index < len(segments)-1:
        next_start = segments[seg_index+1][0]
        inter_mask_next = (time > seg_end) & (time < next_start)
        if np.any(inter_mask_next):
            next_mean = np.mean(voltage[inter_mask_next])
    inter_mean = np.nanmean([prev_mean, next_mean])
    burst_min = np.min(seg_voltage)
    t = np.linspace(-0.5, 0.5, len(seg_voltage))
    X = np.vstack([t**2, t, np.ones_like(t)]).T
    try:
        coef, *_ = lstsq(X, seg_voltage, rcond=None)
        y_hat = X @ coef
        r2_parabola = r2_score(seg_voltage, y_hat)
    except Exception:
        r2_parabola = 0
    try:
        analytic_signal = hilbert(seg_voltage)
        envelope = np.abs(analytic_signal)
        half = len(envelope)//2
        symmetry = np.corrcoef(envelope[:half], envelope[-half:][::-1])[0,1]
    except Exception:
        symmetry = np.nan
    if np.isnan(inter_mean):
        return 'Other'
    elif burst_min > inter_mean:
        return 'Square Wave'
    elif burst_min < inter_mean:
        return 'Parabolic'
    else:
        return 'Other'

def build_visibility_graph(isi_ms):
    n = len(isi_ms)
    G = nx.Graph()
    G.add_nodes_from(range(n))
    for a in range(n):
        for b in range(a+1, n):
            visible = True
            for c in range(a+1, b):
                y_line = isi_ms[b] + (isi_ms[a]-isi_ms[b])*(b-c)/(b-a)
                if isi_ms[c] >= y_line:
                    visible = False
                    break
            if visible:
                G.add_edge(a,b)
    return G

# Main loop
abf_files = [f for f in os.listdir(folder_path) if f.lower().endswith('.abf')]
if len(abf_files)==0 and os.path.exists(example_file):
    abf_files = [os.path.basename(example_file)]

for file_name in abf_files:
    file_path = os.path.join(folder_path,file_name)
    try:
        abf = pyabf.ABF(file_path)
    except Exception as e:
        print(f'Cannot load {file_name}: {e}')
        continue
    data_rate = getattr(abf,'dataRate',default_fs)
    for sweep in range(abf.sweepCount):
        abf.setSweep(sweep)
        time = np.array(abf.sweepX)
        voltage = np.array(abf.sweepY)
        peaks, spike_times, props = detect_spikes(voltage, data_rate)
        bursts, bursts_idx = detect_bursts_from_spike_times(spike_times)
        segments = []
        for start,end in bursts:
            segments.append((start,end,'Burst'))
        if bursts:
            if bursts[0][0]>time[0]: segments.insert(0,(time[0],bursts[0][0],'Non-burst'))
            for i in range(len(bursts)-1): segments.append((bursts[i][1],bursts[i+1][0],'Non-burst'))
            if bursts[-1][1]<time[-1]: segments.append((bursts[-1][1],time[-1],'Non-burst'))
        else:
            segments.append((time[0],time[-1],'Non-burst'))
        burst_types = []
        for i,(s0,s1,t) in enumerate(segments):
            if t=='Burst': btype=classify_burst_segment(voltage,time,s0,s1,i,segments)
            else: btype='Non-burst'
            burst_types.append(btype)
        for j,(s0,s1,t) in enumerate(segments):
            mask=(time>=s0)&(time<=s1)
            v=voltage[mask]
            duration=s1-s0
            btype=burst_types[j]
            seg_peaks,_=find_peaks(v,height=threshold)
            num_peaks=len(seg_peaks)
            mean_isi=np.mean(np.diff(seg_peaks/data_rate)) if num_peaks>=2 else 0
            mean_val=np.mean(v) if len(v)>0 else 0
            std_val=np.std(v) if len(v)>0 else 0
            min_val=np.min(v) if len(v)>0 else 0
            max_val=np.max(v) if len(v)>0 else 0
            skew_val=skew(v) if len(v)>2 else 0
            kurt_val=kurtosis(v) if len(v)>2 else 0
            area_val=np.trapz(v,dx=dt) if len(v)>1 else 0
            deriv=np.diff(v)/dt if len(v)>1 else [0]
            mean_d=np.mean(deriv)
            std_d=np.std(deriv)
            min_d=np.min(deriv)
            max_d=np.max(deriv)
            skew_d=skew(deriv) if len(deriv)>2 else 0
            kurt_d=kurtosis(deriv) if len(deriv)>2 else 0
            all_segment_metrics.append([file_name,sweep,t,btype,duration,num_peaks,mean_isi,mean_val,std_val,min_val,max_val,skew_val,kurt_val,area_val,mean_d,std_d,min_d,max_d,skew_d,kurt_d])

# Save CSVs
df=pd.DataFrame(all_segment_metrics,columns=['File_Name','Sweep','Segment_Type','Burst_Type','Duration','Num_Peaks','Mean_ISI','Mean','Std','Min','Max','Skewness','Kurtosis','Area','Mean_Deriv','Std_Deriv','Min_Deriv','Max_Deriv','Skewness_Deriv','Kurtosis_Deriv'])
df.to_csv('segment_voltage_metrics_with_ISI_and_peaks.csv',index=False)
print(f'Saved metrics: {len(df)} rows')

sns.set(style='white',context='notebook',rc={'figure.figsize':(10,7)})
X=df.iloc[:,4:]
y=df['Burst_Type']
scaler=StandardScaler()
X_scaled=scaler.fit_transform(X.fillna(0))
reducer=umap.UMAP(random_state=42)
embedding=reducer.fit_transform(X_scaled)
df_umap=pd.DataFrame(embedding,columns=['UMAP1','UMAP2'])
df_umap['Burst_Type']=y
sns.scatterplot(data=df_umap,x='UMAP1',y='UMAP2',hue='Burst_Type')
plt.show()
print('Notebook done.')
