# Bombcell Post-Run Analysis (Open Ephys + Kilosort4)

Assumes Bombcell has already been run and you exported per-probe CSV/JSON summaries.

**Expected folder convention**
- `{NP_recording_name}/bombcell_DEFAULT/`
  - `DUPLICATED_KILOSORT4_FILES/`
  - `batch_DEFAULT_results/`
- `{NP_recording_name}/bombcell_NP2.0/`
  - `DUPLICATED_KILOSORT4_FILES_ACD/`
  - `NP2_ReRun_results/`

In [None]:
# =========================
# Configure
# =========================
from pathlib import Path
import pandas as pd
import numpy as np
import json

NP_recording_name = "Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00"  # <-- edit

BASE_ROOT = Path(r"H:\Grant\Neuropixels\Kilosort_Recordings")
RECORDING_ROOT = BASE_ROOT / NP_recording_name

DEFAULT_EXPORT_ROOT = RECORDING_ROOT / "bombcell_DEFAULT" / "batch_DEFAULT_results"
NP20_EXPORT_ROOT    = RECORDING_ROOT / "bombcell_NP2.0" / "NP2_ReRun_results"

PROBES_ALL  = ["A","B","C","D","E","F"]
PROBES_NP20 = ["A","C","D"]

print("RECORDING_ROOT:", RECORDING_ROOT)
print("DEFAULT_EXPORT_ROOT exists:", DEFAULT_EXPORT_ROOT.exists())
print("NP20_EXPORT_ROOT exists:", NP20_EXPORT_ROOT.exists())

In [None]:
# =========================
# Helpers
# =========================
def load_probe_exports(export_root: Path, probe: str):
    # Loads Probe_{probe} exports: quality_metrics.csv, unit_type_counts.csv, param.json, checks.json.
    probe_dir = export_root / f"Probe_{probe}"
    qm_path = probe_dir / f"Probe_{probe}_quality_metrics.csv"
    counts_path = probe_dir / f"Probe_{probe}_unit_type_counts.csv"
    param_path = probe_dir / f"Probe_{probe}_param.json"
    checks_path = probe_dir / f"Probe_{probe}_checks.json"
    err_path = probe_dir / "ERROR.txt"

    if err_path.exists():
        return {"probe": probe, "status": "FAILED", "error": err_path.read_text(), "probe_dir": probe_dir}

    out = {"probe": probe, "status": "OK", "probe_dir": probe_dir}
    out["qm"] = pd.read_csv(qm_path) if qm_path.exists() else None
    out["counts"] = pd.read_csv(counts_path) if counts_path.exists() else None
    out["param"] = json.loads(param_path.read_text()) if param_path.exists() else {}
    out["checks"] = json.loads(checks_path.read_text()) if checks_path.exists() else {}

    out["cluster_id_col"] = None
    if out["qm"] is not None:
        for c in ["cluster_id","clusterID","cluster_id_ks","cluster_id_phy","cluster"]:
            if c in out["qm"].columns:
                out["cluster_id_col"] = c
                break

    return out

def load_batch_summary(export_root: Path):
    p = export_root / "batch_summary.csv"
    return pd.read_csv(p) if p.exists() else None

def summarize_unit_types(qm: pd.DataFrame, label_col="Bombcell_unit_type"):
    if qm is None or label_col not in qm.columns:
        return None
    return qm[label_col].value_counts().rename_axis("unit_type").reset_index(name="count")

def add_percentages(df_counts: pd.DataFrame):
    if df_counts is None or df_counts.empty:
        return df_counts
    total = df_counts["count"].sum()
    df_counts = df_counts.copy()
    df_counts["pct"] = 100 * df_counts["count"] / total
    return df_counts

