
# ABF — Simple Peak Workflow

**What this does (in order):**

1. Load a `.abf` file and choose a sweep.  
2. Pick a **baseline time window** and subtract its mean.  
3. Apply an **8‑pole Bessel low‑pass filter** (adjustable cutoff).  
4. Run a **threshold-based detector** (positive or negative going).  
5. Pull **snippets** around each detected peak (pre/post ms).  

**And it shows a separate plot at each step.**

> Minimal controls, minimal code. No exports or extras—just the essentials.


In [None]:
import pyabf  # noqa: F401
import ipywidgets as widgets
from scipy import signal 
import numpy as np
import matplotlib.pyplot as plt 
import io, os, tempfile
import pyabf
import ipywidgets as widgets
from IPython.display import display, clear_output, Markdown
plt.rcParams['figure.figsize'] = (10, 4)

In [None]:

def bessel_lowpass(y, fs, cutoff_hz, order=8):
    nyq = 0.5 * fs
    if cutoff_hz <= 0:
        return y
    if cutoff_hz >= nyq:
        cutoff_hz = nyq * 0.999
    b, a = signal.bessel(order, cutoff_hz/nyq, btype='low', norm='phase')
    return signal.filtfilt(b, a, y)

def bessel_lowpass_clampfit(y, fs, cutoff_hz, order=8, initial='zero'):
    """
    Clampfit-like 8-pole Bessel low-pass:
      - magnitude-normalized cutoff (–3 dB) at 'cutoff_hz' for a single pass
      - causal (single-pass) IIR, so it introduces phase delay (like Clampfit)
    Params
    ------
    y : 1D array
    fs : sampling rate (Hz)
    cutoff_hz : cutoff in Hz (–3 dB, single pass)
    order : 8 by default (Clampfit uses 8-pole IIR)
    initial : 'zero' (Clampfit-like startup transient) or 'steady' (reduced transient)
    """
    # Safety: keep cutoff below Nyquist
    cutoff_hz = float(min(cutoff_hz, 0.49*fs))

    # Design as Hz (preferred), fall back to normalized Wn if SciPy is older
    try:
        sos = signal.bessel(order, cutoff_hz, btype='low', analog=False,
                            output='sos', norm='mag', fs=fs)
    except TypeError:
        Wn = cutoff_hz / (fs/2.0)
        sos = signal.bessel(order, Wn, btype='low', analog=False,
                            output='sos', norm='mag')

    if initial == 'steady':
        # Optional: reduce startup transient by initializing state to steady value
        zi = signal.sosfilt_zi(sos)
        y_f, _ = signal.sosfilt(sos, y, zi=zi*y[0])
        return y_f
    else:
        # Clampfit describes IIR startup transients; this mirrors that behavior
        return signal.sosfilt(sos, y)

def detect_peaks_threshold(y, threshold, fs, polarity='positive',
                           lookahead_ms=2.0, min_dist_ms=1.0):
    y = np.asarray(y)
    if polarity == 'positive':
        xing = np.where((y[:-1] < threshold) & (y[1:] >= threshold))[0] + 1
        pick = np.argmax
    else:
        xing = np.where((y[:-1] > threshold) & (y[1:] <= threshold))[0] + 1
        pick = np.argmin
    la = max(1, int(round(lookahead_ms*1e-3*fs)))
    mind = max(1, int(round(min_dist_ms*1e-3*fs)))
    peaks, last = [], -10**12
    n = len(y)
    for ci in xing:
        end = min(ci+la, n)
        if ci >= end:
            continue
        local = ci + int(pick(y[ci:end]))
        if local - last >= mind:
            peaks.append(local)
            last = local
    return peaks

def snippets_around(y, peaks, fs, pre_ms, post_ms):
    pre = int(round(pre_ms*1e-3*fs))
    post = int(round(post_ms*1e-3*fs))
    win = pre+post+1
    out = []
    kept = []
    n = len(y)
    for p in peaks:
        s0, s1 = p-pre, p+post+1
        if s0 < 0 or s1 > n:
            continue  # keep simple: drop edge events
        out.append(y[s0:s1])
        kept.append(p)
    return (np.vstack(out) if out else np.empty((0, win))), kept, pre, post



In [None]:
# UI controls
upload = widgets.FileUpload(accept='.abf', multiple=False)
load_btn = widgets.Button(description='Load ABF', button_style='primary')

