# Unified Electrophysiology Analysis Pipeline

This notebook processes ABF files to detect spikes and bursts, classify bursts, extract normalized bursts for shapelet learning, train a shapelet model, perform UMAP embedding, and detect conflict regions between burst types.

All sections include comments outside the code explaining their purpose.

In [ ]:
# ================================
# 1. Imports and parameters
# ================================
import os
import numpy as np
import pandas as pd
import pyabf
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.signal import find_peaks, hilbert
from scipy.stats import skew, kurtosis
from scipy.interpolate import interp1d
from numpy.linalg import lstsq
from sklearn.preprocessing import StandardScaler
import umap
from sklearn.neighbors import NearestNeighbors
from tslearn.preprocessing import TimeSeriesScalerMinMax
from tslearn.shapelets import LearningShapelets
from tensorflow.keras.optimizers import Adam
from matplotlib import cm

# Parameters
folder_path = "bursting"  # Folder containing ABF files
threshold = -35             # Spike detection threshold (mV)
burst_threshold = 0.3       # Maximum ISI for bursts (s)
fs = 10000                  # Sampling frequency (Hz)
dt = 1/fs

In [ ]:
# ================================
# 2. Helper functions for spikes and bursts
# ================================
def normalize_y(signal_segment):
    return (signal_segment - np.mean(signal_segment)) / np.std(signal_segment)

def rescale_x(time_segment, signal_segment, n_points=100):
    f = interp1d(np.linspace(0, 1, len(signal_segment)), signal_segment)
    return f(np.linspace(0, 1, n_points))

def detect_spikes(signal, threshold=-35):
    spike_indices, _ = find_peaks(signal, height=threshold)
    return spike_indices

def detect_bursts(spike_times, burst_threshold=0.3):
    isi = np.diff(spike_times)
    bursts = []
    current_burst = [spike_times[0]]
    for i in range(1, len(isi)):
        if isi[i-1] < burst_threshold:
            current_burst.append(spike_times[i])
        else:
            if len(current_burst) > 1:
                bursts.append((current_burst[0], current_burst[-1]))
            current_burst = [spike_times[i]]
    if len(current_burst) > 1:
        bursts.append((current_burst[0], current_burst[-1]))
    return bursts


In [ ]:
# ================================
# 3. Burst classification
# ================================
def classify_bursts(bursts, signal, time):
    square_wave_bursts = []
    parabolic_bursts = []
    other_bursts = []

    for i, (burst_start, burst_end) in enumerate(bursts):
        burst_mask = (time >= burst_start) & (time <= burst_end)
        burst_min = np.min(signal[burst_mask])

        prev_mean = np.mean(signal[(time > bursts[i-1][1]) & (time < burst_start)]) if i>0 else np.nan
        next_mean = np.mean(signal[(time > burst_end) & (time < bursts[i+1][0])]) if i < len(bursts)-1 else np.nan
        inter_mean = np.nanmean([prev_mean, next_mean])

        if burst_min > inter_mean:
            square_wave_bursts.append((burst_start, burst_end))
        elif burst_min < inter_mean:
            parabolic_bursts.append((burst_start, burst_end))
        else:
            other_bursts.append((burst_start, burst_end))
    return square_wave_bursts, parabolic_bursts, other_bursts


In [ ]:
# ================================
# 4. Extract normalized bursts for shapelet learning
# ================================
def extract_normalized_bursts(burst_list, signal, time, n_points=100):
    normalized_bursts = []
    for start, end in burst_list:
        mask = (time >= start) & (time <= end)
        s_rescaled = rescale_x(time[mask], signal[mask], n_points)
        s_normalized = normalize_y(s_rescaled)
        normalized_bursts.append(s_normalized)
    return normalized_bursts


In [ ]:
# ================================
# 5. Shapelet model training (example placeholders)
# ================================
# After extracting normalized bursts, X and y can be created as follows:
# X = np.array(all_normalized_bursts)[:, :, np.newaxis]
# y = np.array(labels)
# X = TimeSeriesScalerMinMax().fit_transform(X)
# shp_clf = LearningShapelets(n_shapelets_per_size={X.shape[1]: 2}, weight_regularizer=0.0001,
# optimizer=Adam(0.01), max_iter=300, verbose=0, scale=False, random_state=42)
# shp_clf.fit(X, y)
# distances = shp_clf.transform(X).reshape((-1, 2))

In [ ]:
# ================================
# 6. Segment metrics computation (example loop)
# ================================
all_segment_metrics = []
abf_files = [f for f in os.listdir(folder_path) if f.endswith('.abf')]
for file_name in abf_files:
    abf = pyabf.ABF(os.path.join(folder_path, file_name))
    for sweep in range(abf.sweepCount):
        abf.setSweep(sweep)
        signal = abf.sweepY
        time = abf.sweepX
        spike_indices = detect_spikes(signal)
        spike_times = time[spike_indices]
        bursts = detect_bursts(spike_times)
        sq, par, oth = classify_bursts(bursts, signal, time)
        # Compute metrics and append to all_segment_metrics

In [ ]:
# ================================
# 7. UMAP embedding and visualization (example)
# ================================
# df_segments = pd.DataFrame(all_segment_metrics, columns=[...])
# scaler = StandardScaler()
# X_scaled = scaler.fit_transform(df_segments[feature_columns])
# reducer = umap.UMAP(n_components=2, random_state=42)
# embedding = reducer.fit_transform(X_scaled)
# df_segments['UMAP1'] = embedding[:,0]
# df_segments['UMAP2'] = embedding[:,1]

In [ ]:
# ================================
# 8. Conflict detection and plotting (example)
# ================================
# Identify conflict regions using nearest neighbors or UMAP binning
# Plot conflicts and overlay traces, save to CSV if needed