In [5]:
import numpy as np
from numpy.linalg import lstsq
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from sys import path
path.append('..')
from oasis.functions import gen_data, gen_sinusoidal_data, deconvolve, estimate_parameters
from oasis.plotting import simpleaxis
from oasis.oasis_methods import oasisAR1, oasisAR2
from scipy.signal import find_peaks
from scipy.optimize import curve_fit

#%matplotlib inline
%matplotlib qt

## Load data low-pass filtered Minian raw output (no temporal deconvolution update), zarr format

In [7]:
# --- load filtered traces (all units) ---
dpath = "./minian_intermediate"

C_ds = xr.open_zarr(f"{dpath}/C_lowpass.zarr", consolidated=False)
C_all = C_ds["C_lp"].astype(np.float32)

print("Before curation:", C_all.sizes)

# --- load your curation CSV ---
curation_path = "./curation_results.csv"

curation_df = pd.read_csv(curation_path)
keep_ids = curation_df.loc[curation_df["keep"] == 1, "unit_id"].values

print(f"Keeping {len(keep_ids)} curated units out of {len(curation_df)}")
print("Curated unit IDs:", keep_ids.tolist())

# --- filter down to curated units ---
C = C_all.sel(unit_id=keep_ids)

print("After curation:", C.sizes)


print("Vars:", list(C_ds.data_vars))
print("Shape:", C.shape, "| Dims:", C.dims)
print("Units:", C.sizes["unit_id"], "Frames:", C.sizes["frame"])

Before curation: Frozen({'unit_id': 114, 'frame': 6000})
Keeping 3 curated units out of 3
Curated unit IDs: [9, 10, 11]
After curation: Frozen({'unit_id': 3, 'frame': 6000})
Vars: ['C_lp']
Shape: (3, 6000) | Dims: ('unit_id', 'frame')
Units: 3 Frames: 6000


## Implementing functions to measure rise and fall time

estimates rise times and decay time constants (τ) from calcium imaging data (e.g., GCaMP6f). The goal is to extract biophysically meaningful kinetics while staying robust to noise and outliers

1. Event Detection
dettrend trace against baseline

- estimate noise (MAD)
- peaks are called using `scipy.find_peaks`
- which requires **Promience** and **Seperation**
- extract a small window around the event (pre + post)

2. Rise Time (10%-90%)
- fluorescence trace is normalized to peak amplitude
- find the time when the trace crosses **10%** and **90%** of the peak.
- Checks: requires both crossings to exist and be in the correct order
   
3. Decay Constant
- fit the falling phase with a log-linear exponential mode
- Two options for when to stop fitting:
  
      tail_to="baseline" → fit until trace returns near baseline
      tail_to="frac" → fit until a fixed fraction of peak (e.g., 30%)
  
  
- avoid over-fitting noise:
  
       Stop fitting once signal < noise_k × MAD.
	   Require at least min_tail_pts samples.
	   Only accept fits with R² ≥ r2_min on the log-scale.
- apply tau_cap to drop or clip unrealistically long decays
  
4. Outlier & Noise Handling
- MAD multiplied with noise_k. Prevents fitting tails eblow detetcable level
- Goddness of fit R², rject fits that dont look exponetial
- tau_cap avoids rare run-away fits biasing the median
- max_events_per_unit keeps plots readable and stats balanced

   
5. Group summary
- Reported statistics
- Median ± IQR (25–75%) for rise and decay.
- Count of valid events contributing to each.
- These can then be converted into AR(2) parameters (g) for deconvolution (OASIS).

In [8]:
def measure_event_kinetics(
    t, y, peak_idx, base_val,
    rise_lo=0.10, rise_hi=0.90,
    tail_to="baseline",          # "baseline" (preferred) or "frac"
    tail_frac=0.10,              # only used if tail_to == "frac"
    min_tail_pts=8,              # need enough points to fit a line
    r2_min=0.85,                 # require decent log-linear fit
    noise_k=2.0,                 # exclude tail once below noise floor = k * MAD
    max_tau_s=None               # e.g., 3.0 to drop >3s, or None to keep all
):
    """
    Measure rise time (10–90%) and decay time constant τ for one event window.

    t, y       : 1D arrays for the event window (seconds and fluorescence)
    peak_idx   : index (in this window) of the event peak
    base_val   : baseline value for this unit/segment (same units as y)

    Returns:
      dict(rise_t10_90_s, tau_decay_s, A_peak, ok_rise, ok_tau)
    """
    # amplitude at peak
    A = float(y[peak_idx] - base_val)
    if not np.isfinite(A) or A <= 1e-9:
        return dict(rise_t10_90_s=np.nan, tau_decay_s=np.nan, A_peak=0.0,
                    ok_rise=False, ok_tau=False)

    # ---------- Rise (10→90%) ----------
    yn = (y - base_val) / A
    tn = t

    def _t_at_level(tseg, yseg, level):
        above = yseg >= level
        if not np.any(above):
            return np.nan
        i = np.argmax(above)
        if i == 0:
            return tseg[0]
        # linear interpolation between i-1 and i
        t0, t1 = tseg[i-1], tseg[i]
        y0, y1 = yseg[i-1], yseg[i]
        if y1 == y0:
            return t1
        return t0 + (level - y0) * (t1 - t0) / (y1 - y0)

    t10 = _t_at_level(tn[:peak_idx+1], yn[:peak_idx+1], rise_lo)
    t90 = _t_at_level(tn[:peak_idx+1], yn[:peak_idx+1], rise_hi)
    ok_rise = np.isfinite(t10) and np.isfinite(t90) and (t90 > t10)
    rise_t = (t90 - t10) if ok_rise else np.nan

    # ---------- Decay τ (baseline-to-peak tail) ----------
    # Use the original (not normalized) trace for stability.
    z = np.maximum(y - base_val, 0.0)

    # noise floor from pre-peak tail: MAD × 1.4826
    pre = z[max(0, peak_idx-20):peak_idx]
    mad = np.median(np.abs(pre - np.median(pre))) * 1.4826 if pre.size else 0.0
    noise_floor = noise_k * mad

    # decide tail stop value
    if tail_to == "baseline":
        stop_level = max(noise_floor, tail_frac * A)  # don't go below noise
    else:
        stop_level = max(tail_frac * A, noise_floor)

    # tail region: from peak_idx forward until z <= stop_level
    end = peak_idx + 1
    while end < len(z) and z[end] > stop_level:
        end += 1

    # need enough samples
    if (end - peak_idx) < min_tail_pts:
        return dict(rise_t10_90_s=rise_t, tau_decay_s=np.nan, A_peak=A,
                    ok_rise=ok_rise, ok_tau=False)

    tt = t[peak_idx:end]
    zz = z[peak_idx:end]

    # ensure strictly positive for log
    mask = zz > max(noise_floor, 1e-9)
    if np.count_nonzero(mask) < min_tail_pts:
        return dict(rise_t10_90_s=rise_t, tau_decay_s=np.nan, A_peak=A,
                    ok_rise=ok_rise, ok_tau=False)

    tt = tt[mask]
    zz = zz[mask]

    # log-linear fit: log(zz) = c - (1/τ) * t
    X = np.vstack([np.ones_like(tt), -tt]).T
    ylog = np.log(zz)
    # least squares
    beta, *_ = lstsq(X, ylog, rcond=None)   # [c, 1/τ]
    inv_tau = beta[1]
    if inv_tau <= 0 or not np.isfinite(inv_tau):
        return dict(rise_t10_90_s=rise_t, tau_decay_s=np.nan, A_peak=A,
                    ok_rise=ok_rise, ok_tau=False)
    tau = 1.0 / inv_tau

    # R^2 on log scale to verify exponentiality
    yhat = X @ beta
    ss_res = np.sum((ylog - yhat)**2)
    ss_tot = np.sum((ylog - np.mean(ylog))**2) + 1e-12
    r2 = 1.0 - ss_res/ss_tot

    ok_tau = np.isfinite(tau) and (tau > 0) and (r2 >= r2_min)
    if not ok_tau:
        tau = np.nan

    # optional cap/drop for outliers
    if ok_tau and (max_tau_s is not None) and (tau > max_tau_s):
        # Drop (return NaN) rather than clamp to avoid biasing medians upward
        tau = np.nan
        ok_tau = False

    return dict(rise_t10_90_s=rise_t, tau_decay_s=tau, A_peak=A,
                ok_rise=ok_rise, ok_tau=ok_tau)