sweep = widgets.IntSlider(description='Sweep', min=0, max=0, step=1, value=0, disabled=True, continuous_update=False)
baseline = widgets.FloatRangeSlider(description='Baseline (s)', min=0.0, max=1.0, step=0.001, value=(0.0, 0.05),
                                    readout_format='.3f', disabled=True, continuous_update=False)
cutoff = widgets.FloatSlider(description='LPF cutoff (Hz)', min=1.0, max=5000.0, step=1.0, value=200.0,
                             disabled=True, continuous_update=False)
polarity = widgets.ToggleButtons(description='Polarity', options=['positive','negative'], value='positive', disabled=True)
thr = widgets.FloatSlider(description='Threshold', min=-100.0, max=100.0, step=0.1, value=0.0,
                          disabled=True, continuous_update=False)
pre_ms = widgets.FloatSlider(description='Pre (ms)', min=0.0, max=200.0, step=0.5, value=5.0, disabled=True, continuous_update=False)
post_ms = widgets.FloatSlider(description='Post (ms)', min=0.0, max=200.0, step=0.5, value=10.0, disabled=True, continuous_update=False)
run_btn = widgets.Button(description='Run analysis', button_style='success', disabled=True)

status = widgets.Output()
out_raw = widgets.Output()
out_base = widgets.Output()
out_filt = widgets.Output()
out_detect = widgets.Output()
out_snips = widgets.Output()
out_info = widgets.Output()

display(Markdown("### Load file"), widgets.HBox([upload, load_btn]), status)
display(Markdown("### Controls"), widgets.HBox([sweep, baseline]),
        widgets.HBox([cutoff, polarity, thr]), widgets.HBox([pre_ms, post_ms, run_btn]))
display(Markdown("### Step 1: Raw trace"), out_raw)
display(Markdown("### Step 2: After baseline subtraction"), out_base)
display(Markdown("### Step 3: After 8‑pole Bessel low‑pass"), out_filt)
display(Markdown("### Step 4: Detected peaks (on filtered signal)"), out_detect)
display(Markdown("### Step 5: Snippets overlay"), out_snips, out_info)

state = {'abf': None, 'fs': None, 'xlab': 'time (s)', 'ylab': 'amplitude'}

def _get_sweep():
    abf = state['abf']
    abf.setSweep(int(sweep.value))
    return abf.sweepX.copy(), abf.sweepY.copy()

def _plot_raw():
    if state['abf'] is None: return
    t, y = _get_sweep()
    with out_raw:
        clear_output(wait=True)
        plt.figure()
        plt.plot(t, y)
        plt.xlabel(state['xlab']); plt.ylabel(state['ylab']); plt.title(f"Sweep {sweep.value} — Raw")
        plt.show()

def _update_thr_range(yF):
    ymin, ymax = float(np.min(yF)), float(np.max(yF))
    pad = 0.05*(ymax-ymin+1e-12)
    thr.min = ymin - pad
    thr.max = ymax + pad
    thr.step = (thr.max-thr.min)/500.0
    thr.value = 0.0
    thr.disabled = False
    polarity.disabled = False

def _run(_b=None):
    if state['abf'] is None: return
    fs = state['fs']
    t, y = _get_sweep()

    # Baseline
    t0, t1 = baseline.value
    mask = (t>=t0) & (t<=t1)
    base = float(np.mean(y[mask])) if np.any(mask) else 0.0
    y0 = y - base
    with out_base:
        clear_output(wait=True)
        plt.figure()
        plt.plot(t, y0)
        plt.xlabel(state['xlab']); plt.ylabel(state['ylab'])
        plt.title(f"Baseline-subtracted (mean over [{t0:.3f}, {t1:.3f}] s = {base:.3g})")
        plt.show()

    # Filter
    # yF = bessel_lowpass(y0, fs, cutoff.value, order=8)
    yF = bessel_lowpass(y0, fs, cutoff.value, order=8)
    with out_filt:
        clear_output(wait=True)
        plt.figure()
        plt.plot(t, yF)
        plt.xlabel(state['xlab']); plt.ylabel(state['ylab'])
        plt.title(f"8‑pole Bessel low‑pass (cutoff={cutoff.value:.1f} Hz)")
        plt.show()

    # Detector
    if thr.disabled:
        _update_thr_range(yF)
    peaks = detect_peaks_threshold(yF, thr.value, fs, polarity.value, lookahead_ms=2.0, min_dist_ms=1.0)
    with out_detect:
        clear_output(wait=True)
        plt.figure()
        plt.plot(t, yF)
        plt.axhline(thr.value, linestyle='--')
        if len(peaks):
            plt.scatter(t[peaks], yF[peaks], s=20, marker='o')
        plt.xlabel(state['xlab']); plt.ylabel(state['ylab'])
        plt.title(f"Threshold detections: {len(peaks)} ({polarity.value}-going)")
        plt.show()

    # Snippets
    snips, kept, pre, post = snippets_around(yF, peaks, fs, pre_ms.value, post_ms.value)
    snip_t = np.arange(-pre, post+1)/fs
    with out_snips:
        clear_output(wait=True)
        plt.figure()
        if snips.size:
            for row in snips:
                plt.plot(snip_t, row)
            plt.axvline(0.0, linestyle='--')
            plt.xlabel('Time around peak (s)'); plt.ylabel(state['ylab'])
            plt.title(f"Snippets: {snips.shape[0]} × {snips.shape[1]} samples")
        else:
            plt.title("No snippets (adjust threshold or windows)")
            plt.xlabel('Time (s)'); plt.ylabel(state['ylab'])
        plt.show()

    with out_info:
        clear_output(wait=True)
        print(f"Peaks kept: {len(kept)} | Pre={pre_ms.value} ms, Post={post_ms.value} ms")