def find_cluster_row(qm: pd.DataFrame, cluster_id: int, cluster_id_col: str):
    if qm is None:
        raise ValueError("qm is None")
    if cluster_id_col is None or cluster_id_col not in qm.columns:
        raise ValueError("No cluster_id column found in quality_metrics.csv")
    sub = qm.loc[qm[cluster_id_col] == cluster_id]
    if sub.empty:
        raise KeyError(f"Cluster id {cluster_id} not found in {cluster_id_col}")
    return sub.iloc[0]

def threshold_fail_report(row, qm_cols, param):
    # Common Bombcell gates; only checks metrics that exist in the CSV.
    rules = [
        ("rawAmplitude", "<", param.get("minAmplitude", 40)),
        ("signalToNoiseRatio", "<", param.get("minSNR", 5)),
        ("presenceRatio", "<", param.get("minPresenceRatio", 0.7)),
        ("fractionRPVs_estimatedTauR", ">", param.get("maxRPVviolations", 0.1)),
        ("percentageSpikesMissing_gaussian", ">", param.get("maxPercSpikesMissing", 20)),
        ("waveformDuration_peakTrough", "<", param.get("minWvDuration", 100)),
        ("waveformDuration_peakTrough", ">", param.get("maxWvDuration", 1150)),
        ("nPeaks", ">", param.get("maxNPeaks", 2)),
        ("nTroughs", ">", param.get("maxNTroughs", 1)),
        ("waveformBaselineFlatness", ">", param.get("maxWvBaselineFraction", 0.3)),
    ]
    fails = []
    for col, op, thr in rules:
        if col not in qm_cols:
            continue
        v = row[col]
        if pd.isna(v):
            continue
        if (op == "<" and v < thr) or (op == ">" and v > thr):
            fails.append((col, float(v), op, float(thr)))
    return fails

## Load DEFAULT exports (all probes)

In [None]:
default_summary = load_batch_summary(DEFAULT_EXPORT_ROOT)
default_summary

In [None]:
default_data = {p: load_probe_exports(DEFAULT_EXPORT_ROOT, p) for p in PROBES_ALL}

for p in PROBES_ALL:
    d = default_data[p]
    print("="*60, f"Probe {p} ({d['status']})")
    if d["status"] != "OK":
        print(d.get("error",""))
        continue
    counts = add_percentages(summarize_unit_types(d["qm"]))
    display(counts)

## Load NP2.0 rerun exports (A/C/D)

In [None]:
np20_summary = load_batch_summary(NP20_EXPORT_ROOT)
np20_summary

In [None]:
np20_data = {p: load_probe_exports(NP20_EXPORT_ROOT, p) for p in PROBES_NP20}

for p in PROBES_NP20:
    d = np20_data[p]
    print("="*60, f"Probe {p} ({d['status']})")
    if d["status"] != "OK":
        print(d.get("error",""))
        continue
    counts = add_percentages(summarize_unit_types(d["qm"]))
    display(counts)

## Compare DEFAULT vs NP2.0 rerun (A/C/D)

In [None]:
rows = []
for p in PROBES_NP20:
    d0 = default_data.get(p, {})
    d1 = np20_data.get(p, {})
    if d0.get("status") != "OK" or d1.get("status") != "OK":
        continue

    c0 = summarize_unit_types(d0["qm"])
    c1 = summarize_unit_types(d1["qm"])

    def _get(ct, name):
        if ct is None: 
            return 0
        sub = ct.loc[ct["unit_type"] == name, "count"]
        return int(sub.iloc[0]) if len(sub) else 0

    rows.append({
        "probe": p,
        "DEFAULT_GOOD": _get(c0,"GOOD"),
        "RERUN_GOOD": _get(c1,"GOOD"),
        "DEFAULT_MUA": _get(c0,"MUA"),
        "RERUN_MUA": _get(c1,"MUA"),
        "DEFAULT_NOISE": _get(c0,"NOISE"),
        "RERUN_NOISE": _get(c1,"NOISE"),
        "DEFAULT_NON-SOMA": _get(c0,"NON-SOMA"),
        "RERUN_NON-SOMA": _get(c1,"NON-SOMA"),
        "DEFAULT_TOTAL": len(d0["qm"]),
        "RERUN_TOTAL": len(d1["qm"]),
    })