In [9]:
# ---------- helpers ----------
def robust_baseline(y, q=10):
    return np.percentile(y, q)

def simple_lowpass(y, wlen_sec=0.3, fps=20.0):
    w = max(3, int(round(wlen_sec * fps)) | 1)   # odd window
    k = np.ones(w, dtype=float) / w
    return np.convolve(y, k, mode="same")

def detect_isolated_events(y, fps, min_prom_sigma=3.5, min_separation_s=1.2,
                           pre_s=0.5, post_s=2.0, smooth_sec=None):
    """
    Returns (events, baseline), where events is a list of (start, peak, end) indices.
    min_prom_sigma is relative to a robust noise estimate (MAD).
    """
    if smooth_sec is not None:
        y_use = simple_lowpass(y, wlen_sec=smooth_sec, fps=fps)
    else:
        y_use = y

    base = robust_baseline(y_use, q=10)
    z = y_use - base

    # robust noise estimate
    mad = np.median(np.abs(z - np.median(z))) + 1e-12
    sig = 1.4826 * mad
    prom = max(1e-6, float(min_prom_sigma) * sig)

    distance = int(round(min_separation_s * fps))
    peaks, props = find_peaks(z, prominence=prom, distance=distance)

    events = []
    pre = int(round(pre_s * fps))
    post = int(round(post_s * fps))
    n = len(y)
    for p in peaks:
        s = max(0, p - pre)
        e = min(n, p + post)
        if e - s >= 5:
            events.append((s, p, e))
    return events, base

def frac_time_to_level(t, y, level):
    above = y >= level
    if not np.any(above):
        return np.nan
    idx = np.argmax(above)
    if idx == 0:
        return t[0]
    t0, t1 = t[idx-1], t[idx]
    y0, y1 = y[idx-1], y[idx]
    if y1 == y0:
        return t1
    return t0 + (level - y0) * (t1 - t0) / (y1 - y0)

def _exp_decay(t, A, tau, B):
    return A * np.exp(-(t - t[0]) / max(1e-9, tau)) + B

def measure_event_kinetics(t, y, peak_idx, base_val,
                           rise_lo=0.1, rise_hi=0.9,
                           tail_to="baseline",    # "baseline" or "frac"
                           tail_frac=0.50,        # only used if tail_to=="frac"
                           min_tail_pts=8, r2_min=0.85,
                           noise_k=2.0, max_tau_s=None):
    """
    Compute rise t10→90 and exponential tau on the decay.
    Returns dict with keys: rise_t10_90_s, tau_decay_s, ok_rise, ok_tau.
    """
    A = y[peak_idx] - base_val
    if A <= 1e-12:
        return dict(rise_t10_90_s=np.nan, tau_decay_s=np.nan,
                    ok_rise=False, ok_tau=False)

    yn = (y - base_val) / A
    tn = t

    # Rise
    t10 = frac_time_to_level(tn[:peak_idx+1], yn[:peak_idx+1], rise_lo)
    t90 = frac_time_to_level(tn[:peak_idx+1], yn[:peak_idx+1], rise_hi)
    ok_rise = np.isfinite(t10) and np.isfinite(t90) and (t90 > t10)
    rise_t = (t90 - t10) if ok_rise else np.nan

    # Decide where to stop decay fit
    if tail_to == "baseline":
        # go until the trace returns to near-baseline (use small band around 0)
        end_idx = peak_idx + 1
        while end_idx < len(yn) and yn[end_idx] > 0.02:
            end_idx += 1
    else:  # "frac"
        end_idx = peak_idx + 1
        while end_idx < len(yn) and yn[end_idx] > float(tail_frac):
            end_idx += 1

    ok_tau = False
    tau = np.nan
    if end_idx - peak_idx >= max(5, min_tail_pts):
        tx = tn[peak_idx:end_idx]
        yx = y[peak_idx:end_idx]
        try:
            p0 = (A, 1.0, base_val)
            popt, _ = curve_fit(_exp_decay, tx, yx, p0=p0, maxfev=5000)
            tau = float(popt[1])
            if max_tau_s is not None and np.isfinite(tau) and tau > max_tau_s:
                ok_tau = False
            else:
                # quick R^2 check
                yhat = _exp_decay(tx, *popt)
                ss_res = np.sum((yx - yhat)**2)
                ss_tot = np.sum((yx - np.mean(yx))**2) + 1e-12
                r2 = 1.0 - ss_res/ss_tot
                ok_tau = (r2 >= r2_min)
        except Exception:
            ok_tau = False

    return dict(rise_t10_90_s=rise_t, tau_decay_s=tau, ok_rise=ok_rise, ok_tau=ok_tau)