def _get_upload_name_and_bytes(upload):
    """
    Return (name, bytes) from a FileUpload widget, supporting ipywidgets 7 & 8.
    Returns (None, None) if nothing is uploaded.
    """
    v = getattr(upload, "value", None)
    if v is None:
        return None, None

    # ipywidgets 7.x: dict {filename: { 'content': bytes, ...}, ...}
    if isinstance(v, dict):
        if not v:
            return None, None
        name, info = next(iter(v.items()))
        return name, info.get("content", None)

    # ipywidgets 8.x: tuple/list of files
    if isinstance(v, (tuple, list)):
        if len(v) == 0:
            return None, None
        f0 = v[0]
        # Could be a dict or an UploadedFile object
        if isinstance(f0, dict):
            return f0.get("name", "uploaded"), f0.get("content", None)
        # Fallback: object with attributes
        name = getattr(f0, "name", "uploaded")
        content = getattr(f0, "content", None)
        return name, content

    return None, None


def _load_clicked(_b):
    with status:
        clear_output()
        print("Loading...")

    name, raw = _get_upload_name_and_bytes(upload)
    if not name or raw is None:
        with status:
            clear_output()
            print("No file data found. Make sure a file appears in the widget before clicking Load.")
        return

    # Optional: reset the widget so selecting the same file again will retrigger
    try:
        upload.value = () if isinstance(upload.value, tuple) else {}
    except Exception:
        pass

    tmp = tempfile.mkdtemp(prefix='abf_')
    path = os.path.join(tmp, name)
    with open(path, 'wb') as f:
        f.write(raw)

    try:
        abf = pyabf.ABF(path)
    except Exception as e:
        with status:
            clear_output()
            print("Failed:", e)
        return

    state['abf'] = abf
    state['fs'] = float(abf.dataRate)
    state['xlab'] = getattr(abf, 'sweepLabelX', 'time (s)')
    state['ylab'] = getattr(abf, 'sweepLabelY', 'amplitude')

    sweep.max = abf.sweepCount - 1
    sweep.disabled = False

    # Init baseline slider from sweep 0
    abf.setSweep(0)
    t = abf.sweepX
    baseline.min = float(t[0])
    baseline.max = float(t[-1])
    span = max(0.05*(t[-1]-t[0]), 0.05)
    baseline.value = (float(t[0]), float(min(t[0]+span, t[-1])))
    baseline.disabled = False

    cutoff.max = max(10.0, min(10000.0, 0.45*state['fs']))
    cutoff.value = min(200.0, cutoff.max)
    cutoff.disabled = False

    polarity.disabled = True
    thr.disabled = True
    pre_ms.disabled = False
    post_ms.disabled = False
    run_btn.disabled = False

    with status:
        clear_output()
        print(f"Loaded {name}: sweeps={abf.sweepCount}, fs={abf.dataRate} Hz, duration={abf.sweepLengthSec:.3f} s")

    _plot_raw()

load_btn.on_click(_load_clicked)
sweep.observe(lambda ch: _plot_raw(), names='value')
run_btn.on_click(_run)

