# Spike and Burst Analysis in ABF Files

This notebook walks through the processing of electrophysiological recordings in ABF format. We will detect spikes and bursts, extract features from each segment, and visualize the results using UMAP and plotting libraries.

In [ ]:
# Load libraries
import os  # Work with directories and files
import pyabf  # Read ABF files
import numpy as np
import pandas as pd
from scipy.stats import skew, kurtosis  # Statistical metrics
from scipy.signal import find_peaks  # Detect spikes
%matplotlib widget

# Dimensionality reduction & ML
import umap  # UMAP for nonlinear dimensionality reduction
from sklearn.preprocessing import StandardScaler  # Standardize features

# Plotting
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.io as pio

# Force Plotly to display in browser
pio.renderers.default = "browser"

## Parameters

- `folder_path`: folder containing ABF files
- `threshold`: voltage threshold for spike detection
- `burst_threshold`: ISI threshold to define bursts
- `fs`: sampling frequency
- `dt`: time step between samples

In [ ]:
# Parameters
folder_path = "bursting"
threshold = -35
burst_threshold = 0.3
fs = 10000
dt = 1 / fs

# Storage for all segment metrics
all_segment_metrics = []

## Processing ABF files
For each ABF file and sweep, we extract spikes and bursts, classify bursts, and compute metrics for each segment.

In [ ]:
# List all ABF files in the folder
abf_files = [f for f in os.listdir(folder_path) if f.endswith(".abf")]

for file_name in abf_files:
    file_path = os.path.join(folder_path, file_name)
    abf = pyabf.ABF(file_path)

    for sweep in range(abf.sweepCount):
        abf.setSweep(sweep)
        time = abf.sweepX  # time vector
        voltage = abf.sweepY  # voltage trace

        # Spike detection
        peaks, _ = find_peaks(voltage, height=threshold)
        spike_times = peaks / fs  # convert indices to time

### Burst detection
- Define bursts as sequences of spikes with interspike intervals below `burst_threshold`.
- Only bursts with ≥2 spikes are considered valid.

In [ ]:
        bursts = []
        if len(spike_times) >= 5:
            isi = np.diff(spike_times)
            current_burst = [spike_times[0]]

            for i in range(1, len(spike_times)):
                if isi[i - 1] < burst_threshold:
                    current_burst.append(spike_times[i])
                else:
                    if len(current_burst) > 1:
                        bursts.append((current_burst[0], current_burst[-1]))
                    current_burst = [spike_times[i]]
            if len(current_burst) > 1:
                bursts.append((current_burst[0], current_burst[-1]))

### Segmentation
- Label each time segment as 'Burst' or 'Non-burst'.
- Compute burst type: 'Square Wave', 'Parabolic', or 'Other', based on comparison with baseline voltage.

In [ ]:
        segments = []
        for start, end in bursts:
            segments.append((start, end, "Burst"))

        if bursts:
            if bursts[0][0] > time[0]:
                segments.insert(0, (time[0], bursts[0][0], "Non-burst"))
            for i in range(len(bursts)-1):
                segments.append((bursts[i][1], bursts[i+1][0], "Non-burst"))
            if bursts[-1][1] < time[-1]:
                segments.append((bursts[-1][1], time[-1], "Non-burst"))
        else:
            segments.append((time[0], time[-1], "Non-burst"))

### Metrics extraction
- Compute statistics (mean, std, min, max, skewness, kurtosis, area) for voltage and derivative.
- Count spikes and calculate mean ISI.

In [ ]:
        burst_types = []
        for i, (seg_start, seg_end, seg_type) in enumerate(segments):
            if seg_type == "Burst":
                seg_mask = (time >= seg_start) & (time <= seg_end)
                burst_min = np.min(voltage[seg_mask])
                prev_mean = np.nan
                next_mean = np.nan
                if i > 0:
                    prev_end = segments[i-1][1]
                    inter_mask_prev = (time > prev_end) & (time < seg_start)
                    if np.any(inter_mask_prev):
                        prev_mean = np.mean(voltage[inter_mask_prev])
                if i < len(segments)-1:
                    next_start = segments[i+1][0]
                    inter_mask_next = (time > seg_end) & (time < next_start)
                    if np.any(inter_mask_next):
                        next_mean = np.mean(voltage[inter_mask_next])
                inter_mean = np.nanmean([prev_mean, next_mean])
                if np.isnan(inter_mean):
                    burst_type = "Other"
                elif burst_min > inter_mean:
                    burst_type = "Square Wave"
                elif burst_min < inter_mean:
                    burst_type = "Parabolic"
                else:
                    burst_type = "Other"
            else:
                burst_type = "Non-burst"
            burst_types.append(burst_type)

### DataFrame and UMAP visualization
- Save all segment metrics into a CSV.
- Use `StandardScaler` and `UMAP` for dimensionality reduction.
- Visualize clusters with Seaborn and Plotly.

In [ ]:
# Create DataFrame and save
df_segments = pd.DataFrame(all_segment_metrics, columns=[
    "File_Name", "Sweep", "Segment_Type", "Burst_Type", "Duration",
    "Num_Peaks", "Mean_ISI",
    "Mean", "Std", "Min", "Max", "Skewness", "Kurtosis", "Area",
    "Mean_Deriv", "Std_Deriv", "Min_Deriv", "Max_Deriv", "Skewness_Deriv", "Kurtosis_Deriv"
])
df_segments.to_csv("segment_voltage_metrics_with_ISI_and_peaks.csv", index=False)

# UMAP projection
sns.set(style='white', context='notebook', rc={'figure.figsize':(14,10)})
df = pd.read_csv("segment_voltage_metrics_with_ISI_and_peaks.csv")
header = list(df)[4:]
X = np.array(df[header])
y = np.array(df.Burst_Type)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

reducer = umap.UMAP(n_components=2, random_state=42)
embedding = reducer.fit_transform(X_scaled)

df_umap = pd.DataFrame(embedding, columns=["UMAP1", "UMAP2"])
df_umap["Burst_Type"] = y

sns.scatterplot(data=df_umap, x="UMAP1", y="UMAP2", hue="Burst_Type", palette="tab10", alpha=0.7)
plt.title("UMAP segments metrics")
plt.show()