In [10]:
def plot_transients_with_summary(
    C, units, fps, *,
    n_examples=10,                 # how many random units to show (total)
    window=2.0,                    # seconds after peak to include when aligning
    min_prom_sigma=3.5,            # -> detect_isolated_events
    min_separation_s=1.2,          # -> detect_isolated_events
    smooth_sec=None,               # -> detect_isolated_events (None if C already low-pass)
    max_events_per_unit=None,      # cap #events per unit in plots/stats
    random_seed=0,                 # for reproducible event subsampling / unit picks

    # KINETICS (passed to measure_event_kinetics for every event)
    tail_to="baseline",            # "baseline" or "frac"
    tail_frac=0.30,                # used iff tail_to=="frac"
    min_tail_pts=8,
    r2_min=0.85,
    noise_k=2.0,
    tau_cap=None,                  # e.g. 3.0 (seconds). If None, no cap/drop.
    drop_capped=True,              # if True drop tau>cap; else clip to cap

    # SUMMARY RETURN
    return_stats=False,            # if True, return summary dict(s)
    summarize="per_group"          # "per_group" or "all"
):
    """
    Plot aligned calcium transients + separate rise/decay histograms for random units.
    Requires helpers:
      - detect_isolated_events(y, fps, ...)
      - measure_event_kinetics(t, y, peak_idx, base_val, ...)

    Returns:
      None, or list of summary dicts if return_stats=True.
    """
    assert summarize in ("per_group", "all")
    rng = np.random.default_rng(random_seed)

    # pick units and split into 2 figures for readability
    chosen_units = np.random.choice(units, size=n_examples, replace=False)
    groups = np.array_split(chosen_units, 2)

    summaries = []
    all_rise, all_decay = [], []

    for fig_idx, group in enumerate(groups, start=1):
        # --- three columns per unit: aligned • rise-only • decay-only ---
        fig, axes = plt.subplots(len(group), 3, figsize=(15, 2.6*len(group)))
        if len(group) == 1:
            axes = np.array([axes])  # ensure 2D

        # per-figure accumulators (for suptitle + optional per_group stats)
        rise_all, decay_all = [], []

        for i, uid in enumerate(group):
            y = C.sel(unit_id=int(uid)).values.astype(float)
            t = np.arange(len(y)) / float(fps)

            # detect candidate events (returns list of (start, peak, end) idx)
            events, base = detect_isolated_events(
                y, fps,
                min_prom_sigma=min_prom_sigma,
                min_separation_s=min_separation_s,
                pre_s=0.5, post_s=window,
                smooth_sec=smooth_sec
            )

            # optionally cap #events per unit
            if max_events_per_unit is not None and len(events) > max_events_per_unit:
                keep = np.sort(rng.choice(len(events), size=max_events_per_unit, replace=False))
                events = [events[j] for j in keep]

            aligned = []
            rises, decays = [], []

            for s, p, e in events:
                seg_t = t[s:e]
                seg_y = y[s:e]

                kin = measure_event_kinetics(
                    seg_t, seg_y, peak_idx=p - s, base_val=base,
                    rise_lo=0.10, rise_hi=0.90,
                    tail_to=tail_to, tail_frac=tail_frac,
                    min_tail_pts=min_tail_pts, r2_min=r2_min,
                    noise_k=noise_k, max_tau_s=(tau_cap if drop_capped else None)
                )
                # normalize segment for plotting (peak -> 1)
                A = max(seg_y[p - s] - base, 1e-9)
                aligned.append((seg_t - seg_t[p - s], (seg_y - base) / A))

                # collect metrics (apply clip if we keep outliers)
                if kin.get("ok_rise", False) and np.isfinite(kin["rise_t10_90_s"]):
                    rises.append(float(kin["rise_t10_90_s"]))

                if kin.get("ok_tau", False) and np.isfinite(kin["tau_decay_s"]):
                    tau = float(kin["tau_decay_s"])
                    if (tau_cap is not None) and (not drop_capped):
                        tau = min(tau, float(tau_cap))
                    decays.append(tau)

            # accumulate across unit, across figure, across all
            rise_all.extend(rises);  decay_all.extend(decays)
            all_rise.extend(rises);  all_decay.extend(decays)

            # --- PLOTTING ---
            ax1, ax_rise, ax_decay = axes[i]

            # 1) aligned transients
            for tt, yy in aligned:
                ax1.plot(tt, yy, lw=0.8, alpha=0.65)
            ax1.axvline(0, color='k', lw=0.6)
            ax1.set_xlim(-0.5, window)
            ax1.set_ylim(0, 1.2)
            ax1.set_title(f"Unit {int(uid)}: aligned (n={len(aligned)})")
            ax1.set_xlabel("Time (s)")

            # 2) rise-only histogram
            ax_rise.cla()
            if len(rises):
                r = np.asarray(rises, dtype=float)
                r_max = max(0.2, float(np.nanpercentile(r, 99)))
                ax_rise.hist(r, bins=np.linspace(0, r_max, 15), alpha=0.75, color="tab:blue")
                ax_rise.set_xlim(0, r_max)
            ax_rise.set_title("Rise 10–90%")
            ax_rise.set_xlabel("Seconds")

            # 3) decay-only histogram
            ax_decay.cla()
            if len(decays):
                d = np.asarray(decays, dtype=float)
                d_max = max(0.2, float(np.nanpercentile(d, 99)))
                ax_decay.hist(d, bins=np.linspace(0, d_max, 15), alpha=0.75, color="orange")
                ax_decay.set_xlim(0, d_max)
            ax_decay.set_title("Decay τ")
            ax_decay.set_xlabel("Seconds")

        # helper for median ± IQR text
        def _txt(x):
            if not x: return "n/a"
            q25, med, q75 = np.percentile(x, [25, 50, 75])
            return f"{med:.2f} s (IQR {q25:.2f}–{q75:.2f})"

        plt.suptitle(
            f"Group {fig_idx}  |  Rise: {_txt(rise_all)}   •   Decay: {_txt(decay_all)}",
            y=1.02
        )
        plt.tight_layout()
        plt.show()

        # per-figure stats (optional)
        if return_stats and summarize == "per_group":
            def _stats(x):
                if not x: return None, (None, None), 0
                q25, med, q75 = np.percentile(x, [25, 50, 75])
                return float(med), (float(q25), float(q75)), int(len(x))
            r_med, r_iqr, r_n = _stats(rise_all)
            d_med, d_iqr, d_n = _stats(decay_all)
            summaries.append({
                "group": fig_idx,
                "units": list(map(int, group)),
                "rise_median": r_med, "rise_iqr": r_iqr, "n_rise": r_n,
                "decay_median": d_med, "decay_iqr": d_iqr, "n_decay": d_n,
            })

    # single aggregated stats (optional)
    if return_stats and summarize == "all":
        def _stats(x):
            if not x: return None, (None, None), 0
            q25, med, q75 = np.percentile(x, [25, 50, 75])
            return float(med), (float(q25), float(q75)), int(len(x))
        r_med, r_iqr, r_n = _stats(all_rise)
        d_med, d_iqr, d_n = _stats(all_decay)
        return [{
            "group": "all",
            "units": list(map(int, chosen_units)),
            "rise_median": r_med, "rise_iqr": r_iqr, "n_rise": r_n,
            "decay_median": d_med, "decay_iqr": d_iqr, "n_decay": d_n,
        }]

    return summaries if return_stats else None

## Run rise & decay time constant functions on dataset

In [12]:
fps = 20.0


CONFIG = dict(
    n_examples=3,              # number of random units
    random_seed=0,              # random unit selcter on/off
    max_events_per_unit=50,     # calcium events per unit
    window=2.5,                 # time after peak include oin aligned plot

    #Peak detection
    min_prom_sigma=4,           # peak detection: minimum prominence in units of noise
    min_separation_s=1.2,       # peak detection: minimum seperation between different events (1.2s)
    smooth_sec=None,            # optinal pre smooting window lenght in s

    #Decay fit
    tail_to="baseline",             # or "frac" fit decay: "baseline" (down to baseline) or "frac" (to a fraction of peak)
    tail_frac=0.20,             # if tail_to="frac": fraction of peak height (e.g. 0.3 = 30%)
    min_tail_pts=8,             # minimum number of data points required in the decay tail for fitting
    r2_min=0.9,                # minimum R² for exponential decay fit to be considered valid
    noise_k=2.0,                # noise robustness factor for event detection (higher = stricter detection)

    #Outlier control for stats
    tau_cap=3.0,                # maximum allowed decay time constant [s] (anything longer will be clipped/dropped)
    drop_capped=True           # True = drop events with τ > tau_cap, False = include but clip to tau_cap
)

stats = plot_transients_with_summary(
    C, list(C.unit_id.values), fps=fps,
    return_stats=True, summarize="all",
    **{k:v for k,v in CONFIG.items() if k in
       ["n_examples","window","min_prom_sigma","min_separation_s","smooth_sec",
        "max_events_per_unit","random_seed","tau_cap","drop_capped"]}
)

print("CONFIG:", CONFIG)
print("STATS:", stats)

CONFIG: {'n_examples': 3, 'random_seed': 0, 'max_events_per_unit': 50, 'window': 2.5, 'min_prom_sigma': 4, 'min_separation_s': 1.2, 'smooth_sec': None, 'tail_to': 'baseline', 'tail_frac': 0.2, 'min_tail_pts': 8, 'r2_min': 0.9, 'noise_k': 2.0, 'tau_cap': 3.0, 'drop_capped': True}
STATS: [{'group': 'all', 'units': [9, 11, 10], 'rise_median': 0.3102955952162745, 'rise_iqr': (0.29367875722680914, 0.3337003992468901), 'n_rise': 34, 'decay_median': 0.6316911256056117, 'decay_iqr': (0.5651564764884391, 0.7992108721981664), 'n_decay': 11}]


