# MEA Exports Analysis — Spikes + Waveforms

Analyze and aggregate the per-pair exports produced by the Pair Viewer or the batch exporter.

- Exports location: `<output_root>/exports/spikes_waveforms/...`
- Format details: see `docs/exports_spikes_waveforms.md` in this repo (variable types, dataset names, and attributes).
- This notebook auto-discovers pairs, loads HDF5/CSV per pair, and can aggregate per-plate.

Tip: If your external drive is not mounted, set `OUTPUT_ROOT` to `_mcs_mea_outputs_local`.

## FR Box Plots
Channel-level and per-pair mean firing rate distributions comparing CTZ vs VEH across all exported pairs.
This cell discovers exports on disk and does not depend on prior variables.

In [None]:
# Self-contained FR box plots (no dependency on df_pairs)
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

try:
    import seaborn as sns
except Exception:
    sns = None

# Resolve exports directory
try:
    EXPORTS_DIR
except NameError:
    try:
        from mcs_mea_analysis.config import CONFIG
        REPO_ROOT = next((p for p in [Path.cwd(), *Path.cwd().parents] if (p/'mcs_mea_analysis').exists()), Path.cwd())
        OUTPUT_ROOT = CONFIG.output_root if CONFIG.output_root.exists() else (REPO_ROOT / '_mcs_mea_outputs_local')
        EXPORTS_DIR = OUTPUT_ROOT / 'exports' / 'spikes_waveforms'
    except Exception:
        EXPORTS_DIR = Path('/Volumes/Manny2TB/mcs_mea_outputs/exports/spikes_waveforms')
print('Exports dir ->', EXPORTS_DIR)

# Load all per-pair summary CSVs
rows = []
for csvp in EXPORTS_DIR.rglob('*_summary.csv'):
    try:
        # Extract plate/round from path
        round_name = csvp.parents[1].name if len(csvp.parents) > 1 else None
        plate = None
        try:
            ps = csvp.parent.name
            plate = int(ps.replace('plate_', '')) if ps.startswith('plate_') else None
        except Exception:
            pass
        base = csvp.stem.replace('_summary','')
        pair_id = base
        df = pd.read_csv(csvp)
        # Types
        df['fr_hz'] = pd.to_numeric(df['fr_hz'], errors='coerce')
        df['channel'] = pd.to_numeric(df['channel'], errors='coerce').astype('Int64')
        df['side'] = df['side'].astype(str)
        df['pair_id'] = pair_id
        df['plate'] = plate
        df['round'] = round_name
        rows.append(df)
    except Exception as e:
        print('Skip', csvp.name, ':', e)

data = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=['channel','side','n_spikes','fr_hz','pair_id','plate','round'])
print('Rows loaded:', len(data), 'from', len(rows), 'pairs')

if data.empty:
    print('No summary CSVs found under', EXPORTS_DIR)
