# MEA Exports Analysis — Spikes + Waveforms

Analyze and aggregate the per-pair exports produced by the Pair Viewer or the batch exporter.

- Exports location: `<output_root>/exports/spikes_waveforms/...`
- Format details: see `docs/exports_spikes_waveforms.md` in this repo (variable types, dataset names, and attributes).
- This notebook auto-discovers pairs, loads HDF5/CSV per pair, and can aggregate per-plate.

Tip: If your external drive is not mounted, set `OUTPUT_ROOT` to `_mcs_mea_outputs_local`.

## FR Box Plots
Channel-level and per-pair mean firing rate distributions comparing CTZ vs VEH across all exported pairs.
This cell discovers exports on disk and does not depend on prior variables.

In [None]:
# Self-contained FR box plots (no dependency on df_pairs)
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

try:
    import seaborn as sns
except Exception:
    sns = None

# Resolve exports directory
try:
    EXPORTS_DIR
except NameError:
    try:
        from mcs_mea_analysis.config import CONFIG
        REPO_ROOT = next((p for p in [Path.cwd(), *Path.cwd().parents] if (p/'mcs_mea_analysis').exists()), Path.cwd())
        OUTPUT_ROOT = CONFIG.output_root if CONFIG.output_root.exists() else (REPO_ROOT / '_mcs_mea_outputs_local')
        EXPORTS_DIR = OUTPUT_ROOT / 'exports' / 'spikes_waveforms'
    except Exception:
        EXPORTS_DIR = Path('/Volumes/Manny2TB/mcs_mea_outputs/exports/spikes_waveforms')
print('Exports dir ->', EXPORTS_DIR)

# Load all per-pair summary CSVs
rows = []
for csvp in EXPORTS_DIR.rglob('*_summary.csv'):
    try:
        # Extract plate/round from path
        round_name = csvp.parents[1].name if len(csvp.parents) > 1 else None
        plate = None
        try:
            ps = csvp.parent.name
            plate = int(ps.replace('plate_', '')) if ps.startswith('plate_') else None
        except Exception:
            pass
        base = csvp.stem.replace('_summary','')
        pair_id = base
        df = pd.read_csv(csvp)
        # Types
        df['fr_hz'] = pd.to_numeric(df['fr_hz'], errors='coerce')
        df['channel'] = pd.to_numeric(df['channel'], errors='coerce').astype('Int64')
        df['side'] = df['side'].astype(str)
        df['pair_id'] = pair_id
        df['plate'] = plate
        df['round'] = round_name
        rows.append(df)
    except Exception as e:
        print('Skip', csvp.name, ':', e)

data = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=['channel','side','n_spikes','fr_hz','pair_id','plate','round'])
print('Rows loaded:', len(data), 'from', len(rows), 'pairs')

if data.empty:
    print('No summary CSVs found under', EXPORTS_DIR)