In [13]:
def tau_to_g(tau_d, tau_r, fps, in_seconds=False):
    if in_seconds:
        tau_d *= fps; tau_r *= fps
    r1 = np.exp(-1.0/float(tau_d)); r2 = np.exp(-1.0/float(tau_r))
    return np.array([r1 + r2, -r1*r2], dtype=np.float64)

In [14]:
decay_med_s   = stats[0]["decay_median"]
decay_q25_s, decay_q75_s = stats[0]["decay_iqr"]
rise_med_s    = stats[0]["rise_median"]

# seconds → frames
to_frames = lambda s: max(1.0, s * fps)
tau_d_med  = to_frames(decay_med_s)
tau_d_q25  = to_frames(decay_q25_s)
tau_d_q75  = to_frames(decay_q75_s)
tau_r_med  = to_frames(rise_med_s)

# compute AR(2) coefficients
g_med  = tau_to_g(tau_d_med, tau_r_med, fps)   # balanced default
g_slow = tau_to_g(tau_d_q75, tau_r_med, fps)   # longer decay
g_fast = tau_to_g(tau_d_q25, tau_r_med, fps)   # shorter decay

print(f"Decay median: {decay_med_s:.3f} s = {tau_d_med:.1f} frames")
print(f"Decay q25:    {decay_q25_s:.3f} s = {tau_d_q25:.1f} frames")
print(f"Decay q75:    {decay_q75_s:.3f} s = {tau_d_q75:.1f} frames")
print(f"Rise median:  {rise_med_s:.3f} s = {tau_r_med:.1f} frames\n")

print("g_med  =", g_med,  " (median decay)")
print("g_slow =", g_slow, " (q75 decay → slower tail)")
print("g_fast =", g_fast, " (q25 decay → faster tail)")

Decay median: 0.632 s = 12.6 frames
Decay q25:    0.565 s = 11.3 frames
Decay q75:    0.799 s = 16.0 frames
Rise median:  0.310 s = 6.2 frames

g_med  = [ 1.77507465 -0.78640034]  (median decay)
g_slow = [ 1.79053082 -0.79955626]  (q75 decay → slower tail)
g_fast = [ 1.76650532 -0.77910633]  (q25 decay → faster tail)


The **median decay** is a robust central tendency but can be biased downward by noise and truncated events.

Using the **q75 decay (upper quartile)** gives a slightly longer τ, ensuring the AR(2) model does not underestimate tail length.
Prevents overfitting to noisy, fast falloffs

rise = median, decay = q75 gives a balanced AR(2) model that avoids underestimating calcium dynamics while remaining robust to noise.

# Deconvolve Low-level AR(2)

## c, s = oasisAR2 (y_in, g1, g2, s_min=smin_u)

c= deconvolved calcium

s= pruned spikes

y= original raw trace

g1,g2= rise and decay time

s_min: built-in spike sparsity/threshold used by oasisAR2 (bigger ⇒ fewer spikes)

- Fixed kinetics: uses your chosen rise/decay (via g)—no re-estimation
- only optimizes spikes (and calcium), not the AR parameters
- two levers—s_min inside OASIS and post_thresh after—to push more/less into the residual vs spikes
- unit_specific_smin tame noisy units without sacrificing others

## just helper functions

In [15]:
# forward recursion of an AR(2) calcium model:
#-- reconstructed calcium trace c that would result from those spikes under the chosen AR(2) dynamics--
def forward_ar2(s, g):
    c = np.zeros_like(s, dtype=float)
    for t in range(2, len(s)):
        c[t] = g[0]*c[t-1] + g[1]*c[t-2] + s[t]
    return c

#_prune_spikes applies flexible thresholds (robust z-score, absolute amplitude, or quantiles) to remove small/noisy inferred spikes
#clean up inferred spike train after deconvolution
# "z": Computes a robust z-score of each spike relative to the median and MAD

# "z_abs" : First tries the z-score rule, If that eliminates everything (too strict) or MAD≈0, 
#it falls back to a simple absolute rule: keep spikes ≥ frac * max(spike amplitude)

# "abs" : spikes whose amplitude ≥ thr

# "q" : Keep only the top (1-q) fraction of spikes, i.e. above the amplitude at quantile q
# None: return spikes as is

def _prune_spikes(s, post_thresh, debug=False):
    if post_thresh is None:
        return s

    mode = post_thresh[0] if isinstance(post_thresh, tuple) else "abs"

    if mode == "z":
        z = float(post_thresh[1])
        nz = s[s > 0]
        if nz.size == 0:
            if debug: print("[prune] z: no nonzero spikes -> keep none")
            return s*0
        mu  = np.median(nz)
        mad = np.median(np.abs(nz - mu))*1.4826
        if mad > 1e-9:
            keep = ((s - mu)/mad) >= z
            if debug: print(f"[prune] z: MAD={mad:.3g}, kept={keep.sum()} / {s.size}")
            return s*keep
        if debug: print("[prune] z: MAD≈0 -> no-op")
        return s

    if mode == "z_abs":
        z, frac = float(post_thresh[1]), float(post_thresh[2])
        nz = s[s > 0]
        if nz.size == 0:
            if debug: print("[prune] z_abs: no nonzero spikes -> keep none")
            return s*0
        mu  = np.median(nz)
        mad = np.median(np.abs(nz - mu))*1.4826
        if mad > 1e-9:
            keep = ((s - mu)/mad) >= z
            if keep.sum() > 0:
                if debug: print(f"[prune] z_abs: z-kept={keep.sum()} / {s.size}")
                return s*keep
            if debug: print("[prune] z_abs: z kept 0 -> fallback abs")
        else:
            if debug: print("[prune] z_abs: MAD≈0 -> fallback abs")
        # absolute fallback
        m = nz.max()
        thr = frac * m
        keep = s >= thr
        if debug: print(f"[prune] z_abs: abs thr={thr:.5g}, kept={keep.sum()} / {s.size}")
        return s*keep

    if mode == "abs":
        thr = float(post_thresh[1])
        keep = s >= thr
        if debug: print(f"[prune] abs: thr={thr:.5g}, kept={keep.sum()} / {s.size}")
        return s*keep

    if mode == "q":
        q = float(post_thresh[1])
        nz = s[s > 0]
        if nz.size == 0:
            if debug: print("[prune] q: no nonzero spikes -> keep none")
            return s*0
        thr = np.quantile(nz, q)
        keep = s >= thr
        if debug: print(f"[prune] q: q={q}, thr={thr:.5g}, kept={keep.sum()} / {s.size}")
        return s*keep

    return s

In [16]:
#low-level OASIS AR(2) wrapper
#deconvolution per unit with fixed kinetics, optional baseline subtraction, and optional post-hoc spike pruning
def run_oasis_ar2(
    C_da, unit_ids, g, s_min=1.0, baseline="p10", fps=20.0,
    post_thresh=("z_abs", 2.5, 0.2),
    unit_specific_smin=None,
    debug_prune=False,
    keep_unpruned=True,     # <-- NEW: keep s before pruning
):
    g = np.asarray(g, dtype=np.float64); g1, g2 = float(g[0]), float(g[1])
    out = {}
    for uid in unit_ids:
        y = np.ascontiguousarray(C_da.sel(unit_id=int(uid)).values, dtype=np.float64)

        # baseline
        if isinstance(baseline, str) and baseline.startswith("p"):
            p = float(baseline[1:]); b = float(np.percentile(y, p))
        elif isinstance(baseline, (int, float)):
            b = float(baseline)
        else:
            b = 0.0

        # unit-level s_min override
        smin_u = unit_specific_smin.get(int(uid), s_min) if unit_specific_smin else s_min

        # OASIS
        c, s = oasisAR2(y - b, g1, g2, s_min=float(smin_u))

        # keep a copy BEFORE pruning
        s_unpruned = s.copy() if keep_unpruned else None

        # post-prune spikes
        s = _prune_spikes(s, post_thresh, debug=debug_prune)

        out[int(uid)] = dict(
            raw=y, c=c, s=s, s_unpruned=s_unpruned,  # <-- store both
            b=b, g=np.array([g1, g2], dtype=np.float64),
            time=np.arange(len(y), dtype=np.float64)/float(fps)
        )
    return out