pd.DataFrame(rows)

## Drill-down: why Bombcell labeled a specific cluster as MUA (or not GOOD)

This reports which thresholds are violated for a chosen `probe` and `cluster_id`.

In [None]:
probe = "B"          # <-- edit
cluster_id = 39      # <-- edit
run = "DEFAULT"      # "DEFAULT" or "NP20"

d = default_data[probe] if run == "DEFAULT" else np20_data[probe]

qm = d["qm"]
param = d["param"]
cluster_id_col = d.get("cluster_id_col", None)

print("Run:", run)
print("Probe:", probe)
print("cluster_id_col:", cluster_id_col)

row = find_cluster_row(qm, cluster_id, cluster_id_col)
print("Bombcell label:", row.get("Bombcell_unit_type", "UNKNOWN"))

fails = threshold_fail_report(row, qm.columns, param)

print("\n---- FAILING GATES ----")
if not fails:
    print("No fails among common checks; expand rules or inspect full row.")
else:
    for col, v, op, thr in fails:
        print(f"{col:35s} {v:>10.4f}  FAIL ({op}{thr})")

print("\n---- Key values ----")
key_cols = [
    "rawAmplitude","signalToNoiseRatio","presenceRatio",
    "fractionRPVs_estimatedTauR","percentageSpikesMissing_gaussian",
    "waveformDuration_peakTrough","nPeaks","nTroughs","waveformBaselineFlatness"
]
for c in key_cols:
    if c in qm.columns:
        print(f"{c:35s} {row[c]}")

## Distributions (RPV, presenceRatio)

In [None]:
import matplotlib.pyplot as plt

probe = "A"      # <-- edit
run = "NP20"     # "DEFAULT" or "NP20"

d = default_data[probe] if run == "DEFAULT" else np20_data[probe]
qm = d["qm"]

for col in ["fractionRPVs_estimatedTauR", "presenceRatio"]:
    if col in qm.columns:
        plt.figure()
        plt.hist(qm[col].dropna(), bins=50)
        plt.title(f"{probe} {run}: {col}")
        plt.xlabel(col); plt.ylabel("count")
        plt.show()

## Metric-by-label medians

In [None]:
probe = "A"      # <-- edit
run = "NP20"     # "DEFAULT" or "NP20"

d = default_data[probe] if run == "DEFAULT" else np20_data[probe]
qm = d["qm"]

metrics = ["fractionRPVs_estimatedTauR","presenceRatio","rawAmplitude","signalToNoiseRatio"]
present = [m for m in metrics if m in qm.columns]
qm.groupby("Bombcell_unit_type")[present].median()

## Compact overview table

In [None]:
def overview_table(data_dict: dict):
    rows = []
    for p, d in data_dict.items():
        if d.get("status") != "OK":
            rows.append({"probe": p, "status": "FAILED"})
            continue
        qm = d["qm"]
        counts = qm["Bombcell_unit_type"].value_counts()
        total = len(qm)
        row = {
            "probe": p,
            "status": "OK",
            "n_total": int(total),
            "n_GOOD": int(counts.get("GOOD",0)),
            "pct_GOOD": 100*float(counts.get("GOOD",0))/total if total else np.nan,
        }
        for m in ["fractionRPVs_estimatedTauR","presenceRatio","rawAmplitude","signalToNoiseRatio"]:
            if m in qm.columns:
                row[f"median_{m}"] = float(qm[m].median())
        rows.append(row)
    return pd.DataFrame(rows)

default_overview = overview_table(default_data)
np20_overview = overview_table(np20_data)

print("DEFAULT overview")
display(default_overview.sort_values("probe"))

print("NP2.0 rerun overview")
display(np20_overview.sort_values("probe"))