# Burst detection of action potentials

In this tutorial, I will explain how to analyze bursts of action potentials using Python. To read the full tutorial, please see [Patch-clamp data analysis in Python: bursts of action potentials](https://spikesandbursts.wordpress.com/2023/08/24/patch-clamp-data-analysis-in-python-bursts/) of [Spikes and Bursts](https://spikesandbursts.wordpress.com/) blog.




# Import the libraries

In [None]:
import numpy as np
import pandas as pd
import os

import pyabf

import scipy
from scipy import signal
from scipy.signal import find_peaks
from scipy.optimize import curve_fit
from scipy.stats import skew, kurtosis

import matplotlib.pyplot as plt

# If you want to display interactive plots using the ipympl backend:
# %matplotlib widget 
# Ipympl creates multiple interactive plots so you need to close them:
# plt.close('all') 

# Example data

Example data is the file **stg_pd_bursts.abf**. The file is a segment of the file "877_093_0003" from [Haley et al., 2018](https://elifesciences.org/articles/41877): https://osf.io/r7aes/)

# Create the paths

In [None]:
notebook_name = 'action_potentials_bursts'

# Data path to 'Data_example' folders. Change accordingly to your data structure.
data_path = os.path.dirname(os.getcwd())  # Moves one level up from the current directory

# Change the folder names accordingly
paths = {'data':  f'{data_path}/Data',
         'processed_data': f'{data_path}/Processed_data/{notebook_name}',
         'analysis': f'{data_path}/Analysis/{notebook_name}'}

# Make folders if they do not exist yet
for path in paths.values():
    os.makedirs(path, exist_ok=True)

# Load the data

In [None]:
# ABF file/s
filename = "stg_pd_bursts"

data_path = f"{paths['data']}/{filename}.abf" 
abf = pyabf.ABF(data_path)
print(abf)

# Sampling rate
fs = int(abf.dataPointsPerMs * 1000)

# Quick plot to see the trace/s
plt.figure(figsize=(8,4))

for sweepNumber in abf.sweepList:
    abf.setSweep(sweepNumber)
    plt.plot(abf.sweepX, abf.sweepY)
    plt.ylabel(abf.sweepLabelY)
    plt.xlabel(abf.sweepLabelX)

plt.show()

# Select the sweep and/or channel
# abf.setSweep(10)  # Sweep
# abf.setSweep(sweepNumber=0, channel=0)  # Sweep and channel

# Pre-process the signal: filtering

In [None]:
# Sampling rate
fs = int(abf.dataPointsPerMs * 1000)

# Lowpass Bessel filter
b_lowpass, a_lowpass = signal.bessel(4,     # Order of the filter
                                     2000,  # Cutoff frequency
                                     'low', # Type of filter
                                     analog=False,  # Analog or digital filter
                                     norm='phase',  # Critical frequency normalization
                                     fs=fs)  # fs: sampling frequency

signal_filtered = signal.filtfilt(b_lowpass, a_lowpass, abf.sweepY)

# Find Peaks

To select a range within the race you have to slice the slice the peaks_signal and time. E.g. `peaks_signal = abf.sweepY[50000:100000]` and `time = abf.sweepX[50000:100000]`. If you want the absolute peak times, add the amount of time you substracted (e.g., 50000/fs). 

In [None]:
# Assign the variables here to simplify the code
time = abf.sweepX
peaks_signal = abf.sweepY  # Or signal_filtered

# Set parameters for the Find peaks function (set to None if not needed)
thresh_min = -25                    # Min threshold to detect spikes
thresh_prominence = 15              # Min spike amplitude  
thresh_min_width = 0.5 * (fs/1000)  # Min required width in ms
distance_min = 1 * (fs/1000)        # Min horizontal distance between peaks
pretrigger_window = (1.5 * fs)/1000
posttrigger_window = (2 * fs)/1000

# Find peaks function
peaks, peaks_dict = find_peaks(peaks_signal, 
           height=thresh_min, 
           threshold=thresh_min,  
           distance=distance_min,  
           prominence=thresh_prominence,  
           width=thresh_min_width, 
           wlen=None,       # Window length to calculate prominence
           rel_height=0.5,  # Relative height at which the peak width is measured
           plateau_size=None)
 
# Create table with results
spikes_table = pd.DataFrame(columns = ['spike', 'spike_index', 'spike_time',
                                       'inst_freq', 'isi_s',
                                       'width', 'rise_half_ms', 'decay_half_ms',
                                       'spike_peak', 'spike_amplitude'])

spikes_table.spike = np.arange(1, len(peaks) + 1)
spikes_table.spike_index = peaks
spikes_table.spike_time = peaks / fs  # Divided by fs to get s
spikes_table.isi_s = np.diff(peaks, axis=0, prepend=peaks[0]) / fs
spikes_table.inst_freq = 1 / spikes_table.isi_s
spikes_table.width = peaks_dict['widths']/(fs/1000) # Width (ms) at half-height
spikes_table.rise_half_ms = (peaks - peaks_dict['left_ips'])/(fs/1000) 
spikes_table.decay_half_ms = (peaks_dict['right_ips'] - peaks)/(fs/1000)
spikes_table.spike_peak = peaks_dict['peak_heights']  # height parameter is needed
spikes_table.spike_amplitude = peaks_dict['prominences']  # prominence parameter is needed
     
# Plot the detected spikes in the trace
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(time, peaks_signal)

# Red dot on each detected spike
ax.plot(peaks/fs, peaks_signal[peaks], "r.")

# Add a number to each detected peak
# for i, txt in enumerate(spikes_table.spike):  
#     ax1.annotate(spikes_table.spike[i], (peaks[i]/fs, peaks_signal[peaks][i]))

ax.set_title("Event detection")  
ax.set_xlabel("Time (s)")
ax.set_ylabel("Voltage (mV)")
# ax.axes.set_xlim(0, 10000)  # Zoom in the trace

# Save the plot and the table
fig.savefig(f"{paths['analysis']}/{filename}_spikes_plot.png", dpi=300)
spikes_table.to_csv(f"{paths['analysis']}/{filename}_spikes_results.csv", index=False)

# Show graph and table
plt.show()
spikes_table

# Estimate the interspike intervals of bursts

Visually inspect the histogram to identify the two firing modes and the valley between them, and quantify how simmetrical is the distribution using:

* [Cumulative moving average](https://pmc.ncbi.nlm.nih.gov/articles/PMC3378047/).
* [scipy.stats.skew](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.skew.html) quantifies how symmetrical the distribution is. Symmetrical distribution (skew = 0), asymmetrical with long right tail (positive skew), asymmetrical with long left tail (negative skew). Skew > 1 or <-1 is substantial.
* [scipy.stats.kurtosis](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.kurtosis.html) quantifies whether the tails of the data distribution matches the Gaussian distribution. Kurtosis of 0 (gaussian), negative (fewer tail values than gaussian), positive (more values in the tails than gaussian). 


In [None]:
# Assign ISI data to this variable
hist_data = spikes_table['isi_s']

# Empty DataFrame for histogram stats
hist_stats = pd.DataFrame()

# Bin size
bin_size = 10  # in miliseconds

# Histogram
isi_range = np.ptp(hist_data)
bins = int((isi_range * 1000 / bin_size) + 0.5)  # Round to the nearest integer
hist = np.histogram(hist_data, bins=bins)
hist_counts = hist[0]
hist_bins = hist[1]

# Cumulative moving average
cum = np.cumsum(hist_counts)  # Cumulative sum
cma = cum / np.arange(1, len(cum) + 1)

# Calculate peaks and valleys of the cma
cma_peaks_indexes = scipy.signal.argrelextrema(cma, np.greater)
cma_valleys_indexes = scipy.signal.argrelextrema(cma, np.less)

# Select the peak you're interested in
peak_index = cma_peaks_indexes[0][0]  # Change second number to select the peak
alpha = cma[peak_index] * 0.5  # Half-peak, adapt the value to your threshold criterion

# Calculate cma_threshold_index relative to the selected cma_peak
cma_threshold = (np.argmin(cma[peak_index:] >= alpha) + peak_index) * bin_size/1000 

# Dataframe with histogram statistics
length = len(hist_stats)
hist_stats.loc[length, 'mean_isi'] = np.mean(hist_data)
hist_stats.loc[length, 'median_isi'] = np.median(hist_data)
hist_stats.loc[length, 'kurtosis'] = kurtosis(hist_counts)
hist_stats.loc[length, 'skewness'] = skew(hist_counts, bias=True)
hist_stats.loc[length, 'cma_threshold'] = cma_threshold
hist_stats.loc[length, 'cma_valley_time'] = cma_valleys_indexes[0][1] * bin_size/1000  # Change peak index as needed
hist_stats.loc[length, 'cma_peak_time'] = cma_peaks_indexes[0][0] * bin_size/1000  # Change peak index as needed

# Plot ISI histogram
fig, ax = plt.subplots(figsize=(8, 4))
ax.set_title("ISI histogram") 
ax.hist(hist_data, bins=bins, alpha=0.6)

# Plot CMA
cma_x = np.linspace(np.min(hist_bins), np.max(hist_bins), bins)
ax.plot(cma_x, cma)

# Plot CMA threshold line
ax.axvline(cma_threshold, linestyle="dotted", color="gray")

# Plot CMA valleys
ax.plot(cma_x[cma_valleys_indexes], cma[cma_valleys_indexes], 'ko')
ax.plot(cma_x[cma_peaks_indexes], cma[cma_peaks_indexes], 'mo')

# ax.set_xscale('log')  # Logarithmic scale may be easier to set the threshold
ax.set_xlabel("Time bins (s)")
ax.set_ylabel("Count")

# Save the table and the plot
fig.savefig(f"{paths['analysis']}/{filename}_isi_plot.png", dpi=300)  # or svg
hist_stats.to_csv(f"{paths['analysis']}/{filename}_isi_stats.csv", index=False)

# Show graph and table
plt.show()
hist_stats

# Burst detection: Maximum interval method

The function `detect_bursts` takes the Pandas DataFrame with event parameters detected by `FindPeaks` and calculates bursts of events. The function uses the MaxInterval algorithm (see [Cotterill et al., 2016](https://pubmed.ncbi.nlm.nih.gov/27098024/)) to detect bursts with a minimum number of spikes and minimum interburst interval.  

The function sorts the input DataFrame by spike positions, creates a new column for burst labels, and iterates through the spikes to assign them to bursts. The function then filters out any bursts with fewer spikes than the minimum required, and calculates burst information by grouping the spikes by burst label. Finally, it returns a DataFrame containing burst information.


## Function

In [None]:
def burst_detection(df, spike_times, spike_amplitudes, spike_peaks,
                    n_spikes, 
                    max_isi, 
                    # min_duration,  # Optional
                    min_ibi):
    
    """
    Detects bursts in spike data based on spike times, 
    by identifying consecutive spikes that fulfill the criteria of:
    minimum number of spikes, maximum interspike interval, and minimum interburst interval.
    
    Arguments: 
        df: DataFrame with spike data.
        spike_times: Column name for spike positions.
        spike_amplitudes: Column name for spike amplitudes.
        spike_peaks: Column name for spike peak amplitudes. 
        n_spikes: Minimum number of spikes within a burst.
        max_isi: Max interspike interval within the burst.
        min_duration: Minimum burst duration.
        min_ibi: Minimum interburst interval (optional).
        
    
    Returns:
        DataFrame with burst information.
    """
    
    df = df.sort_values(by=spike_times)  # Sort DataFrame by spike positions
    df['burst'] = np.nan  # Create column for burst labels
    burst_num = 0        # Initialize burst number
    burst_start = None   # Initialize burst start position
    last_spike = None    # Initialize last spike position

    for i, row in df.iterrows():  # Loop through DataFrame rows 
        spike = row[spike_times]   # Extract the spike position 
        
        if burst_start is None:   # It checks if it is the first spike 
            burst_start = spike   # It marks the current spike position as the start of a burst
            last_spike = spike    # Update the last_spike position to the current spike position
            df.at[i, 'burst'] = burst_num   # Assign burst number
        elif spike - last_spike <= max_isi:  # It checks if the current spike is within max isi
            df.at[i, 'burst'] = burst_num  
            last_spike = spike 
        elif spike - last_spike > min_ibi:  # It checks if the interburst interval has been reached
            burst_num += 1  
            burst_start = spike 
            last_spike = spike  
            df.at[i, 'burst'] = burst_num  
    
    # Filter bursts with less than min_spikes
    df = df[df.groupby('burst')[spike_times].transform('count') >= n_spikes]
    
    # Filter burst shorter that min_duration (min_duration parameter)
    # df = df[df.groupby('burst')[spike_times].transform('max') 
    #         - df.groupby('burst')[spike_times].transform('min')
    #         >= min_duration]
    
    # Calculate burst information by aggregating single spike information
    bursts = df.groupby('burst')[spike_times].agg(['min', 'max', 'count'])
    bursts.columns = ['burst_start', 'burst_end', 'spikes_in_bursts']
    bursts['burst_length'] = bursts['burst_end'] - bursts['burst_start']
    bursts['avg_spike_amplitude'] = df.groupby('burst')[spike_amplitudes].mean()
    bursts['avg_spike_peaks'] = df.groupby('burst')[spike_peaks].mean()  
    bursts['spikes_frequency'] = bursts['spikes_in_bursts'] / bursts['burst_length']
    bursts = bursts.reset_index()
    bursts['burst_number'] = bursts.index + 1
    
    
    return bursts[['burst_number', 'burst_start', 'burst_end', 
                   'burst_length', 'spikes_in_bursts', 'avg_spike_amplitude', 
                   'avg_spike_peaks', 'spikes_frequency']]

In [None]:
help(burst_detection)

## Table and plot

In [None]:
# Burst table
bursts = burst_detection(spikes_table,  # Dataframe with spike positions as input data  
                         'spike_time', 
                         'spike_amplitude',
                         'spike_peak',
                         n_spikes = 2,  
                         max_isi = 0.1,
                         # min_duration = 0.5,  # Optional
                         min_ibi = 0.2)  


# Plotting: create figure and axis
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6))

# Plot 1: trace and detected bursts
ax1.plot(time, peaks_signal, color='gray')
# Plot red dots for detected events
ax1.scatter(spikes_table['spike_time'], spikes_table['spike_peak'], color="magenta", s=10)

# Plot the detected bursts 
for i, burst in bursts.iterrows():
    burst_start = burst['burst_start']
    burst_end = burst['burst_end']
    burst_number = int(burst['burst_number'])
    
    # Set the height of the burst line
    # spike_peaks = burst['avg_spike_peaks'] + 5  # Option A
    spike_peaks = np.median(spikes_table.spike_peak) + 5  # Option B
    
    # Plot an horizontal line from beginning to the end of the bursts
    ax1.plot([burst_start, burst_end], [spike_peaks, spike_peaks], 'black')
    # Annotate each line with the burst number
    ax1.annotate(str(burst_number),  xy=(burst_start, spike_peaks), 
                xytext=(burst_start, spike_peaks + 1))

# Set title and show plot
ax1.set_title("Burst detection") 
ax1.set_ylabel("Voltage (mV)")
ax1.set_xlabel("Time (s)")

# Remove top and right frame borders
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.axes.set_xlim(0, 4)  # OptionaL: Zoom in the trace

# Plot 2: single burst 
ax2.set_title("Burst viewer")
burst_number = 2  # Change here the burst number

# Plot the signal with detected spikes
ax2.plot(time, peaks_signal, color='gray', label=burst_number)
ax2.scatter(spikes_table['spike_time'], spikes_table['spike_peak'], color="magenta", s=10)

# Burst time window + 0.1 s before and after
burst_start = bursts.loc[bursts['burst_number'] == burst_number, 'burst_start'].values[0]
burst_end = bursts.loc[bursts['burst_number'] == burst_number, 'burst_end'].values[0]
burst_line_y = bursts.loc[bursts['burst_number'] == burst_number, 'avg_spike_peaks'].values[0] + 5
ax2.plot([burst_start, burst_end], [burst_line_y, burst_line_y], 'black')
ax2.set_xlim(burst_start - 0.1, burst_end + 0.1) 

# Label the plot
ax2.set_ylabel("Voltage (mV)")
ax2.set_xlabel("Time (s)")
ax2.legend()
fig.tight_layout()

# Save plot and table
fig.savefig(f"{paths['analysis']}/{filename}_bursts_plot.png", dpi=300)  # or svg
bursts.to_csv(f"{paths['analysis']}/{filename}_bursts_results.csv", index=False)

# Display the plots and table
plt.show()
bursts

## Summary statistics

In [None]:
# Summary statistics
burst_number = len(bursts)
spikes_in_bursts = np.sum(bursts.spikes_in_bursts)
spikes_bursts_pct = (spikes_in_bursts / len(spikes_table.spike)) * 100
mean_burst_duration = np.mean(bursts.burst_length)

# Create a DataFrame 
bursts_stats = pd.DataFrame({
    'Recording': filename,
    'Number of bursts': [burst_number],
    'Spikes in Bursts': [spikes_in_bursts],
    'Spikes in Bursts (%)': [spikes_bursts_pct],
    'Mean Burst Duration': [mean_burst_duration]
})

# Save plot and table
bursts_stats.to_csv(f"{paths['analysis']}/{filename}_bursts_stats.csv", index=False)

bursts_stats