In [17]:
def plot_fit_split_with_spikes(results, C_lp=None, start_s=200, duration_s=300,
                               fps=20.0, normalize=False, spike_scale="auto"):
    """
    results: dict[unit_id] -> {"raw","c","s","time","g",...}
    C_lp   : optional low-pass traces to overlay on Raw
    spike_scale:
        - "auto": scale spikes to ~max(c) in the shown window (per unit)
        - float: use a fixed multiplier for spike heights (e.g., 1.0 keeps raw s)
    """
    units = list(results.keys())
    s0 = int(round(start_s * fps))
    s1 = int(round((start_s + duration_s) * fps))
    t  = np.arange(s0, s1) / fps

    n = len(units)
    fig, axes = plt.subplots(n, 4, figsize=(16, 2.4*n),
                             gridspec_kw={"wspace": 0.25, "hspace": 0.35},
                             sharex=True)
    if n == 1:
        axes = axes.reshape(1, -1)

    for i, uid in enumerate(units):
        R = results[uid]
        raw = np.asarray(R["raw"])[s0:s1]
        c   = np.asarray(R["c"])[s0:s1]
        s   = np.asarray(R["s"])[s0:s1]

        # reconvolved fit from spikes (same AR(2) as used in results)
        g = np.asarray(R["g"])
        yhat = np.zeros_like(s, dtype=float)
        for k in range(2, len(s)):
            yhat[k] = g[0]*yhat[k-1] + g[1]*yhat[k-2] + s[k]

        if normalize:
            denom = np.nanmax(raw) + 1e-12
            raw = raw / denom; c = c / denom; yhat = yhat / denom

        # indices of spike samples
        idx = np.flatnonzero(s > 0)
        # spike scaling (used for the rightmost "Spikes" panel and to derive overlay heights)
        if spike_scale == "auto":
            cmax = float(np.nanmax(c)) if np.any(np.isfinite(c)) else 1.0
            smax = float(np.nanmax(s)) if idx.size else 1.0
            scale = (0.9 * cmax / smax) if smax > 0 else 1.0
        else:
            scale = float(spike_scale)

        # --- Raw (±low-pass) ---
        ax = axes[i, 0]
        ax.plot(t, raw, color="0.25", lw=0.8, label="Raw")
        if C_lp is not None:
            ax.plot(t, C_lp.sel(unit_id=int(uid)).values[s0:s1],
                    color="tab:blue", lw=0.9, alpha=0.8, label="Low-pass")
        # overlay spikes on Raw (fixed-height stems relative to Raw range)
        if idx.size:
            rmin, rmax = float(np.nanmin(raw)), float(np.nanmax(raw))
            rrange = max(1e-9, rmax - rmin)
            base_raw = rmin + 0.02 * rrange            # start a bit above bottom
            height_raw = 0.15 * rrange                 # ~15% panel height
            ax.vlines(t[idx], base_raw, base_raw + height_raw, colors="tab:green", lw=0.8, alpha=0.9)
        if i == 0:
            ax.set_title("Raw (±low-pass)")
            ax.legend(frameon=False, fontsize=8)
        ax.set_ylabel(f"Unit {uid}")

        # --- Deconvolved (c) ---
        axc = axes[i, 1]
        axc.plot(t, c, color="tab:purple", lw=0.9)
        # overlay spikes on Deconvolved (scaled to c range)
        if idx.size:
            cmin, cmax = float(np.nanmin(c)), float(np.nanmax(c))
            crange = max(1e-9, cmax - cmin)
            base_c = cmin + 0.03 * crange
            height_c = 0.20 * crange                   # ~20% panel height
            axc.vlines(t[idx], base_c, base_c + height_c, colors="tab:green", lw=0.8, alpha=0.9)
        if i == 0:
            axc.set_title("Deconvolved (c)")

        # --- Fit & Residual ---
        res = raw - yhat
        axfr = axes[i, 2]
        axfr.plot(t, yhat, color="tab:red", lw=0.9, label="Reconvolved fit")
        axfr.plot(t, res,  color="0.6",    lw=0.7, label="Residual")
        if i == 0:
            axfr.set_title("Fit & Residual")
            axfr.legend(frameon=False, fontsize=8)

        # --- Spikes (stick plot) ---
        axsp = axes[i, 3]
        if idx.size:
            axsp.vlines(t[idx], 0, s[idx] * scale, colors="tab:green", linewidth=0.8)
        axsp.set_ylim(0, max(1e-6, (np.nanmax(s[idx] * scale) if idx.size else 1.0)))
        if i == 0: axsp.set_title("Spikes (scaled)")

        # cosmetics
        for a in axes[i, :]:
            a.spines["top"].set_visible(False)
            a.spines["right"].set_visible(False)
            a.tick_params(length=3)
        if i == n-1:
            for a in axes[i, :]:
                a.set_xlabel("Time (s)")

    plt.tight_layout()
    plt.show()

In [28]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_fit_split_with_spikes_plotly(results, C_lp=None, start_s=200, duration_s=300,
                                      fps=20.0, normalize=False, spikes_height_frac=0.15,
                                      cols=2, html_path="spikes_overlay.html"):
    units = list(results.keys())
    s0 = int(round(start_s*fps)); s1 = int(round((start_s+duration_s)*fps))

    rows = int(np.ceil(len(units)/cols))
    fig = make_subplots(rows=rows, cols=cols,
                        subplot_titles=[f"Unit {u}" for u in units],
                        shared_xaxes=False, shared_yaxes=False)

    for i, u in enumerate(units):
        r = results[u]
        t   = np.arange(s0, s1)/fps
        raw = np.asarray(r["raw"])[s0:s1]
        c   = np.asarray(r["c"])[s0:s1]
        s   = np.asarray(r["s"])[s0:s1]
        idx = np.flatnonzero(s > 0)

        if normalize:
            denom = np.nanmax(raw) + 1e-12
            raw = raw/denom; c = c/denom

        r_i = i//cols + 1
        c_i = i%cols + 1
        showleg = (i == 0)  # only first unit contributes legend entries  # <<

        # Raw (dark gray)
        fig.add_trace(go.Scatter(
            x=t, y=raw, name="Raw", mode="lines",
            line=dict(width=1.5, color="rgba(60,60,60,0.7)"),  # alpha 0.7
            showlegend=showleg
        ), r_i, c_i)


        # Optional low-pass (blue)
        if C_lp is not None:
            lp = C_lp.sel(unit_id=int(u)).values[s0:s1]
            if normalize:
                lp = lp / (np.nanmax(lp) + 1e-12)
            fig.add_trace(go.Scatter(
                x=t, y=lp, name="Low-pass", mode="lines",            # <<
                line=dict(width=1.2, color="rgba(30,120,180,0.8)"),
                showlegend=showleg                                    # <<
            ), r_i, c_i)

        # Deconvolved c (purple)
        fig.add_trace(go.Scatter(
            x=t, y=c, name="Deconvolved c", mode="lines",
            line=dict(width=1.5, color="rgba(148,0,211,0.7)"), # alpha 0.7
            showlegend=showleg
        ), r_i, c_i)

        # Spikes (green stems as shapes)
        if idx.size:
            ymin = float(np.nanmin(raw)); ymax = float(np.nanmax(raw))
            yrng = max(1e-9, ymax - ymin)
            y0 = ymin + 0.02*yrng; y1 = y0 + spikes_height_frac*yrng
            for ti in t[idx]:
                fig.add_shape(
                    type="line", x0=ti, x1=ti, y0=y0, y1=y1,
                    line=dict(color="rgba(0,150,0,1)", width=3),
                    row=r_i, col=c_i
                )
        # Dummy trace so "Spikes" appears in legend once
        if showleg:
            fig.add_trace(go.Scatter(
                x=[None], y=[None], mode="lines",
                line=dict(color="rgba(0,150,0,1)", width=1.2),
                name="Spikes",
                showlegend=True
            ), r_i, c_i)

    fig.update_layout(
        height=700*rows, width=1600,
        showlegend=True,                                            # << turn legend on
        legend=dict(orientation="h", yanchor="bottom", y=1.02,
                    xanchor="left", x=0),                           # << place legend on top
        template="plotly_white",
        title="Raw + c + Spikes (interactive)"
    )
    fig.update_xaxes(title_text="Time (s)")

    fig.write_html(html_path, include_plotlyjs="cdn", full_html=True)
    print(f"Saved interactive HTML to {html_path}")

