In [None]:
# Imports and Setup
import pyabf  # ABF file reader
import numpy as np  # Numerical operations
import matplotlib.pyplot as plt  # Plotting
from scipy.signal import find_peaks  # Spike detection
from scipy.interpolate import interp1d  # Rescaling and interpolation
from tslearn.preprocessing import TimeSeriesScalerMinMax  # Normalization for shapelets
from tslearn.shapelets import LearningShapelets  # Shapelet model
from tensorflow.keras.optimizers import Adam  # Optimizer for shapelet learning
from matplotlib import cm  # Color maps for plotting
%matplotlib widget  # Interactive plots

# Load and concatenate ABF sweeps
file_path = "bursting/cell89basal.abf"  # Path to ABF file
abf = pyabf.ABF(file_path)  # Load ABF
signal = np.concatenate([abf.setSweep(i) or abf.sweepY for i in range(abf.sweepCount)])
dt = 1.0 / abf.dataRate  # Sampling interval in seconds
time = np.arange(len(signal)) * dt  # Time vector

# Spike detection
def detect_spikes(signal, threshold=-35):
    """Detect spikes using a voltage threshold"""
    spike_indices, _ = find_peaks(signal, height=threshold)
    return spike_indices

spike_indices = detect_spikes(signal)
spike_times = time[spike_indices]  # Convert spike indices to time

# Burst detection based on inter-spike intervals (ISI)
def detect_bursts(spike_times, burst_threshold=0.3):
    """Detect bursts from spike times using ISI threshold"""
    isi = np.diff(spike_times)
    bursts = []
    current_burst = [spike_times[0]]
    for i in range(1, len(isi)):
        if isi[i-1] < burst_threshold:
            current_burst.append(spike_times[i])
        else:
            if len(current_burst) > 1:
                bursts.append((current_burst[0], current_burst[-1]))
            current_burst = [spike_times[i]]
    if len(current_burst) > 1:
        bursts.append((current_burst[0], current_burst[-1]))
    return bursts

bursts = detect_bursts(spike_times)
print(f"Total detected bursts: {len(bursts)}")

# Burst classification
def classify_bursts(bursts, signal, time):
    """Classify bursts as Square Wave, Parabolic or Other"""
    square_wave_bursts = []
    parabolic_bursts = []
    other_bursts = []
    for i, (burst_start, burst_end) in enumerate(bursts):
        burst_mask = (time >= burst_start) & (time <= burst_end)
        burst_min = np.min(signal[burst_mask])
        prev_mean = np.mean(signal[(time > bursts[i-1][1]) & (time < burst_start)]) if i > 0 else np.nan
        next_mean = np.mean(signal[(time > burst_end) & (time < bursts[i+1][0])]) if i < len(bursts)-1 else np.nan
        inter_mean = np.nanmean([prev_mean, next_mean])
        if burst_min > inter_mean:
            square_wave_bursts.append((burst_start, burst_end))
        elif burst_min < inter_mean:
            parabolic_bursts.append((burst_start, burst_end))
        else:
            other_bursts.append((burst_start, burst_end))
    return square_wave_bursts, parabolic_bursts, other_bursts

square_wave_bursts, parabolic_bursts, other_bursts = classify_bursts(bursts, signal, time)

# Burst normalization functions
def normalize_y(signal_segment):
    """Normalize amplitude (z-score)"""
    return (signal_segment - np.mean(signal_segment)) / np.std(signal_segment)

def rescale_x(time_segment, signal_segment, n_points=100):
    """Rescale duration to n_points"""
    f = interp1d(np.linspace(0, 1, len(signal_segment)), signal_segment)
    return f(np.linspace(0, 1, n_points))

def extract_normalized_bursts(burst_list, signal, time, n_points=100):
    """Return normalized bursts (for shapelet learning)"""
    normalized_bursts = []
    for start, end in burst_list:
        mask = (time >= start) & (time <= end)
        s_burst = signal[mask]
        s_rescaled = rescale_x(time[mask], s_burst, n_points)
        s_normalized = normalize_y(s_rescaled)
        normalized_bursts.append(s_normalized)
    return normalized_bursts

