# Spike and Burst Detection, Normalization, and Shapelet Learning

Full pipeline: ABF loading → spike detection → burst detection → burst normalization → burst classification → shapelet learning.

In [None]:
import pyabf
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
from scipy.interpolate import interp1d
%matplotlib widget

In [None]:
# Load ABF
file_path = "bursting/cell89basal.abf"
abf = pyabf.ABF(file_path)

signal = np.concatenate([abf.setSweep(i) or abf.sweepY for i in range(abf.sweepCount)])
dt = 1.0 / abf.dataRate
time = np.arange(len(signal)) * dt
print(f"Loaded {file_path}, duration: {time[-1]:.2f}s")

In [None]:
# Spike detection
threshold = -35
spike_indices, _ = find_peaks(signal, height=threshold)
spike_times = time[spike_indices]

In [None]:
# Burst detection
burst_threshold = 0.3
isi = np.diff(spike_times)
bursts = []
current_burst = [spike_times[0]]
for i in range(1, len(isi)):
    if isi[i-1] < burst_threshold:
        current_burst.append(spike_times[i])
    else:
        if len(current_burst) > 1:
            bursts.append((current_burst[0], current_burst[-1]))
        current_burst = [spike_times[i]]
if len(current_burst) > 1:
    bursts.append((current_burst[0], current_burst[-1]))
print(f"Total bursts detected: {len(bursts)}")

In [None]:
# Burst classification
square_wave_bursts, parabolic_bursts, other_bursts = [], [], []
for i, (start, end) in enumerate(bursts):
    burst_mask = (time >= start) & (time <= end)
    burst_min = np.min(signal[burst_mask])
    prev_mean = np.mean(signal[(time > bursts[i-1][1]) & (time < start)]) if i>0 else np.nan
    next_mean = np.mean(signal[(time > end) & (time < bursts[i+1][0])]) if i < len(bursts)-1 else np.nan
    inter_mean = np.nanmean([prev_mean, next_mean])
    if burst_min > inter_mean:
        square_wave_bursts.append((start, end))
    elif burst_min < inter_mean:
        parabolic_bursts.append((start, end))
    else:
        other_bursts.append((start, end))
print(f"Square Wave: {len(square_wave_bursts)}, Parabolic: {len(parabolic_bursts)}, Other: {len(other_bursts)}")

In [None]:
# Normalization functions
def normalize_y(segment):
    return (segment - np.mean(segment)) / np.std(segment)

def rescale_x(t_segment, segment, n_points=100):
    f = interp1d(np.linspace(0,1,len(segment)), segment)
    return f(np.linspace(0,1,n_points))

In [None]:
def extract_normalized_bursts(burst_list, signal, time, n_points=100):
    normalized_bursts = []
    for start, end in burst_list:
        mask = (time >= start) & (time <= end)
        segment = signal[mask]
        segment_rescaled = rescale_x(time[mask], segment, n_points)
        segment_normalized = normalize_y(segment_rescaled)
        normalized_bursts.append(segment_normalized)
    return normalized_bursts

square_bursts_norm = extract_normalized_bursts(square_wave_bursts, signal, time)
parabolic_bursts_norm = extract_normalized_bursts(parabolic_bursts, signal, time)
other_bursts_norm = extract_normalized_bursts(other_bursts, signal, time)

all_normalized_bursts = square_bursts_norm + parabolic_bursts_norm + other_bursts_norm
labels = ([0]*len(square_bursts_norm) + [1]*len(parabolic_bursts_norm) + [2]*len(other_bursts_norm))

In [None]:
# Shapelet learning
from tslearn.preprocessing import TimeSeriesScalerMinMax
from tslearn.shapelets import LearningShapelets
from tensorflow.keras.optimizers import Adam

X = np.array(all_normalized_bursts)[:, :, np.newaxis]
y = np.array(labels)
X = TimeSeriesScalerMinMax().fit_transform(X)

n_shapelets_per_size = {X.shape[1]: 2}
shp_clf = LearningShapelets(n_shapelets_per_size=n_shapelets_per_size,
                             weight_regularizer=0.0001,
                             optimizer=Adam(0.01),
                             max_iter=300,
                             verbose=0,
                             scale=False,
                             random_state=42)
shp_clf.fit(X, y)
distances = shp_clf.transform(X).reshape((-1, 2))

In [None]:
# Visualize distances
from matplotlib import cm
viridis = cm.get_cmap('viridis', 4)
plt.figure(figsize=(8,6))
plt.scatter(distances[:,0], distances[:,1], c=[viridis(l/2) for l in y], edgecolors='k')
plt.xlabel('Distance to Shapelet 1')
plt.ylabel('Distance to Shapelet 2')
plt.title('Distance-transformed bursts')
plt.show()