In [29]:
def summarize_spike_counts(results, fps=20.0, force_frames=None):
    """
    results: dict[unit_id] -> {"s", "s_unpruned", "time", ...}
    fps: frames/sec, used only if force_frames is provided instead of time
    force_frames: optional int; if given, duration = force_frames / fps

    Returns: (df_per_unit, totals_dict)
    """
    rows = []
    for uid, R in results.items():
        s_after   = np.asarray(R["s"])
        s_before  = np.asarray(R.get("s_unpruned")) if ("s_unpruned" in R and R["s_unpruned"] is not None) else None

        n_after  = int((s_after  > 0).sum())
        n_before = int((s_before > 0).sum()) if s_before is not None else None

        # duration (prefer time if present)
        if force_frames is not None:
            duration_s = force_frames / float(fps)
        elif "time" in R and R["time"] is not None:
            duration_s = float(np.nanmax(R["time"]) - np.nanmin(R["time"]))
        else:
            duration_s = len(s_after) / float(fps)

        duration_min = duration_s / 60.0 if duration_s > 0 else np.nan

        rate_after  = n_after  / duration_min if duration_min > 0 else np.nan
        rate_before = (n_before / duration_min) if (n_before is not None and duration_min > 0) else None

        rows.append(dict(
            unit_id=uid,
            spikes_before=n_before,
            spikes_after=n_after,
            reduction=(None if n_before is None else (n_before - n_after)),
            reduction_pct=(None if n_before in (None, 0) else 100.0*(n_before - n_after)/n_before),
            minutes=duration_min,
            rate_before_per_min=rate_before,
            rate_after_per_min=rate_after,
        ))

    df = pd.DataFrame(rows).sort_values("unit_id").reset_index(drop=True)

    # Overall totals (only where we have before)
    has_before = df["spikes_before"].notna()
    total_before = int(df.loc[has_before, "spikes_before"].sum()) if has_before.any() else None
    total_after  = int(df["spikes_after"].sum())
    total_redux  = (None if total_before is None else total_before - total_after)
    total_redux_pct = (None if (total_before in (None, 0)) else 100.0*total_redux/total_before)

    totals = dict(
        total_units=len(df),
        total_spikes_before=total_before,
        total_spikes_after=total_after,
        total_reduction=total_redux,
        total_reduction_pct=total_redux_pct,
        mean_rate_after_per_min=float(df["rate_after_per_min"].mean()),
        mean_rate_before_per_min=(float(df["rate_before_per_min"].mean()) if has_before.any() else None),
    )
    return df, totals

## Run deconvolution

### Choose here the appropiate g from the rise / decay estimation perfomed earlier

In [30]:
print(f"Decay median: {decay_med_s:.3f} s = {tau_d_med:.1f} frames")
print(f"Decay q25 fast:    {decay_q25_s:.3f} s = {tau_d_q25:.1f} frames")
print(f"Decay q75 slow:    {decay_q75_s:.3f} s = {tau_d_q75:.1f} frames")
print(f"Rise median:  {rise_med_s:.3f} s = {tau_r_med:.1f} frames\n")

print("g_med  =", g_med,  " (median decay)")
print("g_slow =", g_slow, " (q75 decay → slower tail)")
print("g_fast =", g_fast, " (q25 decay → faster tail)")

Decay median: 0.632 s = 12.6 frames
Decay q25 fast:    0.565 s = 11.3 frames
Decay q75 slow:    0.799 s = 16.0 frames
Rise median:  0.310 s = 6.2 frames

g_med  = [ 1.77507465 -0.78640034]  (median decay)
g_slow = [ 1.79053082 -0.79955626]  (q75 decay → slower tail)
g_fast = [ 1.76650532 -0.77910633]  (q25 decay → faster tail)


# Change parameters here

In [31]:
curated_units = list(C.unit_id.values)
print(f"Running deconvolution on {len(curated_units)} curated units")

results_all = run_oasis_ar2(
    C, curated_units, g_slow,
    s_min=1.0,
    baseline="p10",
    post_thresh=("q",0.98),
    #unit_specific_smin={13: 2.0, 14: 1.5 },   #  only affects unit 13,   # or add per-unit overrides if needed
    fps=fps,
    debug_prune=True
)

Running deconvolution on 3 curated units
[prune] q: q=0.98, thr=8.8818e-16, kept=80 / 6000
[prune] q: q=0.98, thr=1.1102e-16, kept=31 / 6000
[prune] q: q=0.98, thr=8.2601e-16, kept=29 / 6000


In [32]:
plot_fit_split_with_spikes_plotly(results_all, fps=20.0,
                                  cols=2, html_path="spikes_overlay.html")

Saved interactive HTML to spikes_overlay.html


## save deconvolution results (like minian as zarr)

In [33]:
import xarray as xr\

output_dir = "./output/"
animal_name = "WL25"  # <-- Change this for each animal
session_date = "DEC1"  # Optional: add session info

unit_ids = curated_units

# Extract arrays from results
C_deconv = np.stack([results_all[uid]["c"] for uid in unit_ids])
S = np.stack([results_all[uid]["s"] for uid in unit_ids])
S_unpruned = np.stack([results_all[uid]["s_unpruned"] for uid in unit_ids])

# Create xarray Dataset
ds_oasis = xr.Dataset({
    "C_deconv": (["unit_id", "frame"], C_deconv.astype(np.float32)),
    "S": (["unit_id", "frame"], S.astype(np.float32)),
    "S_unpruned": (["unit_id", "frame"], S_unpruned.astype(np.float32))
}, coords={
    "unit_id": unit_ids,
    "frame": np.arange(C_deconv.shape[1])
})

# Add metadata as attributes
ds_oasis.attrs["fps"] = fps
ds_oasis.attrs["g_params"] = g_slow.tolist()
ds_oasis.attrs["s_min"] = 1.0
ds_oasis.attrs["baseline"] = "p10"
ds_oasis.attrs["post_thresh"] = str(("q", 0.98))

## Save to zarr - NOW includes session_date in filename
zarr_filename = f"{animal_name}_{session_date}_oasis_deconv.zarr"
ds_oasis.to_zarr(f"{output_dir}/{zarr_filename}", mode="w")
print(f"✅ Saved zarr dataset: {output_dir}/{zarr_filename}")

print("\n" + "="*60)
print("SAVE SUMMARY")
print("="*60)
print(f"Animal: {animal_name}")
print(f"Session: {session_date}")
print(f"Units: {len(unit_ids)}")
print(f"Frames: {C_deconv.shape[1]}")
print(f"Arrays saved: C_deconv, S, S_unpruned")
print(f"Filename: {zarr_filename}")
print("="*60)