else:
    # Channel-level box plot (CTZ vs VEH)
    plt.figure(figsize=(6,4))
    if sns is not None:
        sns.boxplot(data=data, x='side', y='fr_hz')
        sns.stripplot(data=data, x='side', y='fr_hz', color='k', size=2, alpha=0.3)
    else:
        # Fallback to matplotlib only
        groups = [data[data['side']=='CTZ']['fr_hz'].dropna(), data[data['side']=='VEH']['fr_hz'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.title('Firing Rate (Hz) by Treatment — Channel Level')
    plt.ylabel('FR (Hz)'); plt.xlabel('')
    plt.show()

    # Per-pair mean FR box plot
    mean_per_pair = (
        data.groupby(['pair_id','side'], as_index=False)['fr_hz'].mean()
        .rename(columns={'fr_hz':'mean_fr_hz'})
    )
    plt.figure(figsize=(6,4))
    if sns is not None:
        sns.boxplot(data=mean_per_pair, x='side', y='mean_fr_hz')
        sns.stripplot(data=mean_per_pair, x='side', y='mean_fr_hz', color='k', size=3, alpha=0.5)
    else:
        groups = [mean_per_pair[mean_per_pair['side']=='CTZ']['mean_fr_hz'].dropna(),
                  mean_per_pair[mean_per_pair['side']=='VEH']['mean_fr_hz'].dropna()]
        plt.boxplot(groups, labels=['CTZ','VEH'])
    plt.title('Mean Firing Rate (Hz) by Treatment — Per-Pair Means')
    plt.ylabel('Mean FR (Hz)'); plt.xlabel('')
    plt.show()


## Pair-Level Summary
Roll up aggregated channel rows into one row per exported pair (CTZ vs VEH),
with mean FRs and mean ΔFR across the pair's channels (optionally filtered to accepted channels).

In [None]:
if 'aggregated_all' in globals() and not aggregated_all.empty:
    group_cols = ['plate','round','ctz_stem','veh_stem']
    pair_summary = (
        aggregated_all
        .groupby(group_cols)
        .agg(n_channels=('channel','count'), ctz_mean_hz=('CTZ','mean'), veh_mean_hz=('VEH','mean'), delta_fr_mean_hz=('delta_fr_hz','mean'))
        .reset_index()
        .sort_values(group_cols)
    )
    print('Pairs summarized:', len(pair_summary))
    display(pair_summary)
else:
    print('Aggregate all plates first to build `aggregated_all`.')


## Aggregate All Pairs
Combine every exported pair across all plates into one DataFrame; optionally save a single CSV.

In [None]:
# Aggregate all plates found in df_pairs
if df_pairs.empty:
    print('No pairs found. Run exports first.')
    aggregated_all = pd.DataFrame(columns=['plate','round','ctz_stem','veh_stem','channel','CTZ','VEH','delta_fr_hz'])
else:
    plates = sorted(df_pairs['plate'].dropna().astype(int).unique().tolist())
    frames = []
    for p in plates:
        g = aggregate_plate(int(p), use_selections=True)
        if not g.empty:
            frames.append(g)
    aggregated_all = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame(columns=['plate','round','ctz_stem','veh_stem','channel','CTZ','VEH','delta_fr_hz'])
    print('Aggregated rows (all plates):', len(aggregated_all))
    display(aggregated_all.head(20))

# Quick rollups
if not aggregated_all.empty:
    print('
Delta FR mean per plate:')
    display(aggregated_all.groupby('plate')['delta_fr_hz'].mean().to_frame('delta_fr_mean_hz'))
    print('
Counts per plate:')
    display(aggregated_all.groupby('plate').size().to_frame('n_rows'))


In [None]:
# Save a single combined CSV (all plates)
if 'aggregated_all' in globals() and not aggregated_all.empty:
    out_csv_all = EXPORTS_DIR / 'all_plates_aggregate_summary.csv'
    aggregated_all.to_csv(out_csv_all, index=False)
    print('Wrote ->', out_csv_all)
else:
    print('Nothing to save (no aggregated data).')


## Pairs Overview
Quick ways to see how many pairs were discovered and which recordings they are.

- `df_pairs` is the discovered pairs index (one row per exported pair).
- Use the cells below to summarize counts and list pairs.

In [None]:
# Summary counts and listing of pairs
if df_pairs.empty:
    print('No pairs found. Run GUI export or batch exporter, then re-run discovery.')
else:
    print('Total pairs:', len(df_pairs))
    print('
Pairs per plate:')
    display(df_pairs.groupby('plate').size().to_frame('n_pairs').reset_index().sort_values('plate'))
    print('
Pairs per plate and round:')
    display(df_pairs.groupby(['plate','round']).size().to_frame('n_pairs').reset_index().sort_values(['plate','round']))
    print('
Pair list (plate, round, stems):')
    cols = ['plate','round','ctz_stem','veh_stem']
    display(df_pairs[cols].sort_values(cols).reset_index(drop=True))
    # Example filter: set a plate number you care about
    # plate_focus = 5
    # display(df_pairs[df_pairs['plate']==plate_focus][cols].sort_values(cols))


In [None]:
# Setup
%gui qt5  # harmless outside Qt
import sys, json
from pathlib import Path
import numpy as np
import pandas as pd
import h5py

# Ensure repo root on path
def _ensure_repo_on_path():
    here = Path.cwd()
    for cand in [here, *here.parents]:
        if (cand / 'mcs_mea_analysis').exists():
            if str(cand) not in sys.path:
                sys.path.insert(0, str(cand))
            return cand
    return here
REPO_ROOT = _ensure_repo_on_path()

from mcs_mea_analysis.config import CONFIG

# Pick output root: external drive if present, else local mirror
OUTPUT_ROOT = CONFIG.output_root if CONFIG.output_root.exists() else (REPO_ROOT / '_mcs_mea_outputs_local')
EXPORTS_DIR = OUTPUT_ROOT / 'exports' / 'spikes_waveforms'
print('Using OUTPUT_ROOT ->', OUTPUT_ROOT)
print('Exports dir ->', EXPORTS_DIR)

In [None]:
# Discover exported pairs (HDF5 + CSV summary)
pairs = []
for h5 in EXPORTS_DIR.rglob('*.h5'):
    if h5.name.endswith('_summary.h5'):
        continue
    # Path structure: .../<round>/plate_<N>/<CTZ>__VS__<VEH>.h5
    try:
        round_name = h5.parents[1].name  # immediate parent is plate_*, next is round
        plate_str = h5.parent.name
        plate = int(plate_str.replace('plate_', '')) if plate_str.startswith('plate_') else None
    except Exception:
        round_name, plate = None, None
    base = h5.stem
    if '__VS__' in base:
        ctz_stem, veh_stem = base.split('__VS__', 1)
    else:
        ctz_stem, veh_stem = base, ''
    csv_sum = h5.with_name(h5.stem + '_summary.csv')
    pairs.append({
        'round': round_name, 'plate': plate,
        'ctz_stem': ctz_stem, 'veh_stem': veh_stem,
        'h5_path': str(h5), 'csv_summary': str(csv_sum) if csv_sum.exists() else ''
    })

df_pairs = pd.DataFrame(pairs).sort_values(['plate','round','ctz_stem']).reset_index(drop=True)
print('Found pairs:', len(df_pairs))
df_pairs.head(3)

In [None]:
# Helpers to read exported HDF5
from typing import Optional, Tuple, Dict, Any

def read_pair_attrs(h5_path: Path) -> Dict[str, Any]:
    out: Dict[str, Any] = {}
    with h5py.File(h5_path.as_posix(), 'r') as f:
        # Root attrs
        for k, v in f.attrs.items():
            out[k] = v
        # Config JSON datasets
        def _json_of(name: str):
            if name in f:
                try:
                    return json.loads(f[name][()].decode())
                except Exception:
                    return None
            return None
        out['filter_config'] = _json_of('filter_config_json')
        out['detect_config'] = _json_of('detect_config_json')
        # Group-level attrs
        for side in ('CTZ','VEH'):
            if side in f:
                g = f[side]
                out[f'{side.lower()}_sr_hz'] = float(g.attrs.get('sr_hz', 0.0))
                out[f'{side.lower()}_baseline_bounds'] = g.attrs.get('baseline_bounds', '')
                out[f'{side.lower()}_analysis_bounds'] = g.attrs.get('analysis_bounds', '')
    return out

def load_channel(h5_path: Path, side: str, ch: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Return (time, raw, filtered, timestamps, waveforms) for one channel.
    side in {'CTZ','VEH'}.
    """
    with h5py.File(h5_path.as_posix(), 'r') as f:
        g = f[side]
        t = g[f'ch{ch:02d}_time'][:] if f'ch{ch:02d}_time' in g else np.empty(0)
        raw = g[f'ch{ch:02d}_raw'][:] if f'ch{ch:02d}_raw' in g else np.empty(0)
        fil = g[f'ch{ch:02d}_filtered'][:] if f'ch{ch:02d}_filtered' in g else np.empty(0)
        ts = g[f'ch{ch:02d}_timestamps'][:] if f'ch{ch:02d}_timestamps' in g else np.empty(0)
        wf = g[f'ch{ch:02d}_waveforms'][:] if f'ch{ch:02d}_waveforms' in g else np.empty((0,0))
    return t, raw, fil, ts, wf

def n_channels(h5_path: Path, side: str='CTZ') -> int:
    with h5py.File(h5_path.as_posix(), 'r') as f:
        g = f[side]
        # count by presence of ch00_time datasets
        k = [k for k in g.keys() if k.startswith('ch') and k.endswith('_time')]
        return len(k)

def read_summary_csv(csv_path: Path) -> pd.DataFrame:
    if not csv_path.exists():
        return pd.DataFrame(columns=['channel','side','n_spikes','fr_hz'])
    df = pd.read_csv(csv_path)
    df['channel'] = df['channel'].astype(int)
    df['side'] = df['side'].astype(str)
    df['n_spikes'] = df['n_spikes'].astype(int)
    df['fr_hz'] = df['fr_hz'].astype(float)
    return df

def load_selections_if_any(output_root: Path, plate: Optional[int], ctz_stem: str, veh_stem: str) -> Dict[int, str]:
    sel_dir = output_root / 'selections'
    sel_name = f"plate_{plate or 'NA'}__{ctz_stem}_ifr_per_channel_1ms__{veh_stem}_ifr_per_channel_1ms.json"
    # Fallback: scan selections dir for matching stems if pattern differs
    sel = {}
    if not sel_dir.exists():
        return sel
    for p in sel_dir.glob('*.json'):
        try:
            data = json.loads(p.read_text())
            if (ctz_stem in str(data.get('ctz_npz',''))) and (veh_stem in str(data.get('veh_npz',''))):
                d = data.get('selections') or {}
                sel = {int(k): str(v) for k, v in d.items()}
                break
        except Exception:
            continue
    return sel


In [None]:
# Inspect one pair
if len(df_pairs):
    i = 0  # change index to pick a pair
    r = df_pairs.iloc[i]
    h5p = Path(r['h5_path'])
    attrs = read_pair_attrs(h5p)
    print('Pair:', r['plate'], r['round'], r['ctz_stem'], 'VS', r['veh_stem'])
    print('Attrs keys:', sorted(attrs.keys()))
    print('CTZ sr_hz:', attrs.get('ctz_sr_hz'), '| VEH sr_hz:', attrs.get('veh_sr_hz'))
    # Load channel 0 arrays (if present)
    t, raw, fil, ts, wf = load_channel(h5p, 'CTZ', 0)
    print('CTZ ch00 shapes: t/raw/fil:', t.shape, raw.shape, fil.shape, '| spikes:', ts.shape, '| wf:', wf.shape)
else:
    print('No exports found. Run GUI export or batch exporter first.')


In [None]:
# Build per-plate aggregate summary (joins CTZ/VEH FR per channel; optional selections)
def aggregate_plate(plate: int, use_selections: bool = True) -> pd.DataFrame:
    rows = []
    dpf = df_pairs[df_pairs['plate'] == plate]
    for _, r in dpf.iterrows():
        h5p = Path(r['h5_path'])
        csvp = Path(r['csv_summary'])
        attrs = read_pair_attrs(h5p)
        df = read_summary_csv(csvp)
        if df.empty:
            continue
        # Optional: apply selections
        accepted = None
        if use_selections:
            sel = load_selections_if_any(OUTPUT_ROOT, plate, r['ctz_stem'], r['veh_stem'])
            accepted = {ch for ch, v in sel.items() if str(v).lower() == 'accept'} if sel else None
        # pivot to have CTZ and VEH FR columns per channel
        piv = df.pivot_table(index='channel', columns='side', values='fr_hz', aggfunc='first').reset_index()
        piv['plate'] = plate
        piv['round'] = r['round']
        piv['ctz_stem'] = r['ctz_stem']
        piv['veh_stem'] = r['veh_stem']
        if 'CTZ' not in piv.columns: piv['CTZ'] = np.nan
        if 'VEH' not in piv.columns: piv['VEH'] = np.nan
        piv['delta_fr_hz'] = piv['CTZ'] - piv['VEH']
        if accepted is not None:
            piv['accepted'] = piv['channel'].isin(accepted)
            piv = piv[piv['accepted']].copy()
        rows.append(piv)
    if not rows:
        return pd.DataFrame(columns=['plate','round','ctz_stem','veh_stem','channel','CTZ','VEH','delta_fr_hz'])
    out = pd.concat(rows, ignore_index=True)
    return out

# Example: aggregate one plate (set your plate number)
if len(df_pairs) and df_pairs['plate'].notna().any():
    plate_example = int(df_pairs['plate'].dropna().iloc[0])
    agg = aggregate_plate(plate_example, use_selections=True)
    print('Aggregated rows:', len(agg))
    display(agg.head(10))
else:
    print('No plate numbers found in exports.')


In [None]:
# Save per-plate aggregates to CSV for downstream analysis
def write_plate_aggregate_csv(plate: int, use_selections: bool = True) -> Path:
    agg = aggregate_plate(plate, use_selections=use_selections)
    out_dir = EXPORTS_DIR / f'plate_{plate}'
    out_dir.mkdir(parents=True, exist_ok=True)
    out_csv = out_dir / f'plate_{plate}_aggregate_summary.csv'
    agg.to_csv(out_csv, index=False)
    print('Wrote ->', out_csv)
    return out_csv

# Example (uncomment to run):
# write_plate_aggregate_csv(plate_example, use_selections=True)


## Notes
- HDF5 data types are float64; CSV fields are typed on load as indicated.
- The analysis window for FR is `post_s` (stored in HDF5 root attrs).
- To inspect waveforms programmatically, use `load_channel(...)[-1]` to get an `n_spikes × n_snippet` array.
- Use selections JSON (if available) to focus aggregation on accepted channels.
- You can extend aggregation to compute pre/post baselines, averages of waveforms, etc., using the same loaders.