In [ ]:
# ==== Imports and setup ====
import os
import pyabf
import numpy as np
import pandas as pd
from scipy.stats import skew, kurtosis
from scipy.signal import find_peaks, hilbert
from numpy.linalg import lstsq
from sklearn.metrics import r2_score
%matplotlib widget
import umap
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'browser'
# Parameters
folder_path = 'bursting'
threshold = -35
burst_threshold = 0.3
fs = 10000
dt = 1/fs
min_spikes_in_burst = 4
bin_size = 0.5
N = 3
all_segment_metrics = []
abf_files = [f for f in os.listdir(folder_path) if f.endswith('.abf')]

In [ ]:
# ==== Process ABF files: spike detection, burst detection, segment definition ====
for file_name in abf_files:
    file_path = os.path.join(folder_path, file_name)
    abf = pyabf.ABF(file_path)
    for sweep in range(abf.sweepCount):
        abf.setSweep(sweep)
        time = abf.sweepX
        voltage = abf.sweepY
        peaks, _ = find_peaks(voltage, height=threshold)
        spike_times = peaks / fs
        bursts = []
        if len(spike_times) >= min_spikes_in_burst:
            isi = np.diff(spike_times)
            current_burst = [spike_times[0]]
            for i in range(1, len(spike_times)):
                if isi[i-1] < burst_threshold:
                    current_burst.append(spike_times[i])
                else:
                    if len(current_burst) >= min_spikes_in_burst:
                        bursts.append((current_burst[0], current_burst[-1]))
                    current_burst = [spike_times[i]]
            if len(current_burst) >= min_spikes_in_burst:
                bursts.append((current_burst[0], current_burst[-1]))
        segments = []
        for start,end in bursts:
            segments.append((start,end,'Burst'))
        if bursts:
            if bursts[0][0] > time[0]: segments.insert(0,(time[0],bursts[0][0],'Non-burst'))
            for i in range(len(bursts)-1): segments.append((bursts[i][1],bursts[i+1][0],'Non-burst'))
            if bursts[-1][1] < time[-1]: segments.append((bursts[-1][1],time[-1],'Non-burst'))
        else:
            segments.append((time[0],time[-1],'Non-burst'))

In [ ]:
# ==== Classify bursts and compute metrics ====
for i,(seg_start,seg_end,seg_type) in enumerate(segments):
    seg_mask = (time >= seg_start) & (time <= seg_end)
    seg_voltage = voltage[seg_mask]
    if seg_type=='Burst':
        prev_mean = np.nan
        next_mean = np.nan
        if i>0:
            prev_end = segments[i-1][1]
            inter_mask_prev = (time>prev_end) & (time<seg_start)
            if np.any(inter_mask_prev): prev_mean = np.mean(voltage[inter_mask_prev])
        if i<len(segments)-1:
            next_start = segments[i+1][0]
            inter_mask_next = (time>seg_end) & (time<next_start)
            if np.any(inter_mask_next): next_mean = np.mean(voltage[inter_mask_next])
        inter_mean = np.nanmean([prev_mean,next_mean])
        burst_min = np.min(seg_voltage)
        if np.isnan(inter_mean): burst_type='Other'
        elif burst_min>inter_mean: burst_type='Square Wave'
        elif burst_min<inter_mean: burst_type='Parabolic'
        else: burst_type='Other'
    else: burst_type='Non-burst'

    seg_peaks,_ = find_peaks(seg_voltage,height=threshold)
    num_peaks = len(seg_peaks)
    mean_isi = np.mean(np.diff(seg_peaks/fs)) if num_peaks>=2 else 0
    if len(seg_voltage)<2:
        mean_val=std_val=min_val=max_val=skew_val=kurt_val=area_val=0
        mean_d=std_d=min_d=max_d=skew_d=kurt_d=0
    else:
        mean_val=np.mean(seg_voltage)
        std_val=np.std(seg_voltage)
        min_val=np.min(seg_voltage)
        max_val=np.max(seg_voltage)
        skew_val=skew(seg_voltage)
        kurt_val=kurtosis(seg_voltage)
        area_val=np.trapz(seg_voltage,dx=dt)
        deriv=np.diff(seg_voltage)/dt
        mean_d=np.mean(deriv)
        std_d=np.std(deriv)
        min_d=np.min(deriv)
        max_d=np.max(deriv)
        skew_d=skew(deriv)
        kurt_d=kurtosis(deriv)
    all_segment_metrics.append([file_name,sweep,seg_type,burst_type,seg_start,seg_end,seg_end-seg_start,
                                num_peaks,mean_isi,mean_val,std_val,min_val,max_val,skew_val,kurt_val,area_val,
                                mean_d,std_d,min_d,max_d,skew_d,kurt_d])

In [ ]:
# ==== Create DataFrame, scale features, run UMAP ====
df_segments=pd.DataFrame(all_segment_metrics,columns=[
    'File_Name','Sweep','Segment_Type','Burst_Type','Segment_Start','Segment_End','Duration',
    'Num_Peaks','Mean_ISI','Mean','Std','Min','Max','Skewness','Kurtosis','Area',
    'Mean_Deriv','Std_Deriv','Min_Deriv','Max_Deriv','Skewness_Deriv','Kurtosis_Deriv'])
header=['Duration','Num_Peaks','Mean_ISI','Mean','Std','Min','Max','Skewness','Kurtosis','Area',
        'Mean_Deriv','Std_Deriv','Min_Deriv','Max_Deriv','Skewness_Deriv','Kurtosis_Deriv']
X_scaled=StandardScaler().fit_transform(df_segments[header])
embedding=umap.UMAP(n_components=2,random_state=42).fit_transform(X_scaled)
df_segments['UMAP1']=embedding[:,0]
df_segments['UMAP2']=embedding[:,1]
df_segments.to_csv('segment_voltage_metrics_with_ISI_and_peaks.csv',index=False)
print(f'CSV saved with {len(df_segments)} segments and UMAP coordinates.')

In [ ]:
# ==== Conflict detection ====
df_segments['UMAP_bin']=(df_segments['UMAP1']//bin_size).astype(int).astype(str)+'_'+(df_segments['UMAP2']//bin_size).astype(int).astype(str)
conflict_bins=df_segments.groupby('UMAP_bin')['Burst_Type'].nunique()
conflict_bins=conflict_bins.loc[lambda x: x>1].index
conflict_data=df_segments[df_segments['UMAP_bin'].isin(conflict_bins)]
print(f'{len(conflict_data)} segments in conflict regions.')

# Plot first N conflicts
pairs=[]
for cell,group in conflict_data.groupby('UMAP_bin'):
    types=group['Burst_Type'].unique()
    if len(types)>1:
        for i,row1 in group.iterrows():
            for j,row2 in group.iterrows():
                if i<j and row1['Burst_Type']!=row2['Burst_Type']:
                    pairs.append((row1,row2))
print(f'{len(pairs)} conflict pairs detected.')