✅ Saved zarr dataset: ./output//WL25_DEC1_oasis_deconv.zarr

SAVE SUMMARY
Animal: WL25
Session: DEC1
Units: 3
Frames: 6000
Arrays saved: C_deconv, S, S_unpruned
Filename: WL25_DEC1_oasis_deconv.zarr



Consolidated metadata is currently not part in the Zarr format 3 specification. It may not be supported by other zarr implementations and may change in the future.



## Plot normalized

In [None]:
start_s = 200
duration_s = 300

plot_fit_split_with_spikes(results_all, C_lp=None, start_s=start_s, duration_s=duration_s, fps=fps,
                           normalize=True, spike_scale="auto")

plt.savefig(
    f"{output_dir}/{animal_name}_deconv_plots_{int(start_s)}-{int(start_s+duration_s)}s.png",
    dpi=300,
    bbox_inches='tight'
)
print(f"✅ Saved plot: {output_dir}/{animal_name}_deconv_plots_{int(start_s)}-{int(start_s+duration_s)}s.png")

plt.show()

ValueError: x and y must have same first dimension, but have shapes (6000,) and (2000,)

## Load saved deconvolved data
### C_deconv - Deconvolved Calcium
### S - Spike Train (pruned)
### S_unpruned - Spike Train

In [None]:
##Load saved deconvolved data
#C_deconv - Deconvolved Calcium
#S - Spike Train (pruned) like post oasis pruning  ("q", 0.98)
#S_unpruned - Spike Train


# Load the zarr file
zarr_path = "/Users/mbrosch/Documents/GitKraken_mac/OASIS/examples/WL3_JUL25_oasis_deconv.zarr"
ds_oasis = xr.open_zarr(zarr_path)

print(ds_oasis)

## Post result analysis: Residual RMS evaluation

In [None]:
models = {
    "g_fast": g_fast,
    "g_med":  g_med,
    "g_slow": g_slow,
}

def run_for_g(g):
    return run_oasis_ar2(
        C, curated_units, g,
        s_min=1.0,
        baseline="p10",
        post_thresh=None,          # no post-pruning while we compare
        unit_specific_smin=None,
        fps=fps,
        debug_prune=False
    )

results_by_g = {name: run_for_g(g) for name, g in models.items()}

In [None]:
def residual_rms(results, start_s=None, duration_s=None, fps=20.0):
    """Return (overall_mean_rms, per_unit_df). Windowing is optional."""
    rows = []
    for uid, r in results.items():
        raw = np.asarray(r["raw"], dtype=float)
        s   = np.asarray(r["s"],   dtype=float)
        yhat = forward_ar2(s, r["g"])

        if start_s is not None and duration_s is not None:
            s0 = int(max(0, round(start_s * fps)))
            s1 = int(min(len(raw), round((start_s + duration_s) * fps)))
            raw = raw[s0:s1]; yhat = yhat[s0:s1]

        res = raw - yhat
        rms = float(np.sqrt(np.mean(res**2)))
        rows.append({"unit_id": int(uid), "rms": rms})
    df = pd.DataFrame(rows).sort_values("unit_id")
    return float(df["rms"].mean()), df

# compute RMS for each model
rms_overall = {}
rms_tables  = {}
for name, res in results_by_g.items():
    mean_rms, df = residual_rms(res, fps=fps)  # or add start_s/duration_s
    rms_overall[name] = mean_rms
    rms_tables[name]  = df

print("Overall residual RMS:")
for k, v in rms_overall.items():
    print(f"  {k:6s}: {v:.4f}")

## residuals g_fast might be low but just because its also overfitting noise and not signal per se

In [None]:
df_all = rms_tables["g_med"].rename(columns={"rms":"rms_med"})
for k in ("g_fast","g_slow"):
    df_all = df_all.merge(rms_tables[k].rename(columns={"rms":f"rms_{k.split('_')[1]}"}),
                          on="unit_id", how="outer")
df_all["best_model"] = df_all[["rms_fast","rms_med","rms_slow"]].idxmin(axis=1)
df_all

## Metric based QC

In [None]:
import numpy as np
import pandas as pd

# ===================== helpers =====================

def _reconvolve_from_s(s, g):
    """Reconvolve spikes s with AR(2) kernel g=[g1,g2] to get fit yhat."""
    s = np.asarray(s, dtype=float)
    yhat = np.zeros_like(s, dtype=float)
    g = np.asarray(g, dtype=float)
    for k in range(2, len(s)):
        yhat[k] = g[0]*yhat[k-1] + g[1]*yhat[k-2] + s[k]
    return yhat

def _safe_corr(a, b):
    a = np.asarray(a); b = np.asarray(b)
    m = np.isfinite(a) & np.isfinite(b)
    if m.sum() < 3: return np.nan
    a = a[m] - np.nanmean(a[m]); b = b[m] - np.nanmean(b[m])
    da = np.sqrt(np.nanvar(a)); db = np.sqrt(np.nanvar(b))
    if da == 0 or db == 0: return np.nan
    return float(np.dot(a, b) / (len(a) * da * db))

def _explained_variance(y, yhat):
    y = np.asarray(y); yhat = np.asarray(yhat)
    m = np.isfinite(y) & np.isfinite(yhat)
    if m.sum() < 3: return np.nan
    var_y = np.nanvar(y[m])
    if var_y <= 0: return np.nan
    return float(1.0 - np.nanvar(y[m] - yhat[m]) / var_y)

def _eventize_indices(idx, min_gap_frames):
    """Group sorted indices into clusters separated by > min_gap_frames."""
    if idx.size == 0: return []
    splits = np.where(np.diff(idx) > min_gap_frames)[0] + 1
    return np.split(idx, splits)

def _count_calcium_events(trace, fps, z=2.5, min_gap_s=0.5):
    """
    Simple, robust event finder on a calcium-like trace:
    - threshold by median + z * MAD (positive side)
    - merge above-threshold samples with a refractory gap
    - pick the local max in each merged segment as the event
    """
    x = np.asarray(trace, float)
    med = np.nanmedian(x)
    mad = np.nanmedian(np.abs(x - med))*1.4826
    thr = med + z*mad if mad > 1e-12 else med + (np.nanmax(x)-med)*0.25
    nz = np.flatnonzero(x > thr)
    if nz.size == 0:
        return 0, np.array([], dtype=int)
    groups = _eventize_indices(nz, int(round(min_gap_s*fps)))
    peaks = [g[np.argmax(x[g])] for g in groups]
    return len(peaks), np.array(peaks, dtype=int)

def _cluster_spike_frames(s, fps, min_gap_s=0.35, min_cluster_frac=0.10):
    """
    Cluster non-zero spike frames into spike events separated by min_gap_s.
    Drop tiny clusters whose summed amplitude < min_cluster_frac of the largest.
    """
    s = np.asarray(s)
    idx = np.flatnonzero(s > 0)
    if idx.size == 0:
        return 0, []

    groups = _eventize_indices(idx, int(round(min_gap_s*fps)))

    # remove tiny clusters by area
    if len(groups) > 0 and (min_cluster_frac is not None) and (min_cluster_frac > 0):
        areas = [float(s[g].sum()) for g in groups]
        amax = max(areas) if areas else 0.0
        thr = min_cluster_frac * amax
        groups = [g for g, a in zip(groups, areas) if a >= thr]

    return len(groups), groups

# ===================== metrics + tagging =====================