n_points = 100
square_bursts_normalized = extract_normalized_bursts(square_wave_bursts, signal, time, n_points)
parabolic_bursts_normalized = extract_normalized_bursts(parabolic_bursts, signal, time, n_points)
other_bursts_normalized = extract_normalized_bursts(other_bursts, signal, time, n_points)

all_normalized_bursts = square_bursts_normalized + parabolic_bursts_normalized + other_bursts_normalized
labels = ([0]*len(square_bursts_normalized) +
          [1]*len(parabolic_bursts_normalized) +
          [2]*len(other_bursts_normalized))

# Shapelet learning
X = np.array(all_normalized_bursts)[:, :, np.newaxis]  # Add feature dimension
y = np.array(labels)
X = TimeSeriesScalerMinMax().fit_transform(X)  # Normalize between 0 and 1

n_shapelets_per_size = {X.shape[1]: 2}  # 2 shapelets per burst length
shp_clf = LearningShapelets(
    n_shapelets_per_size=n_shapelets_per_size,
    weight_regularizer=0.0001,
    optimizer=Adam(0.01),
    max_iter=300,
    verbose=0,
    scale=False,
    random_state=42
)
shp_clf.fit(X, y)

# Transform bursts to 2D distances for visualization
distances = shp_clf.transform(X).reshape((-1, 2))

# Plot shapelets and 2D burst embeddings
%matplotlib inline
viridis = cm.get_cmap('viridis', 4)
fig = plt.figure(constrained_layout=True, figsize=(10,6))
gs = fig.add_gridspec(3, 9)
fig_ax1 = fig.add_subplot(gs[0, :2])
fig_ax2 = fig.add_subplot(gs[0, 2:4])
fig_ax4 = fig.add_subplot(gs[:, 4:])

# Plot shapelets
fig_ax1.plot(shp_clf.shapelets_[0].flatten())
fig_ax1.set_title('Shapelet $\mathbf{s}_1$')
fig_ax2.plot(shp_clf.shapelets_[1].flatten())
fig_ax2.set_title('Shapelet $\mathbf{s}_2$')

# Scatter plot of bursts in 2D distance space
for i, label in enumerate(np.unique(y)):
    mask = y == label
    fig_ax4.scatter(
        distances[mask][:,0],
        distances[mask][:,1],
        c=[viridis(i / max(1, len(np.unique(y)) - 1))]*np.sum(mask),
        edgecolors='k',
        label=f'Class {label}'
    )

# Decision boundaries for visualization
xmin, xmax = distances[:,0].min()-0.1, distances[:,0].max()+0.1
ymin, ymax = distances[:,1].min()-0.1, distances[:,1].max()+0.1
xx, yy = np.meshgrid(np.linspace(xmin, xmax, 200), np.linspace(ymin, ymax, 200))
W, b = shp_clf.model_.get_layer("classification").get_weights()
n_classes = len(np.unique(y))
Z = []
for x_val, y_val in np.c_[xx.ravel(), yy.ravel()]:
    if n_classes == 2:
        logit = b[0] + W[0,0]*x_val + W[1,0]*y_val
        pred = int(logit >= 0)
    else:
        scores = [b[i] + W[0,i]*x_val + W[1,i]*y_val for i in range(n_classes)]
        pred = np.argmax(scores)
    Z.append(pred)
Z = np.array(Z).reshape(xx.shape)
fig_ax4.contourf(xx, yy, Z/max(1, n_classes-1), cmap=viridis, alpha=0.25)
fig_ax4.set_xlabel('$d(\mathbf{x}, \mathbf{s}_1)$')
fig_ax4.set_ylabel('$d(\mathbf{x}, \mathbf{s}_2)$')
fig_ax4.set_xlim(xmin, xmax)
fig_ax4.set_ylim(ymin, ymax)
fig_ax4.set_title('Distance transformed bursts')
fig_ax4.legend()
plt.show()