else:
    # Channel-level box plot (CTZ vs VEH)
    plt.figure(figsize=(6,4))
    if sns is not None:
        sns.boxplot(data=data, x='side', y='fr_hz')
        sns.stripplot(data=data, x='side', y='fr_hz', color='k', size=2, alpha=0.3)
    else:
        # Fallback to matplotlib only
        groups = [data[data['side']=='CTZ']['fr_hz'].dropna(), data[data['side']=='VEH']['fr_hz'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.title('Firing Rate (Hz) by Treatment — Channel Level')
    plt.ylabel('FR (Hz)'); plt.xlabel('')
    plt.show()

    # Per-pair mean FR box plot
    mean_per_pair = (
        data.groupby(['pair_id','side'], as_index=False)['fr_hz'].mean()
        .rename(columns={'fr_hz':'mean_fr_hz'})
    )
    plt.figure(figsize=(6,4))
    if sns is not None:
        sns.boxplot(data=mean_per_pair, x='side', y='mean_fr_hz')
        sns.stripplot(data=mean_per_pair, x='side', y='mean_fr_hz', color='k', size=3, alpha=0.5)
    else:
        groups = [mean_per_pair[mean_per_pair['side']=='CTZ']['mean_fr_hz'].dropna(),
                  mean_per_pair[mean_per_pair['side']=='VEH']['mean_fr_hz'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.title('Mean Firing Rate (Hz) by Treatment — Per-Pair Means')
    plt.ylabel('Mean FR (Hz)'); plt.xlabel('')
    plt.show()


In [None]:
# FR box plots across all exported pairs (no other vars needed)
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

try:
    import seaborn as sns
except Exception:
    sns = None

# Resolve exports directory
try:
    from mcs_mea_analysis.config import CONFIG
    REPO_ROOT = next((p for p in [Path.cwd(), *Path.cwd().parents] if (p/'mcs_mea_analysis').exists()), Path.cwd())
    OUTPUT_ROOT = CONFIG.output_root if CONFIG.output_root.exists() else (REPO_ROOT / '_mcs_mea_outputs_local')
    EXPORTS_DIR = OUTPUT_ROOT / 'exports' / 'spikes_waveforms'
except Exception:
    EXPORTS_DIR = Path('/Volumes/Manny2TB/mcs_mea_outputs/exports/spikes_waveforms')

print('Exports dir ->', EXPORTS_DIR)

# Load all per-pair summary CSVs
rows = []
for csvp in EXPORTS_DIR.rglob('*_summary.csv'):
    try:
        round_name = csvp.parents[1].name if len(csvp.parents) > 1 else None
        plate = None
        try:
            ps = csvp.parent.name
            plate = int(ps.replace('plate_', '')) if ps.startswith('plate_') else None
        except Exception:
            pass
        pair_id = csvp.stem.replace('_summary', '')
        df = pd.read_csv(csvp)
        df['fr_hz'] = pd.to_numeric(df['fr_hz'], errors='coerce')
        df['channel'] = pd.to_numeric(df['channel'], errors='coerce').astype('Int64')
        df['side'] = df['side'].astype(str)
        df['pair_id'] = pair_id
        df['plate'] = plate
        df['round'] = round_name
        rows.append(df)
    except Exception as e:
        print('Skip', csvp.name, ':', e)

data = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=['channel','side','n_spikes','fr_hz','pair_id','plate','round'])
print('Rows loaded:', len(data), 'from', len(rows), 'pairs')

if data.empty:
    print('No summary CSVs found under', EXPORTS_DIR)
else:
    # Channel-level box plot (CTZ vs VEH)
    plt.figure(figsize=(6,4))
    if sns is not None:
        sns.boxplot(data=data, x='side', y='fr_hz')
        sns.stripplot(data=data, x='side', y='fr_hz', color='k', size=2, alpha=0.3)
    else:
        groups = [data[data['side']=='CTZ']['fr_hz'].dropna(), data[data['side']=='VEH']['fr_hz'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.title('Firing Rate (Hz) by Treatment — Channel Level')
    plt.ylabel('FR (Hz)'); plt.xlabel('')
    plt.show()

    # Per-pair mean FR box plot
    mean_per_pair = (
        data.groupby(['pair_id','side'], as_index=False)['fr_hz'].mean()
            .rename(columns={'fr_hz':'mean_fr_hz'})
    )
    plt.figure(figsize=(6,4))
    if sns is not None:
        sns.boxplot(data=mean_per_pair, x='side', y='mean_fr_hz')
        sns.stripplot(data=mean_per_pair, x='side', y='mean_fr_hz', color='k', size=3, alpha=0.5)
    else:
        groups = [mean_per_pair[mean_per_pair['side']=='CTZ']['mean_fr_hz'].dropna(),
                  mean_per_pair[mean_per_pair['side']=='VEH']['mean_fr_hz'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.title('Mean Firing Rate (Hz) by Treatment — Per-Pair Means')
    plt.ylabel('Mean FR (Hz)'); plt.xlabel('')
    plt.show()


In [None]:
# Pre/Post FR and spike durations (uses exported HDF5s)
import json
from pathlib import Path
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt

try:
    import seaborn as sns
except Exception:
    sns = None

try:
    from scipy import signal
except Exception as e:
    raise RuntimeError("This cell needs scipy installed (for peak widths).") from e

# Resolve exports directory
try:
    from mcs_mea_analysis.config import CONFIG
    REPO_ROOT = next((p for p in [Path.cwd(), *Path.cwd().parents] if (p/'mcs_mea_analysis').exists()), Path.cwd())
    OUTPUT_ROOT = CONFIG.output_root if CONFIG.output_root.exists() else (REPO_ROOT / '_mcs_mea_outputs_local')
    EXPORTS_DIR = OUTPUT_ROOT / 'exports' / 'spikes_waveforms'
except Exception:
    EXPORTS_DIR = Path('/Volumes/Manny2TB/mcs_mea_outputs/exports/spikes_waveforms')

def _noise_level(y: np.ndarray, method: str, pctl: float) -> float:
    if y.size == 0:
        return np.nan
    if method == 'mad':
        med = np.median(y)
        return 1.4826 * np.median(np.abs(y - med))
    if method == 'rms':
        return float(np.sqrt(np.mean(np.square(y))))
    if method == 'pctl':
        med = np.median(y)
        return float(np.percentile(np.abs(y - med), pctl))
    return np.nan

def _parse_bounds_attr(attr_val):
    try:
        if isinstance(attr_val, (bytes, bytearray)):
            attr_val = attr_val.decode()
        d = json.loads(attr_val) if isinstance(attr_val, str) else attr_val
        return float(d.get('t0', 0.0)), float(d.get('t1', 0.0))
    except Exception:
        return 0.0, 0.0

def detect_pre_post(t: np.ndarray, y: np.ndarray, sr_hz: float,
                    t0b: float, t1b: float, t0a: float, t1a: float, dcfg: dict):
    # masks
    mb = (t >= t0b) & (t <= t1b)
    ma = (t >= t0a) & (t <= t1a)
    yb = y[mb]
    noise = _noise_level(yb, str(dcfg.get('noise','mad')), float(dcfg.get('noise_percentile',68.0)))
    if not np.isfinite(noise) or noise <= 0:
        return np.empty(0), np.empty(0), np.empty(0), np.empty(0)
    thr = float(dcfg.get('K', 5.0)) * noise
    dist = max(1, int(round(float(dcfg.get('refractory_ms',1.0)) * 1e-3 * sr_hz)))
    minw = max(1, int(round(float(dcfg.get('min_width_ms',0.3)) * 1e-3 * sr_hz)))
    pol = str(dcfg.get('polarity','neg'))

    def detect_on_window(mask):
        ta = t[mask]; ya = y[mask]
        if ta.size == 0:
            return np.empty(0), np.empty(0)
        arr = -ya if pol in ('neg','both') else ya
        peaks, _ = signal.find_peaks(arr, height=thr, distance=dist, width=minw)
        widths, _, _, _ = signal.peak_widths(arr, peaks, rel_height=0.5) if peaks.size else (np.empty(0),)*4
        st = ta[peaks] if peaks.size else np.empty(0)
        w_ms = (widths / float(sr_hz)) * 1000.0 if widths.size else np.empty(0)
        return st, w_ms

    st_pre, w_pre_ms = detect_on_window(mb)
    st_post, w_post_ms = detect_on_window(ma)
    return st_pre, w_pre_ms, st_post, w_post_ms

rows_fr = []
rows_dur = []

for h5 in EXPORTS_DIR.rglob('*.h5'):
    if h5.name.endswith('_summary.h5'):
        continue
    try:
        round_name = h5.parents[1].name if len(h5.parents) > 1 else None
        plate = None
        try:
            ps = h5.parent.name
            plate = int(ps.replace('plate_', '')) if ps.startswith('plate_') else None
        except Exception:
            pass
        pair_id = h5.stem
        with h5py.File(h5.as_posix(),'r') as f:
            dcfg = {}
            if 'detect_config_json' in f:
                try:
                    dcfg = json.loads(f['detect_config_json'][()].decode())
                except Exception:
                    dcfg = {}
            for side in ('CTZ','VEH'):
                if side not in f: 
                    continue
                g = f[side]
                sr = float(g.attrs.get('sr_hz', 0.0)) or 10000.0
                t0b, t1b = _parse_bounds_attr(g.attrs.get('baseline_bounds', '{"t0":0,"t1":0}'))
                t0a, t1a = _parse_bounds_attr(g.attrs.get('analysis_bounds', '{"t0":0,"t1":0}'))
                chs = sorted(int(k[2:4]) for k in g.keys() if k.startswith('ch') and k.endswith('_time'))
                for ch in chs:
                    t = g[f'ch{ch:02d}_time'][:] if f'ch{ch:02d}_time' in g else np.empty(0)
                    yf = g[f'ch{ch:02d}_filtered'][:] if f'ch{ch:02d}_filtered' in g else np.empty(0)
                    if t.size == 0 or yf.size == 0:
                        continue
                    st_pre, w_pre, st_post, w_post = detect_pre_post(t, yf, sr, t0b, t1b, t0a, t1a, dcfg)
                    dur_pre = max(1e-9, (t1b - t0b))
                    dur_post = max(1e-9, (t1a - t0a))
                    fr_pre = (st_pre.size / dur_pre) if dur_pre > 0 else np.nan
                    fr_post = (st_post.size / dur_post) if dur_post > 0 else np.nan
                    rows_fr.append({'pair_id': pair_id, 'plate': plate, 'round': round_name, 'side': side, 'channel': ch, 'period': 'pre', 'fr_hz': fr_pre})
                    rows_fr.append({'pair_id': pair_id, 'plate': plate, 'round': round_name, 'side': side, 'channel': ch, 'period': 'post', 'fr_hz': fr_post})
                    for dms in w_pre:
                        rows_dur.append({'pair_id': pair_id, 'plate': plate, 'round': round_name, 'side': side, 'channel': ch, 'period': 'pre', 'width_ms': float(dms)})
                    for dms in w_post:
                        rows_dur.append({'pair_id': pair_id, 'plate': plate, 'round': round_name, 'side': side, 'channel': ch, 'period': 'post', 'width_ms': float(dms)})
    except Exception as e:
        print('Skip H5', h5.name, ':', e)

fr_data = pd.DataFrame(rows_fr) if rows_fr else pd.DataFrame(columns=['pair_id','plate','round','side','channel','period','fr_hz'])
dur_data = pd.DataFrame(rows_dur) if rows_dur else pd.DataFrame(columns=['pair_id','plate','round','side','channel','period','width_ms'])
print('FR rows:', len(fr_data), '| Duration rows:', len(dur_data))

# FR nested box plot (CTZ/VEH × pre/post)
if not fr_data.empty:
    plt.figure(figsize=(7,4))
    if sns is not None:
        sns.boxplot(data=fr_data, x='side', y='fr_hz', hue='period')
        sns.stripplot(data=fr_data, x='side', y='fr_hz', hue='period', dodge=True, color='k', size=2, alpha=0.3)
        plt.legend_.remove() if hasattr(plt, 'legend_') else None
    else:
        for i, side in enumerate(['CTZ','VEH']):
            grp = [fr_data[(fr_data.side==side)&(fr_data.period=='pre')]['fr_hz'].dropna(),
                   fr_data[(fr_data.side==side)&(fr_data.period=='post')]['fr_hz'].dropna()]
            plt.boxplot(grp, positions=[i*3+1, i*3+2], labels=[f'{side}-pre', f'{side}-post'])
    plt.title('FR (Hz) — Pre vs Post by Treatment')
    plt.ylabel('FR (Hz)')
    plt.show()

# Spike duration nested box plot (CTZ/VEH × pre/post)
if not dur_data.empty:
    plt.figure(figsize=(7,4))
    if sns is not None:
        sns.boxplot(data=dur_data, x='side', y='width_ms', hue='period')
        sns.stripplot(data=dur_data, x='side', y='width_ms', hue='period', dodge=True, color='k', size=2, alpha=0.3)
        plt.legend_.remove() if hasattr(plt, 'legend_') else None
    else:
        for i, side in enumerate(['CTZ','VEH']):
            grp = [dur_data[(dur_data.side==side)&(dur_data.period=='pre')]['width_ms'].dropna(),
                   dur_data[(dur_data.side==side)&(dur_data.period=='post')]['width_ms'].dropna()]
            plt.boxplot(grp, positions=[i*3+1, i*3+2], labels=[f'{side}-pre', f'{side}-post'])
    plt.title('Spike Duration (ms, FWHM) — Pre vs Post by Treatment')
    plt.ylabel('Width (ms)')
    plt.show()


In [None]:
# Per-pair and combined plots: FR (CTZ vs VEH, pre vs post) + waveform traces
import json, random
from pathlib import Path
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt

try:
    import seaborn as sns
except Exception:
    sns = None

try:
    from scipy import signal
except Exception as e:
    raise RuntimeError("This cell needs scipy installed (for spike widths).") from e

# Resolve exports directory
def _exports_dir():
    try:
        from mcs_mea_analysis.config import CONFIG
        REPO_ROOT = next((p for p in [Path.cwd(), *Path.cwd().parents] if (p/'mcs_mea_analysis').exists()), Path.cwd())
        OUTPUT_ROOT = CONFIG.output_root if CONFIG.output_root.exists() else (REPO_ROOT / '_mcs_mea_outputs_local')
        return OUTPUT_ROOT / 'exports' / 'spikes_waveforms'
    except Exception:
        return Path('/Volumes/Manny2TB/mcs_mea_outputs/exports/spikes_waveforms')

EXPORTS_DIR = _exports_dir()
print('Exports dir ->', EXPORTS_DIR)

# Discover H5 pairs and pick up to 3 (change MAX_PAIRS to include all)
h5_pairs = sorted(EXPORTS_DIR.rglob('*.h5'))
h5_pairs = [p for p in h5_pairs if not p.name.endswith('_summary.h5')]
MAX_PAIRS = 3
sel_pairs = h5_pairs[:MAX_PAIRS] if MAX_PAIRS else h5_pairs
print('Pairs found:', len(h5_pairs), '| Using:', len(sel_pairs))
for p in sel_pairs: print('  -', p.name)

# Helper: noise estimator used for detection
def _noise_level(y: np.ndarray, method: str, pctl: float) -> float:
    if y.size == 0:
        return np.nan
    if method == 'mad':
        med = np.median(y)
        return 1.4826 * np.median(np.abs(y - med))
    if method == 'rms':
        return float(np.sqrt(np.mean(np.square(y))))
    if method == 'pctl':
        med = np.median(y)
        return float(np.percentile(np.abs(y - med), pctl))
    return np.nan

# Detect spikes on pre/post windows using detect_config saved in H5
def detect_pre_post(t: np.ndarray, y: np.ndarray, sr_hz: float,
                    t0b: float, t1b: float, t0a: float, t1a: float, dcfg: dict):
    # masks
    mb = (t >= t0b) & (t <= t1b)
    ma = (t >= t0a) & (t <= t1a)
    yb = y[mb]
    noise = _noise_level(yb, str(dcfg.get('noise','mad')), float(dcfg.get('noise_percentile',68.0)))
    if not np.isfinite(noise) or noise <= 0:
        return np.empty(0), np.empty(0), np.empty(0), np.empty(0)
    thr = float(dcfg.get('K', 5.0)) * noise
    dist = max(1, int(round(float(dcfg.get('refractory_ms',1.0)) * 1e-3 * sr_hz)))
    minw = max(1, int(round(float(dcfg.get('min_width_ms',0.3)) * 1e-3 * sr_hz)))
    pol = str(dcfg.get('polarity','neg'))

    def detect_on_window(mask):
        ta = t[mask]; ya = y[mask]
        if ta.size == 0:
            return np.empty(0), np.empty(0)
        arr = -ya if pol in ('neg','both') else ya
        peaks, _ = signal.find_peaks(arr, height=thr, distance=dist, width=minw)
        widths, _, _, _ = signal.peak_widths(arr, peaks, rel_height=0.5) if peaks.size else (np.empty(0),)*4
        st = ta[peaks] if peaks.size else np.empty(0)
        w_ms = (widths / float(sr_hz)) * 1000.0 if widths.size else np.empty(0)
        return st, w_ms

    st_pre, w_pre_ms = detect_on_window(mb)
    st_post, w_post_ms = detect_on_window(ma)
    return st_pre, w_pre_ms, st_post, w_post_ms

def _parse_bounds_attr(attr_val):
    try:
        if isinstance(attr_val, (bytes, bytearray)):
            attr_val = attr_val.decode()
        d = json.loads(attr_val) if isinstance(attr_val, str) else attr_val
        return float(d.get('t0', 0.0)), float(d.get('t1', 0.0))
    except Exception:
        return 0.0, 0.0

# Per-pair plots (separate)
combined_fr_rows = []
combined_fr_prepost_rows = []

for h5 in sel_pairs:
    pair_id = h5.stem
    # Load per-pair summary CSV (post-only FR per channel)
    csv_sum = h5.with_name(h5.stem + '_summary.csv')
    df_sum = pd.read_csv(csv_sum) if csv_sum.exists() else pd.DataFrame(columns=['channel','side','n_spikes','fr_hz'])
    df_sum['side'] = df_sum['side'].astype(str)
    df_sum['pair_id'] = pair_id
    combined_fr_rows.append(df_sum.assign(period='post'))

    # Plot: FR channel-level (CTZ vs VEH) for this pair
    plt.figure(figsize=(6,4))
    if sns is not None:
        sns.boxplot(data=df_sum, x='side', y='fr_hz')
        sns.stripplot(data=df_sum, x='side', y='fr_hz', color='k', size=2, alpha=0.3)
    else:
        groups = [df_sum[df_sum['side']=='CTZ']['fr_hz'].dropna(), df_sum[df_sum['side']=='VEH']['fr_hz'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.title(f'FR (Hz) by Treatment — Channels — {pair_id}')
    plt.ylabel('FR (Hz)'); plt.xlabel('')
    plt.show()

    # Compute pre/post FR (re-detect spikes on exported filtered traces)
    fr_prepost = []
    with h5py.File(h5.as_posix(),'r') as f:
        # detect config
        dcfg = {}
        if 'detect_config_json' in f:
            try:
                dcfg = json.loads(f['detect_config_json'][()].decode())
            except Exception:
                dcfg = {}
        for side in ('CTZ','VEH'):
            if side not in f:
                continue
            g = f[side]
            sr = float(g.attrs.get('sr_hz', 0.0)) or 10000.0
            t0b, t1b = _parse_bounds_attr(g.attrs.get('baseline_bounds', '{"t0":0,"t1":0}'))
            t0a, t1a = _parse_bounds_attr(g.attrs.get('analysis_bounds', '{"t0":0,"t1":0}'))
            chs = sorted(int(k[2:4]) for k in g.keys() if k.startswith('ch') and k.endswith('_time'))
            for ch in chs:
                t = g[f'ch{ch:02d}_time'][:] if f'ch{ch:02d}_time' in g else np.empty(0)
                yf = g[f'ch{ch:02d}_filtered'][:] if f'ch{ch:02d}_filtered' in g else np.empty(0)
                if t.size == 0 or yf.size == 0:
                    continue
                st_pre, _, st_post, _ = detect_pre_post(t, yf, sr, t0b, t1b, t0a, t1a, dcfg)
                dur_pre = max(1e-9, (t1b - t0b))
                dur_post = max(1e-9, (t1a - t0a))
                fr_pre = (st_pre.size / dur_pre) if dur_pre > 0 else np.nan
                fr_post = (st_post.size / dur_post) if dur_post > 0 else np.nan
                fr_prepost.append({'pair_id': pair_id, 'side': side, 'channel': ch, 'period': 'pre', 'fr_hz': fr_pre})
                fr_prepost.append({'pair_id': pair_id, 'side': side, 'channel': ch, 'period': 'post', 'fr_hz': fr_post})

    df_pp = pd.DataFrame(fr_prepost)
    combined_fr_prepost_rows.append(df_pp)

    # Plot: Nested pre/post FR by side for this pair
    plt.figure(figsize=(7,4))
    if sns is not None:
        sns.boxplot(data=df_pp, x='side', y='fr_hz', hue='period')
        sns.stripplot(data=df_pp, x='side', y='fr_hz', hue='period', dodge=True, color='k', size=2, alpha=0.3)
        plt.legend_.remove() if hasattr(plt, 'legend_') else None
    else:
        for i, side in enumerate(['CTZ','VEH']):
            grp = [df_pp[(df_pp.side==side)&(df_pp.period=='pre')]['fr_hz'].dropna(),
                   df_pp[(df_pp.side==side)&(df_pp.period=='post')]['fr_hz'].dropna()]
            plt.boxplot(grp, positions=[i*3+1, i*3+2], labels=[f'{side}-pre', f'{side}-post'])
    plt.title(f'FR (Hz) — Pre vs Post by Treatment — {pair_id}')
    plt.ylabel('FR (Hz)')
    plt.show()

    # Plot: Waveform traces (overlay, sampled) for CTZ and VEH
    # Sample up to N_WF waveforms total per side from across channels
    N_WF = 200
    with h5py.File(h5.as_posix(),'r') as f:
        for side in ('CTZ','VEH'):
            if side not in f:
                continue
            g = f[side]
            sr = float(g.attrs.get('sr_hz',10000.0)) or 10000.0
            # collect waveforms from all channels
            waves = []
            for k in sorted(g.keys()):
                if k.startswith('ch') and k.endswith('_waveforms'):
                    arr = g[k][:]
                    if arr.size:
                        waves.append(arr)
            if not waves:
                print(f'No waveforms in {pair_id} {side}')
                continue
            W = np.vstack([w for w in waves if w.size]) if len(waves) else np.empty((0,0))
            if W.size == 0:
                continue
            # sample rows
            idx = np.arange(W.shape[0])
            if idx.size > N_WF:
                rng = np.random.default_rng(0)
                idx = rng.choice(idx, size=N_WF, replace=False)
            W = W[idx]
            n_samp = W.shape[1]
            # approximate pre/post split ~ 1:2 (export used 0.8ms pre, 1.6ms post)
            n_pre = int(round(n_samp / 3))
            t_ms = (np.arange(n_samp) - n_pre) * (1000.0 / sr)
            med = np.median(W, axis=0)
            plt.figure(figsize=(6,3))
            for r in range(W.shape[0]):
                plt.plot(t_ms, W[r], color='C0' if side=='CTZ' else 'C1', alpha=0.1, linewidth=0.6)
            plt.plot(t_ms, med, color='k', linewidth=2, label='median')
            plt.axvline(0.0, color='k', linestyle='--', alpha=0.4)
            plt.title(f'Waveform Traces — {side} — {pair_id} (n={W.shape[0]})')
            plt.xlabel('Time (ms, approx)'); plt.ylabel('Filtered amplitude (a.u.)')
            plt.show()

# Combined plots across selected pairs
if combined_fr_rows:
    all_post = pd.concat(combined_fr_rows, ignore_index=True)
    plt.figure(figsize=(6,4))
    if sns is not None:
        sns.boxplot(data=all_post, x='side', y='fr_hz')
        sns.stripplot(data=all_post, x='side', y='fr_hz', color='k', size=2, alpha=0.3)
    else:
        groups = [all_post[all_post['side']=='CTZ']['fr_hz'].dropna(),
                  all_post[all_post['side']=='VEH']['fr_hz'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.title('FR (Hz) by Treatment — Channels — Combined')
    plt.ylabel('FR (Hz)'); plt.xlabel('')
    plt.show()

if combined_fr_prepost_rows:
    all_pp = pd.concat(combined_fr_prepost_rows, ignore_index=True)
    plt.figure(figsize=(7,4))
    if sns is not None:
        sns.boxplot(data=all_pp, x='side', y='fr_hz', hue='period')
        sns.stripplot(data=all_pp, x='side', y='fr_hz', hue='period', dodge=True, color='k', size=2, alpha=0.3)
        plt.legend_.remove() if hasattr(plt, 'legend_') else None
    else:
        for i, side in enumerate(['CTZ','VEH']):
            grp = [all_pp[(all_pp.side==side)&(all_pp.period=='pre')]['fr_hz'].dropna(),
                   all_pp[(all_pp.side==side)&(all_pp.period=='post')]['fr_hz'].dropna()]
            plt.boxplot(grp, positions=[i*3+1, i*3+2], labels=[f'{side}-pre', f'{side}-post'])
    plt.title('FR (Hz) — Pre vs Post — Combined')
    plt.ylabel('FR (Hz)')
    plt.show()


In [None]:
# PairBurstISIAnalyzer — discovers exports, re-detects spikes, computes ISI and burst durations
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple, List
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt

try:
    import seaborn as sns
except Exception:
    sns = None

try:
    from scipy import signal  # needed for detection
except Exception as e:
    raise RuntimeError("This analyzer requires scipy installed.") from e

def _resolve_exports_dir() -> Path:
    try:
        from mcs_mea_analysis.config import CONFIG
        REPO_ROOT = next((p for p in [Path.cwd(), *Path.cwd().parents] if (p/'mcs_mea_analysis').exists()), Path.cwd())
        output_root = CONFIG.output_root if CONFIG.output_root.exists() else (REPO_ROOT / '_mcs_mea_outputs_local')
        return output_root / 'exports' / 'spikes_waveforms'
    except Exception:
        return Path('/Volumes/Manny2TB/mcs_mea_outputs/exports/spikes_waveforms')

def _parse_bounds_attr(attr_val) -> Tuple[float, float]:
    try:
        if isinstance(attr_val, (bytes, bytearray)):
            attr_val = attr_val.decode()
        d = json.loads(attr_val) if isinstance(attr_val, str) else attr_val
        return float(d.get('t0', 0.0)), float(d.get('t1', 0.0))
    except Exception:
        return 0.0, 0.0

def _noise_level(y: np.ndarray, method: str, pctl: float) -> float:
    if y.size == 0: return np.nan
    if method == 'mad':
        med = np.median(y); return 1.4826 * np.median(np.abs(y - med))
    if method == 'rms':
        return float(np.sqrt(np.mean(np.square(y))))
    if method == 'pctl':
        med = np.median(y); return float(np.percentile(np.abs(y - med), pctl))
    return np.nan

def _detect_prepost(t: np.ndarray, y: np.ndarray, sr_hz: float,
                    t0b: float, t1b: float, t0a: float, t1a: float, dcfg: dict):
    # masks
    mb = (t >= t0b) & (t <= t1b)
    ma = (t >= t0a) & (t <= t1a)
    yb = y[mb]
    noise = _noise_level(yb, str(dcfg.get('noise','mad')), float(dcfg.get('noise_percentile',68.0)))
    if not np.isfinite(noise) or noise <= 0:
        return np.empty(0), np.empty(0)
    thr = float(dcfg.get('K', 5.0)) * noise
    dist = max(1, int(round(float(dcfg.get('refractory_ms',1.0)) * 1e-3 * sr_hz)))
    minw = max(1, int(round(float(dcfg.get('min_width_ms',0.3)) * 1e-3 * sr_hz)))
    pol = str(dcfg.get('polarity','neg'))

    def _detect(mask):
        ta = t[mask]; ya = y[mask]
        if ta.size == 0: return np.empty(0)
        arr = -ya if pol in ('neg','both') else ya
        peaks, _ = signal.find_peaks(arr, height=thr, distance=dist, width=minw)
        return ta[peaks] if peaks.size else np.empty(0)

    st_pre  = _detect(mb)
    st_post = _detect(ma)
    return st_pre, st_post

def _bursts_from_spikes(st_s: np.ndarray, isi_thr_ms: float, min_spikes: int) -> List[Tuple[int,int,float,int]]:
    """
    Return bursts as list of tuples: (start_idx, end_idx, dur_ms, n_spikes)
    - start_idx/end_idx are spike indices in st_s (inclusive)
    """
    if st_s.size < min_spikes: return []
    isi_s = np.diff(st_s)
    thr_s = isi_thr_ms / 1000.0
    m = isi_s <= thr_s
    bursts = []
    i = 0
    while i < m.size:
        if not m[i]:
            i += 1
            continue
        j = i
        while j + 1 < m.size and m[j+1]:
            j += 1
        # run i..j of True spans spikes [i .. j+1]
        n_spikes = (j - i + 1) + 1
        if n_spikes >= min_spikes:
            s_idx = i
            e_idx = j + 1
            dur_ms = (st_s[e_idx] - st_s[s_idx]) * 1000.0
            bursts.append((s_idx, e_idx, float(dur_ms), int(n_spikes)))
        i = j + 1
    return bursts

@dataclass
class PairBurstISIAnalyzer:
    exports_dir: Path = _resolve_exports_dir()

    def discover_pairs(self) -> pd.DataFrame:
        pairs = []
        for h5 in sorted(self.exports_dir.rglob('*.h5')):
            if h5.name.endswith('_summary.h5'):
                continue
            try:
                round_name = h5.parents[1].name if len(h5.parents) > 1 else None
                plate = None
                try:
                    ps = h5.parent.name
                    plate = int(ps.replace('plate_', '')) if ps.startswith('plate_') else None
                except Exception:
                    pass
                pairs.append({'pair_id': h5.stem, 'h5_path': str(h5), 'round': round_name, 'plate': plate})
            except Exception:
                continue
        return pd.DataFrame(pairs).sort_values(['plate','round','pair_id']).reset_index(drop=True)

    def spikes_prepost_for_pair(self, h5_path: Path) -> pd.DataFrame:
        rows = []
        with h5py.File(h5_path.as_posix(),'r') as f:
            dcfg = {}
            if 'detect_config_json' in f:
                try: dcfg = json.loads(f['detect_config_json'][()].decode())
                except Exception: dcfg = {}
            for side in ('CTZ','VEH'):
                if side not in f: continue
                g = f[side]
                sr = float(g.attrs.get('sr_hz', 0.0)) or 10000.0
                t0b, t1b = _parse_bounds_attr(g.attrs.get('baseline_bounds', '{"t0":0,"t1":0}'))
                t0a, t1a = _parse_bounds_attr(g.attrs.get('analysis_bounds', '{"t0":0,"t1":0}'))
                chs = sorted(int(k[2:4]) for k in g.keys() if k.startswith('ch') and k.endswith('_time'))
                for ch in chs:
                    t = g.get(f'ch{ch:02d}_time', None)
                    y = g.get(f'ch{ch:02d}_filtered', None)
                    if t is None or y is None: continue
                    t = t[:]; y = y[:]
                    if t.size == 0 or y.size == 0: continue
                    st_pre, st_post = _detect_prepost(t, y, sr, t0b, t1b, t0a, t1a, dcfg)
                    rows.append({'side': side, 'channel': ch, 'period': 'pre',  'spike_times_s': st_pre})
                    rows.append({'side': side, 'channel': ch, 'period': 'post', 'spike_times_s': st_post})
        return pd.DataFrame(rows)

    def compute_isi_and_bursts(self, pairs_df: pd.DataFrame, isi_thr_ms: float = 100.0, min_spikes: int = 3):
        isi_rows, burst_rows = [], []
        for _, r in pairs_df.iterrows():
            h5p = Path(r['h5_path'])
            spikes_df = self.spikes_prepost_for_pair(h5p)
            if spikes_df.empty: continue
            spikes_df['pair_id'] = r['pair_id']; spikes_df['plate'] = r['plate']; spikes_df['round'] = r['round']
            for _, row in spikes_df.iterrows():
                st = np.asarray(row['spike_times_s'], dtype=float)
                if st.size >= 2:
                    isi_ms = np.diff(st) * 1000.0
                    for v in isi_ms:
                        isi_rows.append({'pair_id': row['pair_id'], 'plate': row['plate'], 'round': row['round'],
                                         'side': row['side'], 'channel': row['channel'], 'period': row['period'],
                                         'isi_ms': float(v)})
                # bursts
                bursts = _bursts_from_spikes(st, isi_thr_ms=isi_thr_ms, min_spikes=min_spikes)
                for s_idx, e_idx, dur_ms, nsp in bursts:
                    burst_rows.append({'pair_id': row['pair_id'], 'plate': row['plate'], 'round': row['round'],
                                       'side': row['side'], 'channel': row['channel'], 'period': row['period'],
                                       'burst_dur_ms': float(dur_ms), 'n_spikes': int(nsp)})
        isi_df   = pd.DataFrame(isi_rows)   if isi_rows   else pd.DataFrame(columns=['pair_id','plate','round','side','channel','period','isi_ms'])
        burst_df = pd.DataFrame(burst_rows) if burst_rows else pd.DataFrame(columns=['pair_id','plate','round','side','channel','period','burst_dur_ms','n_spikes'])
        return isi_df, burst_df

    # Plot helpers
    def plot_isi_pair(self, isi_df: pd.DataFrame, pair_id: str, logy: bool = True):
        sub = isi_df[isi_df['pair_id']==pair_id].copy()
        if sub.empty: print('No ISI data for', pair_id); return
        plt.figure(figsize=(7,4))
        if sns is not None:
            sns.boxplot(data=sub, x='side', y='isi_ms', hue='period')
            sns.stripplot(data=sub, x='side', y='isi_ms', hue='period', dodge=True, color='k', size=2, alpha=0.3)
            plt.legend_.remove() if hasattr(plt, 'legend_') else None
        else:
            for i, side in enumerate(['CTZ','VEH']):
                grp = [sub[(sub.side==side)&(sub.period=='pre')]['isi_ms'].dropna(),
                       sub[(sub.side==side)&(sub.period=='post')]['isi_ms'].dropna()]
                plt.boxplot(grp, positions=[i*3+1,i*3+2], labels=[f'{side}-pre', f'{side}-post'])
        plt.title(f'ISI (ms) — {pair_id}'); plt.ylabel('ISI (ms)'); plt.xlabel('')
        if logy: plt.yscale('log')
        plt.show()

    def plot_isi_all(self, isi_df: pd.DataFrame, logy: bool = True):
        sub = isi_df.copy()
        if sub.empty: print('No ISI data.'); return
        plt.figure(figsize=(7,4))
        if sns is not None:
            sns.boxplot(data=sub, x='side', y='isi_ms', hue='period')
            sns.stripplot(data=sub, x='side', y='isi_ms', hue='period', dodge=True, color='k', size=2, alpha=0.3)
            plt.legend_.remove() if hasattr(plt, 'legend_') else None
        else:
            for i, side in enumerate(['CTZ','VEH']):
                grp = [sub[(sub.side==side)&(sub.period=='pre')]['isi_ms'].dropna(),
                       sub[(sub.side==side)&(sub.period=='post')]['isi_ms'].dropna()]
                plt.boxplot(grp, positions=[i*3+1,i*3+2], labels=[f'{side}-pre', f'{side}-post'])
        plt.title('ISI (ms) — Combined Across Pairs'); plt.ylabel('ISI (ms)'); plt.xlabel('')
        if logy: plt.yscale('log')
        plt.show()

    def plot_burst_pair(self, burst_df: pd.DataFrame, pair_id: str, logy: bool = False):
        sub = burst_df[burst_df['pair_id']==pair_id].copy()
        if sub.empty: print('No burst data for', pair_id); return
        plt.figure(figsize=(7,4))
        if sns is not None:
            sns.boxplot(data=sub, x='side', y='burst_dur_ms', hue='period')
            sns.stripplot(data=sub, x='side', y='burst_dur_ms', hue='period', dodge=True, color='k', size=2, alpha=0.3)
            plt.legend_.remove() if hasattr(plt, 'legend_') else None
        else:
            for i, side in enumerate(['CTZ','VEH']):
                grp = [sub[(sub.side==side)&(sub.period=='pre')]['burst_dur_ms'].dropna(),
                       sub[(sub.side==side)&(sub.period=='post')]['burst_dur_ms'].dropna()]
                plt.boxplot(grp, positions=[i*3+1,i*3+2], labels=[f'{side}-pre', f'{side}-post'])
        plt.title(f'Burst Duration (ms) — {pair_id}'); plt.ylabel('Burst duration (ms)'); plt.xlabel('')
        if logy: plt.yscale('log')
        plt.show()

    def plot_burst_all(self, burst_df: pd.DataFrame, logy: bool = False):
        sub = burst_df.copy()
        if sub.empty: print('No burst data.'); return
        plt.figure(figsize=(7,4))
        if sns is not None:
            sns.boxplot(data=sub, x='side', y='burst_dur_ms', hue='period')
            sns.stripplot(data=sub, x='side', y='burst_dur_ms', hue='period', dodge=True, color='k', size=2, alpha=0.3)
            plt.legend_.remove() if hasattr(plt, 'legend_') else None
        else:
            for i, side in enumerate(['CTZ','VEH']):
                grp = [sub[(sub.side==side)&(sub.period=='pre')]['burst_dur_ms'].dropna(),
                       sub[(sub.side==side)&(sub.period=='post')]['burst_dur_ms'].dropna()]
                plt.boxplot(grp, positions=[i*3+1,i*3+2], labels=[f'{side}-pre', f'{side}-post'])
        plt.title('Burst Duration (ms) — Combined Across Pairs'); plt.ylabel('Burst duration (ms)'); plt.xlabel('')
        if logy: plt.yscale('log')
        plt.show()


In [None]:
# Compute ISI and burst metrics; plot per pair and combined
an = PairBurstISIAnalyzer()
pairs_df = an.discover_pairs()
print('Discovered pairs:', len(pairs_df))
display(pairs_df)

# Choose first 3 pairs; set sel = pairs_df to use all
sel = pairs_df.iloc[:3] if len(pairs_df) > 3 else pairs_df

# Tunables: burst definition
ISI_THR_MS = 100.0   # ISI threshold to link spikes into a burst
MIN_SPIKES = 3       # min spikes per burst

isi_df, burst_df = an.compute_isi_and_bursts(sel, isi_thr_ms=ISI_THR_MS, min_spikes=MIN_SPIKES)
print('ISI rows:', len(isi_df), '| Burst rows:', len(burst_df))

# Per‑pair plots (separate for each selected pair)
for pid in sel['pair_id']:
    an.plot_isi_pair(isi_df, pid, logy=True)
    an.plot_burst_pair(burst_df, pid, logy=False)

# Combined across selected pairs
an.plot_isi_all(isi_df, logy=True)
an.plot_burst_all(burst_df, logy=False)


In [None]:
# Post-only CTZ vs VEH: FR (channels), ISI, and Burst Duration
from pathlib import Path
import numpy as np, pandas as pd, h5py, matplotlib.pyplot as plt

try:
    import seaborn as sns
except Exception:
    sns = None

# Resolve exports directory
try:
    from mcs_mea_analysis.config import CONFIG
    REPO_ROOT = next((p for p in [Path.cwd(), *Path.cwd().parents] if (p/'mcs_mea_analysis').exists()), Path.cwd())
    OUTPUT_ROOT = CONFIG.output_root if CONFIG.output_root.exists() else (REPO_ROOT / '_mcs_mea_outputs_local')
    EXPORTS_DIR = OUTPUT_ROOT / 'exports' / 'spikes_waveforms'
except Exception:
    EXPORTS_DIR = Path('/Volumes/Manny2TB/mcs_mea_outputs/exports/spikes_waveforms')

print('Exports dir ->', EXPORTS_DIR)

# 1) Post FR (channel-level) comes straight from the per-pair _summary.csv files
rows = []
for csvp in EXPORTS_DIR.rglob('*_summary.csv'):
    try:
        base = csvp.stem.replace('_summary','')
        df = pd.read_csv(csvp)
        df['side'] = df['side'].astype(str)
        df['pair_id'] = base
        rows.append(df[['pair_id','side','channel','fr_hz']])
    except Exception as e:
        print('Skip', csvp.name, ':', e)

fr_post = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=['pair_id','side','channel','fr_hz'])
print('FR (post) rows:', len(fr_post), 'from', len(rows), 'pairs')

if not fr_post.empty:
    plt.figure(figsize=(6,4))
    if sns is not None:
        sns.boxplot(data=fr_post, x='side', y='fr_hz')
        sns.stripplot(data=fr_post, x='side', y='fr_hz', color='k', size=2, alpha=0.3)
    else:
        groups = [fr_post[fr_post.side=='CTZ']['fr_hz'].dropna(), fr_post[fr_post.side=='VEH']['fr_hz'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.title('Post FR (Hz) — Channels — CTZ vs VEH')
    plt.ylabel('FR (Hz)'); plt.xlabel('')
    plt.show()

# 2) ISI and 3) Burst duration (post) from exported post spike timestamps in the H5
ISI_THR_MS = 100.0   # spikes closer than this ISI are part of the same burst
MIN_SPIKES = 3       # minimum spikes per burst

def bursts_from(spikes_s: np.ndarray, thr_ms: float, min_sp: int):
    s = np.asarray(spikes_s, float)
    if s.size < min_sp: return []
    isi_s = np.diff(s); thr_s = thr_ms / 1000.0
    m = isi_s <= thr_s
    out = []; i = 0
    while i < m.size:
        if not m[i]: i += 1; continue
        j = i
        while j + 1 < m.size and m[j+1]: j += 1
        n_sp = (j - i + 1) + 1
        if n_sp >= min_sp:
            s_idx, e_idx = i, j + 1
            dur_ms = (s[e_idx] - s[s_idx]) * 1000.0
            out.append((s_idx, e_idx, float(dur_ms), int(n_sp)))
        i = j + 1
    return out

isi_rows, burst_rows = [], []
h5_files = [p for p in EXPORTS_DIR.rglob('*.h5') if not p.name.endswith('_summary.h5')]
for h5 in h5_files:
    pair_id = h5.stem
    with h5py.File(h5.as_posix(), 'r') as f:
        for side in ('CTZ','VEH'):
            if side not in f: continue
            g = f[side]
            # Post timestamps are already the analysis-window (post) spikes
            ch_idxs = sorted(int(k[2:4]) for k in g.keys() if k.startswith('ch') and k.endswith('_timestamps'))
            for ch in ch_idxs:
                ds = f'{side}/ch{ch:02d}_timestamps'
                if ds not in f: continue
                st = np.asarray(f[ds][:], float)
                # ISI (ms)
                if st.size >= 2:
                    isi_ms = np.diff(st) * 1000.0
                    isi_rows.extend({'pair_id': pair_id, 'side': side, 'channel': ch, 'isi_ms': float(v)} for v in isi_ms)
                # Burst durations (ms)
                for _, _, dur_ms, nsp in bursts_from(st, ISI_THR_MS, MIN_SPIKES):
                    burst_rows.append({'pair_id': pair_id, 'side': side, 'channel': ch, 'burst_dur_ms': float(dur_ms), 'n_spikes': int(nsp)})

isi_df   = pd.DataFrame(isi_rows)   if isi_rows   else pd.DataFrame(columns=['pair_id','side','channel','isi_ms'])
burst_df = pd.DataFrame(burst_rows) if burst_rows else pd.DataFrame(columns=['pair_id','side','channel','burst_dur_ms','n_spikes'])
print('ISI rows:', len(isi_df), '| Burst rows:', len(burst_df))

# ISI (post) — CTZ vs VEH
if not isi_df.empty:
    plt.figure(figsize=(7,4))
    if sns is not None:
        sns.boxplot(data=isi_df, x='side', y='isi_ms')
        sns.stripplot(data=isi_df, x='side', y='isi_ms', color='k', size=2, alpha=0.3)
    else:
        groups = [isi_df[isi_df.side=='CTZ']['isi_ms'].dropna(), isi_df[isi_df.side=='VEH']['isi_ms'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.yscale('log')
    plt.title('Post ISI (ms) — CTZ vs VEH')
    plt.ylabel('ISI (ms)'); plt.xlabel('')
    plt.show()

# Burst duration (post) — CTZ vs VEH
if not burst_df.empty:
    plt.figure(figsize=(7,4))
    if sns is not None:
        sns.boxplot(data=burst_df, x='side', y='burst_dur_ms')
        sns.stripplot(data=burst_df, x='side', y='burst_dur_ms', color='k', size=2, alpha=0.3)
    else:
        groups = [burst_df[burst_df.side=='CTZ']['burst_dur_ms'].dropna(), burst_df[burst_df.side=='VEH']['burst_dur_ms'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.title('Post Burst Duration (ms) — CTZ vs VEH')
    plt.ylabel('Burst duration (ms)'); plt.xlabel('')
    plt.show()


In [None]:
# Object-based analyzer for baseline (pre-chem) FR only
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt

try:
    import seaborn as sns
except Exception:
    sns = None

try:
    from scipy import signal  # used for detection
except Exception as e:
    raise RuntimeError("This analyzer requires scipy installed.") from e

def _resolve_exports_dir() -> Path:
    try:
        from mcs_mea_analysis.config import CONFIG
        REPO_ROOT = next((p for p in [Path.cwd(), *Path.cwd().parents] if (p/'mcs_mea_analysis').exists()), Path.cwd())
        output_root = CONFIG.output_root if CONFIG.output_root.exists() else (REPO_ROOT / '_mcs_mea_outputs_local')
        return output_root / 'exports' / 'spikes_waveforms'
    except Exception:
        return Path('/Volumes/Manny2TB/mcs_mea_outputs/exports/spikes_waveforms')

def _parse_bounds_attr(attr_val) -> Tuple[float, float]:
    try:
        if isinstance(attr_val, (bytes, bytearray)):
            attr_val = attr_val.decode()
        d = json.loads(attr_val) if isinstance(attr_val, str) else attr_val
        return float(d.get('t0', 0.0)), float(d.get('t1', 0.0))
    except Exception:
        return 0.0, 0.0

def _noise_level(y: np.ndarray, method: str, pctl: float) -> float:
    if y.size == 0:
        return np.nan
    if method == 'mad':
        med = np.median(y)
        return 1.4826 * np.median(np.abs(y - med))
    if method == 'rms':
        return float(np.sqrt(np.mean(np.square(y))))
    if method == 'pctl':
        med = np.median(y)
        return float(np.percentile(np.abs(y - med), pctl))
    return np.nan

def _detect_prepost(t: np.ndarray, y: np.ndarray, sr_hz: float,
                    t0b: float, t1b: float, t0a: float, t1a: float, dcfg: dict):
    # masks
    mb = (t >= t0b) & (t <= t1b)
    ma = (t >= t0a) & (t <= t1a)
    yb = y[mb]
    noise = _noise_level(yb, str(dcfg.get('noise','mad')), float(dcfg.get('noise_percentile',68.0)))
    if not np.isfinite(noise) or noise <= 0:
        return np.empty(0), np.empty(0), np.empty(0), np.empty(0)
    thr = float(dcfg.get('K', 5.0)) * noise
    dist = max(1, int(round(float(dcfg.get('refractory_ms',1.0)) * 1e-3 * sr_hz)))
    minw = max(1, int(round(float(dcfg.get('min_width_ms',0.3)) * 1e-3 * sr_hz)))
    pol = str(dcfg.get('polarity','neg'))

    def _detect(mask):
        ta = t[mask]; ya = y[mask]
        if ta.size == 0:
            return np.empty(0), np.empty(0)
        arr = -ya if pol in ('neg','both') else ya
        peaks, _ = signal.find_peaks(arr, height=thr, distance=dist, width=minw)
        widths, _, _, _ = signal.peak_widths(arr, peaks, rel_height=0.5) if peaks.size else (np.empty(0),)*4
        st = ta[peaks] if peaks.size else np.empty(0)
        w_ms = (widths / float(sr_hz)) * 1000.0 if widths.size else np.empty(0)
        return st, w_ms

    st_pre, _ = _detect(mb)
    st_post, _ = _detect(ma)
    return st_pre, st_post

@dataclass
class PairExportsAnalyzer:
    exports_dir: Path = _resolve_exports_dir()

    def discover_pairs(self) -> pd.DataFrame:
        pairs = []
        for h5 in sorted(self.exports_dir.rglob('*.h5')):
            if h5.name.endswith('_summary.h5'):
                continue
            try:
                round_name = h5.parents[1].name if len(h5.parents) > 1 else None
                plate = None
                try:
                    ps = h5.parent.name
                    plate = int(ps.replace('plate_', '')) if ps.startswith('plate_') else None
                except Exception:
                    pass
                pairs.append({'pair_id': h5.stem, 'h5_path': str(h5), 'round': round_name, 'plate': plate})
            except Exception:
                continue
        return pd.DataFrame(pairs).sort_values(['plate','round','pair_id']).reset_index(drop=True)

    def fr_prepost_for_pair(self, h5_path: Path) -> pd.DataFrame:
        rows = []
        with h5py.File(h5_path.as_posix(),'r') as f:
            dcfg = {}
            if 'detect_config_json' in f:
                try:
                    dcfg = json.loads(f['detect_config_json'][()].decode())
                except Exception:
                    dcfg = {}
            for side in ('CTZ','VEH'):
                if side not in f: 
                    continue
                g = f[side]
                sr = float(g.attrs.get('sr_hz', 0.0)) or 10000.0
                t0b, t1b = _parse_bounds_attr(g.attrs.get('baseline_bounds', '{"t0":0,"t1":0}'))
                t0a, t1a = _parse_bounds_attr(g.attrs.get('analysis_bounds', '{"t0":0,"t1":0}'))
                chs = sorted(int(k[2:4]) for k in g.keys() if k.startswith('ch') and k.endswith('_time'))
                for ch in chs:
                    t = g.get(f'ch{ch:02d}_time', None)
                    y = g.get(f'ch{ch:02d}_filtered', None)
                    if t is None or y is None:
                        continue
                    t = t[:]; y = y[:]
                    if t.size == 0 or y.size == 0:
                        continue
                    st_pre, st_post = _detect_prepost(t, y, sr, t0b, t1b, t0a, t1a, dcfg)
                    dur_pre = max(1e-9, (t1b - t0b))
                    dur_post = max(1e-9, (t1a - t0a))
                    fr_pre = (st_pre.size / dur_pre) if dur_pre > 0 else np.nan
                    fr_post = (st_post.size / dur_post) if dur_post > 0 else np.nan
                    rows.append({'side': side, 'channel': ch, 'period': 'pre', 'fr_hz': fr_pre})
                    rows.append({'side': side, 'channel': ch, 'period': 'post', 'fr_hz': fr_post})
        return pd.DataFrame(rows)

    def compute_all(self, pairs_df: pd.DataFrame) -> pd.DataFrame:
        frames = []
        for _, r in pairs_df.iterrows():
            h5p = Path(r['h5_path'])
            d = self.fr_prepost_for_pair(h5p)
            if not d.empty:
                d['pair_id'] = r['pair_id']; d['plate'] = r['plate']; d['round'] = r['round']
                frames.append(d)
        cols = ['pair_id','plate','round','side','channel','period','fr_hz']
        return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame(columns=cols)

    def plot_baseline_pair(self, df: pd.DataFrame, pair_id: str) -> None:
        sub = df[(df['pair_id']==pair_id) & (df['period']=='pre')].copy()
        if sub.empty:
            print('No baseline data for', pair_id); return
        plt.figure(figsize=(6,4))
        if sns is not None:
            sns.boxplot(data=sub, x='side', y='fr_hz')
            sns.stripplot(data=sub, x='side', y='fr_hz', color='k', size=2, alpha=0.3)
        else:
            groups = [sub[sub['side']=='CTZ']['fr_hz'].dropna(), sub[sub['side']=='VEH']['fr_hz'].dropna()]
            plt.boxplot(groups, labels=['CTZ','VEH'])
        plt.title(f'Baseline FR (Hz) — {pair_id}')
        plt.ylabel('FR (Hz)'); plt.xlabel('')
        plt.show()

    def plot_baseline_all(self, df: pd.DataFrame) -> None:
        sub = df[df['period']=='pre'].copy()
        if sub.empty:
            print('No baseline data.'); return
        plt.figure(figsize=(6,4))
        if sns is not None:
            sns.boxplot(data=sub, x='side', y='fr_hz')
            sns.stripplot(data=sub, x='side', y='fr_hz', color='k', size=2, alpha=0.3)
        else:
            groups = [sub[sub['side']=='CTZ']['fr_hz'].dropna(), sub[sub['side']=='VEH']['fr_hz'].dropna()]
            plt.boxplot(groups, labels=['CTZ','VEH'])
        plt.title('Baseline FR (Hz) — Combined Across Pairs')
        plt.ylabel('FR (Hz)'); plt.xlabel('')
        plt.show()

    def plot_baseline_pair_means(self, df: pd.DataFrame) -> None:
        sub = df[df['period']=='pre'].copy()
        if sub.empty:
            print('No baseline data.'); return
        means = (sub.groupby(['pair_id','side'], as_index=False)['fr_hz']
                   .mean().rename(columns={'fr_hz':'mean_fr_hz'}))
        plt.figure(figsize=(6,4))
        if sns is not None:
            sns.boxplot(data=means, x='side', y='mean_fr_hz')
            sns.stripplot(data=means, x='side', y='mean_fr_hz', color='k', size=3, alpha=0.5)
        else:
            groups = [means[means['side']=='CTZ']['mean_fr_hz'].dropna(),
                      means[means['side']=='VEH']['mean_fr_hz'].dropna()]
            plt.boxplot(groups, labels=['CTZ','VEH'])
        plt.title('Baseline FR (Hz) — Per-Pair Means')
        plt.ylabel('Mean FR (Hz)'); plt.xlabel('')
        plt.show()


In [None]:
# Use the analyzer — baseline-only plots, by pair and combined
an = PairExportsAnalyzer()
pairs_df = an.discover_pairs()
print('Discovered pairs:', len(pairs_df))
display(pairs_df)

# Limit to first 3 if you want; or use all by setting sel = pairs_df
sel = pairs_df.iloc[:3] if len(pairs_df) > 3 else pairs_df

fr_df = an.compute_all(sel)
print('Computed rows:', len(fr_df))

# Combined baseline CTZ vs VEH across selected pairs
an.plot_baseline_all(fr_df)

# Baseline per pair (CTZ vs VEH), one figure per pair
for pid in sel['pair_id']:
    an.plot_baseline_pair(fr_df, pid)

# Baseline per-pair means (one mean per side per pair)
an.plot_baseline_pair_means(fr_df)