def qc_metrics_from_results_with_events(results_all, fps, start_s=None, duration_s=None,
                                        use_lowpass=False, lowpass_da=None,
                                        z_events=2.5, min_gap_event_s=0.5,
                                        min_gap_spike_s=0.35, min_cluster_frac=0.10):
    """
    Build a per-unit QC table with core & event-based metrics.
    Returns a DataFrame indexed by unit_id.
    """
    units = list(results_all.keys())
    rows = []
    for uid in units:
        R = results_all[uid]
        raw = np.asarray(R["raw"])
        s   = np.asarray(R["s"])
        g   = np.asarray(R["g"])
        n   = len(raw)

        # analysis window
        if start_s is not None and duration_s is not None:
            s0 = max(0, int(round(start_s*fps)))
            s1 = min(n, int(round((start_s+duration_s)*fps)))
        else:
            s0, s1 = 0, n
        sl = slice(s0, s1)

        raw_w = raw[sl]; s_w = s[sl]
        yhat_w = _reconvolve_from_s(s_w, g)

        base_for_fit = raw_w
        if use_lowpass and (lowpass_da is not None):
            try:
                base_for_fit = np.asarray(lowpass_da.sel(unit_id=int(uid)).values)[sl]
            except Exception:
                pass  # fall back to raw_w

        # core metrics
        corr_fit = _safe_corr(base_for_fit, yhat_w)
        r2 = _explained_variance(base_for_fit, yhat_w)
        resid_rms_ratio = float(
            np.sqrt(np.nanmean((base_for_fit - yhat_w)**2)) /
            (np.sqrt(np.nanmean(base_for_fit**2)) + 1e-12)
        )
        frames = (s1 - s0)
        minutes = frames / fps / 60.0
        nonzero_frames = int(np.count_nonzero(s_w))
        spikes_per_min = float(nonzero_frames / max(1e-9, minutes))
        nonzero_pct = 100.0 * nonzero_frames / max(1, frames)

        # event metrics
        n_events, _ = _count_calcium_events(base_for_fit, fps,
                                            z=z_events, min_gap_s=min_gap_event_s)
        n_spike_clusters, _ = _cluster_spike_frames(
            s_w, fps, min_gap_s=min_gap_spike_s, min_cluster_frac=min_cluster_frac
        )

        spike_frames_per_event = float(nonzero_frames / max(1, n_events))
        spike_clusters_per_event = float(n_spike_clusters / max(1, n_events))

        rows.append(dict(
            unit_id=int(uid),
            corr_fit=corr_fit,
            r2=r2,
            resid_rms_ratio=resid_rms_ratio,
            spikes_per_min=spikes_per_min,
            nonzero_frames=nonzero_frames,
            nonzero_pct=nonzero_pct,
            n_calcium_events=int(n_events),
            n_spike_clusters=int(n_spike_clusters),
            spike_frames_per_event=spike_frames_per_event,
            spike_clusters_per_event=spike_clusters_per_event
        ))

    df = pd.DataFrame(rows).sort_values("unit_id").reset_index(drop=True)
    return df.set_index("unit_id")  # keep unit_id always visible

def qc_tag_units_events(df,
                        corr_keep=0.30, corr_discard=0.15,
                        r2_keep=0.20,   r2_discard=0.10,
                        resid_keep=0.80, resid_discard=0.95,
                        rate_min=1.0,   rate_max=20.0,
                        frames_per_event_max=6.0,       # relaxed vs 5.0
                        clusters_per_event_max=2.5):    # relaxed vs 2.0
    """
    Add 'qc_tag' column using core + event thresholds.
    """
    df = df.copy()

    good = (
        (df["corr_fit"].fillna(-1) >= corr_keep) |
        (df["r2"].fillna(-1)       >= r2_keep)
    ) & (
        df["resid_rms_ratio"] <= resid_keep
    ) & (
        df["spikes_per_min"].between(rate_min, rate_max)
    ) & (
        df["spike_frames_per_event"]   <= frames_per_event_max
    ) & (
        df["spike_clusters_per_event"] <= clusters_per_event_max
    )

    bad = (
        (df["corr_fit"].fillna(-1) < corr_discard) &
        (df["r2"].fillna(-1)       < r2_discard)
    ) | (
        df["resid_rms_ratio"] >= resid_discard
    ) | (
        ~df["spikes_per_min"].between(0.5, 60.0)   # pathological guardrails
    ) | (
        df["spike_frames_per_event"]   > (frames_per_event_max*2)
    ) | (
        df["spike_clusters_per_event"] > (clusters_per_event_max*2)
    )

    df["qc_tag"] = np.where(bad, "discard", np.where(good, "keep", "borderline"))
    return df

def _soften_for_strong_units(df, corr_strong=0.70, r2_strong=0.50, resid_strong=0.70):
    """
    If core metrics are very strong, upgrade borderline → keep.
    """
    df = df.copy()
    strong = (df["corr_fit"] >= corr_strong) & (df["r2"] >= r2_strong) & (df["resid_rms_ratio"] <= resid_strong)
    df.loc[strong & (df["qc_tag"] == "borderline"), "qc_tag"] = "keep"
    return df

# ===================== top-level convenience =====================

def run_qc(results_all, fps, start_s=None, duration_s=None,
           use_lowpass=False, lowpass_da=None,
           z_events=2.5, min_gap_event_s=0.5,
           min_gap_spike_s=0.35, min_cluster_frac=0.10,
           # tagging thresholds (NAc-friendly defaults)
           corr_keep=0.30, corr_discard=0.15,
           r2_keep=0.20,   r2_discard=0.10,
           resid_keep=0.80, resid_discard=0.95,
           rate_min=1.0,   rate_max=20.0,
           frames_per_event_max=6.0,
           clusters_per_event_max=2.5,
           soften_strong=True):
    """
    Compute metrics, tag units, (optionally soften strong ones), and return (df, summary).
    """
    df = qc_metrics_from_results_with_events(
        results_all, fps, start_s, duration_s,
        use_lowpass, lowpass_da,
        z_events, min_gap_event_s,
        min_gap_spike_s=min_gap_spike_s, min_cluster_frac=min_cluster_frac
    )
    df = qc_tag_units_events(
        df,
        corr_keep=corr_keep, corr_discard=corr_discard,
        r2_keep=r2_keep,     r2_discard=r2_discard,
        resid_keep=resid_keep, resid_discard=resid_discard,
        rate_min=rate_min, rate_max=rate_max,
        frames_per_event_max=frames_per_event_max,
        clusters_per_event_max=clusters_per_event_max
    )
    if soften_strong:
        df = _soften_for_strong_units(df)

    summary = {
        "keep":       df.index[df.qc_tag == "keep"].tolist(),
        "borderline": df.index[df.qc_tag == "borderline"].tolist(),
        "discard":    df.index[df.qc_tag == "discard"].tolist(),
        "counts":     df.qc_tag.value_counts().to_dict()
    }
    return df, summary

- `corr_fit` Does the spike-based model capture the main transient structure ≥0.3–0.4
- `r2` Fraction of variance in the raw trace explained by the fit. ≥0.2
- `resid_rms_ratio` How much variance is left unexplained, <0.8 (residual smaller than raw → fit captured a lot)
- `spikes_per_min` Biological plausibility of inferred firing rate
- `nonzero_frames` How often spikes are being called,a unit with hundreds of “nonzero frames” but only a handful of true calcium events may be over-deconvolving noise


- `spike_frames_per_event` High values mean the algorithm is “smearing” spikes over many frames instead of giving a compact even, close to 1-3 frames
- `spike_clusters_per_event` High values mean the algorithm is “splitting” a single biological event into multiple inferred spikes, clsoe to one cluster per calcium event

In [None]:
df_qc, summary = run_qc(
    results_all, fps=20.0,
    start_s=200, duration_s=300,
    use_lowpass=False,             # your raw is already low-pass
    # (you can tweak any thresholds here if needed)
)

print(df_qc.round(3))   # unit_id is the index
print(summary)