# Bombcell classification reason audit (GOOD / MUA / NOISE / NON-SOMATIC)

This notebook shows **exactly why each unit got its Bombcell label** by recomputing threshold pass/fail flags from saved run outputs.


In [None]:
from pathlib import Path
import sys
import numpy as np
import pandas as pd
import bombcell as bc

analysis_dir = Path.cwd().resolve()
sys.path.insert(0, str(analysis_dir))
from post_analysis_setup import load_post_analysis_context

CONFIG_FILE = '../configs/grant_recording_config.json'
RUN_MODE = 'batch'   # batch | single_probe | np20_rerun
TARGET_PROBE = 'B'   # A-F


In [None]:
ctx = load_post_analysis_context(CONFIG_FILE)
mode_to_roots = {
    'batch': ctx['DEFAULT_KS_STAGING_ROOT'],
    'np20_rerun': ctx['NP20_KS_STAGING_ROOT'],
    'single_probe': ctx['BOMBCELL_KS_SINGLEPROBE_STAGING_ROOT'],
}

staging_root = mode_to_roots[RUN_MODE]
ks_dir = Path(staging_root) / f'kilosort4_{TARGET_PROBE}'
save_path = ks_dir / 'bombcell'

print('ks_dir:', ks_dir)
print('save_path:', save_path)


In [None]:
param, quality_metrics, _ = bc.load_bc_results(str(save_path))
unit_type, unit_type_string = bc.qm.get_quality_unit_type(param, quality_metrics)
qm_df = pd.DataFrame(quality_metrics).copy()
qm_df['bombcell_label'] = unit_type_string
qm_df['unit_index'] = np.arange(len(qm_df))

if 'cluster_id' not in qm_df.columns:
    qm_df['cluster_id'] = qm_df['unit_index']

print('Loaded units:', len(qm_df))
print(qm_df['bombcell_label'].value_counts(dropna=False))


## Build explicit threshold-fail flags for each class block


In [None]:
def col(name):
    if name in qm_df.columns:
        return qm_df[name]
    return pd.Series(np.nan, index=qm_df.index)

noise_fail = {
    'nPeaks>maxNPeaks': col('nPeaks') > param['maxNPeaks'],
    'nTroughs>maxNTroughs': col('nTroughs') > param['maxNTroughs'],
    'wvDuration<minWvDuration': col('waveformDuration_peakTrough') < param['minWvDuration'],
    'wvDuration>maxWvDuration': col('waveformDuration_peakTrough') > param['maxWvDuration'],
    'baselineFlatness>maxWvBaselineFraction': col('waveformBaselineFlatness') > param['maxWvBaselineFraction'],
    'scndPeakToTroughRatio>maxScndPeakToTroughRatio_noise': col('scndPeakToTroughRatio') > param['maxScndPeakToTroughRatio_noise'],
}
if bool(param.get('computeSpatialDecay', False)):
    if bool(param.get('spDecayLinFit', False)):
        noise_fail['spatialDecaySlope<minSpatialDecaySlope'] = col('spatialDecaySlope') < param['minSpatialDecaySlope']
    else:
        noise_fail['spatialDecaySlope<minSpatialDecaySlopeExp'] = col('spatialDecaySlope') < param['minSpatialDecaySlopeExp']
        noise_fail['spatialDecaySlope>maxSpatialDecaySlopeExp'] = col('spatialDecaySlope') > param['maxSpatialDecaySlopeExp']

mua_fail = {
    'percentageSpikesMissing_gaussian>maxPercSpikesMissing': col('percentageSpikesMissing_gaussian') > param['maxPercSpikesMissing'],
    'nSpikes<minNumSpikes': col('nSpikes') < param['minNumSpikes'],
    'fractionRPVs_estimatedTauR>maxRPVviolations': col('fractionRPVs_estimatedTauR') > param['maxRPVviolations'],
    'presenceRatio<minPresenceRatio': col('presenceRatio') < param['minPresenceRatio'],
}
if bool(param.get('extractRaw', False)):
    mua_fail['rawAmplitude<minAmplitude'] = col('rawAmplitude') < param['minAmplitude']
    mua_fail['signalToNoiseRatio<minSNR'] = col('signalToNoiseRatio') < param['minSNR']
if bool(param.get('computeDrift', False)):
    mua_fail['maxDriftEstimate>maxDrift'] = col('maxDriftEstimate') > param['maxDrift']
if bool(param.get('computeDistanceMetrics', False)):
    mua_fail['isolationDistance<isoDmin'] = col('isolationDistance') < param['isoDmin']
    mua_fail['Lratio>lratioMax'] = col('Lratio') > param['lratioMax']

non_soma_fail = {
    'troughToPeak2Ratio<minTroughToPeak2Ratio_nonSomatic': col('troughToPeak2Ratio') < param['minTroughToPeak2Ratio_nonSomatic'],
    'mainPeak_before_width<minWidthFirstPeak_nonSomatic': col('mainPeak_before_width') < param['minWidthFirstPeak_nonSomatic'],
    'mainTrough_width<minWidthMainTrough_nonSomatic': col('mainTrough_width') < param['minWidthMainTrough_nonSomatic'],
    'peak1ToPeak2Ratio>maxPeak1ToPeak2Ratio_nonSomatic': col('peak1ToPeak2Ratio') > param['maxPeak1ToPeak2Ratio_nonSomatic'],
    'mainPeakToTroughRatio>maxMainPeakToTroughRatio_nonSomatic': col('mainPeakToTroughRatio') > param['maxMainPeakToTroughRatio_nonSomatic'],
}

for prefix, flag_dict in [('noise_', noise_fail), ('mua_', mua_fail), ('nonsoma_', non_soma_fail)]:
    for k, v in flag_dict.items():
        qm_df[prefix + k] = v.fillna(False).astype(bool)


In [None]:
noise_fail_np = {k: np.asarray(v, dtype=bool) for k, v in noise_fail.items()}
mua_fail_np = {k: np.asarray(v, dtype=bool) for k, v in mua_fail.items()}
nonsoma_fail_np = {k: np.asarray(v, dtype=bool) for k, v in non_soma_fail.items()}
labels_np = qm_df['bombcell_label'].astype(str).to_numpy()

def reasons_for_unit(i: int) -> str:
    label = labels_np[i]
    reasons = []

    noise_hits = [k for k, v in noise_fail_np.items() if bool(v[i])]
    mua_hits = [k for k, v in mua_fail_np.items() if bool(v[i])]
    nonsoma_hits = [k for k, v in nonsoma_fail_np.items() if bool(v[i])]

    if label == 'NOISE':
        reasons.extend([f'NOISE: {r}' for r in noise_hits] or ['NOISE: classified by noise mask'])
    elif label in ('MUA', 'NON-SOMA MUA'):
        reasons.extend([f'MUA: {r}' for r in mua_hits] or ['MUA: classified by MUA mask'])
    elif label in ('NON-SOMA', 'NON-SOMA GOOD'):
        # non-soma can be reassigned from GOOD (or MUA when split)
        reasons.extend([f'NON-SOMA: {r}' for r in nonsoma_hits] or ['NON-SOMA: non-somatic waveform rule'])

    if label == 'GOOD':
        reasons.append('GOOD: passed NOISE and MUA thresholds, and not non-somatic')

    return ' | '.join(reasons)

qm_df['classification_reason'] = [reasons_for_unit(i) for i in range(len(qm_df))]


In [None]:
show_cols = [
    'unit_index', 'cluster_id', 'bombcell_label', 'classification_reason',
    'nPeaks', 'nTroughs', 'waveformDuration_peakTrough',
    'fractionRPVs_estimatedTauR', 'presenceRatio', 'nSpikes'
]
show_cols = [c for c in show_cols if c in qm_df.columns]

display(qm_df[show_cols].head(30))


In [None]:
# Filter by one class to inspect exact reasons (GOOD / MUA / NOISE / NON-SOMA)
CLASS_TO_INSPECT = 'MUA'
mask = qm_df['bombcell_label'].astype(str).str.contains(CLASS_TO_INSPECT, regex=False)
display(qm_df.loc[mask, show_cols].head(100))


In [None]:
# Optional export for downstream auditing
out_csv = Path(save_path) / f'Probe_{TARGET_PROBE}_classification_reason_audit.csv'
qm_df.to_csv(out_csv, index=False)
print('Saved:', out_csv)


## Visualization suite: classification audit reasons

These plots summarize **why units were assigned each class** and allow **per-neuron drill-down**.


In [None]:
import matplotlib.pyplot as plt
from collections import Counter

# Parse reason tokens from 'classification_reason'
def parse_reason_tokens(reason_text: str):
    if not isinstance(reason_text, str) or reason_text.strip() == '':
        return []
    tokens = [t.strip() for t in reason_text.split('|') if t.strip()]
    # strip class prefix e.g. "MUA: ..."
    clean = [t.split(': ', 1)[1] if ': ' in t else t for t in tokens]
    return clean

qm_df['reason_tokens'] = qm_df['classification_reason'].astype(str).apply(parse_reason_tokens)
all_reasons = sorted({r for row in qm_df['reason_tokens'] for r in row})

print('Unique reasons found:', len(all_reasons))
print('Classes:', qm_df['bombcell_label'].value_counts().to_dict())


### Plot 1-3: high-level class + reason summaries


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 5))

# 1) Class counts
class_counts = qm_df['bombcell_label'].value_counts()
axes[0].bar(class_counts.index.astype(str), class_counts.values)
axes[0].set_title('Unit counts by Bombcell class')
axes[0].set_ylabel('n units')
axes[0].tick_params(axis='x', rotation=35)

# 2) Class proportions
axes[1].pie(class_counts.values, labels=class_counts.index.astype(str), autopct='%1.1f%%', startangle=90)
axes[1].set_title('Class proportion')

# 3) Top reasons overall
overall_reason_counts = Counter(r for row in qm_df['reason_tokens'] for r in row)
if overall_reason_counts:
    top = overall_reason_counts.most_common(15)
    labels = [k for k, _ in top][::-1]
    vals = [v for _, v in top][::-1]
    axes[2].barh(labels, vals)
else:
    axes[2].text(0.5, 0.5, 'No reason tokens', ha='center', va='center')
axes[2].set_title('Top audit reasons (overall)')
axes[2].set_xlabel('count')

plt.tight_layout()
plt.show()


### Plot 4-6: class-specific reason structure


In [None]:
classes = [c for c in ['GOOD','MUA','NOISE','NON-SOMA','NON-SOMA GOOD','NON-SOMA MUA'] if c in set(qm_df['bombcell_label'])]

# Build reason-by-class matrix
reason_class = pd.DataFrame(0, index=all_reasons, columns=classes, dtype=int)
for _, row in qm_df.iterrows():
    c = row['bombcell_label']
    if c not in reason_class.columns:
        continue
    for r in row['reason_tokens']:
        if r in reason_class.index:
            reason_class.loc[r, c] += 1

fig, axes = plt.subplots(1, 3, figsize=(24, 6))

# 4) Stacked bars for top reasons overall
top_reasons = reason_class.sum(axis=1).sort_values(ascending=False).head(12).index
if len(top_reasons) > 0:
    bottom = np.zeros(len(top_reasons))
    x = np.arange(len(top_reasons))
    for c in classes:
        vals = reason_class.loc[top_reasons, c].to_numpy()
        axes[0].bar(x, vals, bottom=bottom, label=c)
        bottom += vals
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(top_reasons, rotation=60, ha='right')
    axes[0].set_title('Top reasons split by class (stacked)')
    axes[0].set_ylabel('count')
    axes[0].legend(fontsize=8)
else:
    axes[0].text(0.5,0.5,'No reasons',ha='center',va='center')

# 5) Heatmap: reason x class
if len(reason_class.index) > 0 and len(reason_class.columns) > 0:
    top_heat = reason_class.sum(axis=1).sort_values(ascending=False).head(20).index
    heat = reason_class.loc[top_heat, classes].to_numpy()
    im = axes[1].imshow(heat, aspect='auto')
    axes[1].set_yticks(np.arange(len(top_heat)))
    axes[1].set_yticklabels(top_heat)
    axes[1].set_xticks(np.arange(len(classes)))
    axes[1].set_xticklabels(classes, rotation=35, ha='right')
    axes[1].set_title('Reason × Class heatmap (top 20 reasons)')
    plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
else:
    axes[1].text(0.5,0.5,'No reasons',ha='center',va='center')

# 6) Reason count distribution per class (boxplot)
reason_counts_per_unit = qm_df['reason_tokens'].apply(len)
box_data = [reason_counts_per_unit[qm_df['bombcell_label']==c].to_numpy() for c in classes]
if any(len(x)>0 for x in box_data):
    axes[2].boxplot(box_data, labels=classes, showfliers=False)
    axes[2].set_title('# triggered reasons per unit by class')
    axes[2].set_ylabel('n reasons')
    axes[2].tick_params(axis='x', rotation=35)
else:
    axes[2].text(0.5,0.5,'No data',ha='center',va='center')

plt.tight_layout()
plt.show()


### Plot 7-8: per-class top-reason panels and reason co-occurrence


In [None]:
fig, axes = plt.subplots(2, 2, figsize=(18, 12))
axes = axes.ravel()
panel_classes = ['GOOD','MUA','NOISE','NON-SOMA']

for ax, c in zip(axes, panel_classes):
    sub = qm_df[qm_df['bombcell_label'].astype(str).str.contains(c, regex=False)]
    cnt = Counter(r for row in sub['reason_tokens'] for r in row)
    if cnt:
        top = cnt.most_common(10)
        labels = [k for k,_ in top][::-1]
        vals = [v for _,v in top][::-1]
        ax.barh(labels, vals)
    else:
        ax.text(0.5,0.5,'No reasons found',ha='center',va='center')
    ax.set_title(f'Top reasons for {c}')
    ax.set_xlabel('count')

plt.tight_layout()
plt.show()

# Reason co-occurrence matrix (top 12 reasons)
reason_totals = Counter(r for row in qm_df['reason_tokens'] for r in row)
top_reasons = [r for r,_ in reason_totals.most_common(12)]
co = pd.DataFrame(0, index=top_reasons, columns=top_reasons, dtype=int)
for row in qm_df['reason_tokens']:
    row_set = set(row)
    for a in top_reasons:
        if a not in row_set:
            continue
        for b in top_reasons:
            if b in row_set:
                co.loc[a,b] += 1

plt.figure(figsize=(10,8))
plt.imshow(co.to_numpy(), aspect='auto')
plt.xticks(np.arange(len(top_reasons)), top_reasons, rotation=60, ha='right')
plt.yticks(np.arange(len(top_reasons)), top_reasons)
plt.title('Reason co-occurrence (top reasons)')
plt.colorbar(fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()


## Individual neuron diagnostics


In [None]:
required = ['qm_df', 'param', 'noise_fail_np', 'mua_fail_np', 'nonsoma_fail_np']
missing = [k for k in required if k not in globals()]
if missing:
    raise RuntimeError(f'Missing required objects: {missing}. Run all previous cells first.')
print('All required objects are loaded for per-neuron diagnostics.')


In [None]:
# Choose one unit by row index OR by cluster_id
selected_unit_index = 0  # change this to inspect different neurons
selected_cluster_id = None  # set e.g. 123 to select by cluster_id instead

if selected_cluster_id is not None:
    matches = qm_df.index[qm_df['cluster_id'] == selected_cluster_id]
    if len(matches) == 0:
        raise ValueError(f'cluster_id {selected_cluster_id} not found in qm_df')
    selected_unit_index = int(matches[0])

row = qm_df.iloc[selected_unit_index]
print('selected_unit_index:', selected_unit_index)
print('cluster_id:', row.get('cluster_id', np.nan))
print('label:', row['bombcell_label'])
print('classification_reason:', row['classification_reason'])


In [None]:
# Per-neuron fail profile by block
row = qm_df.iloc[selected_unit_index]

noise_hits = {k: bool(v[selected_unit_index]) for k,v in noise_fail_np.items()}
mua_hits = {k: bool(v[selected_unit_index]) for k,v in mua_fail_np.items()}
non_hits = {k: bool(v[selected_unit_index]) for k,v in nonsoma_fail_np.items()}

def plot_hits(hit_dict, title, color_true='tab:red'):
    keys = list(hit_dict.keys())
    vals = np.array([1 if hit_dict[k] else 0 for k in keys])
    colors = [color_true if v==1 else 'lightgray' for v in vals]
    plt.barh(keys, vals, color=colors)
    plt.xlim(0,1.1)
    plt.xticks([0,1], ['pass','fail'])
    plt.title(title)

plt.figure(figsize=(18,10))
plt.subplot(1,3,1); plot_hits(noise_hits, 'Noise-rule fails')
plt.subplot(1,3,2); plot_hits(mua_hits, 'MUA-rule fails', color_true='tab:orange')
plt.subplot(1,3,3); plot_hits(non_hits, 'Non-somatic-rule triggers', color_true='tab:blue')
plt.tight_layout(); plt.show()


In [None]:
# Per-neuron metric-vs-threshold chart (where metric exists)
row = qm_df.iloc[selected_unit_index]
checks = [
    ('nPeaks', row.get('nPeaks', np.nan), param.get('maxNPeaks', np.nan), 'le'),
    ('nTroughs', row.get('nTroughs', np.nan), param.get('maxNTroughs', np.nan), 'le'),
    ('waveformDuration_peakTrough', row.get('waveformDuration_peakTrough', np.nan), param.get('minWvDuration', np.nan), 'ge'),
    ('fractionRPVs_estimatedTauR', row.get('fractionRPVs_estimatedTauR', np.nan), param.get('maxRPVviolations', np.nan), 'le'),
    ('presenceRatio', row.get('presenceRatio', np.nan), param.get('minPresenceRatio', np.nan), 'ge'),
    ('nSpikes', row.get('nSpikes', np.nan), param.get('minNumSpikes', np.nan), 'ge'),
    ('rawAmplitude', row.get('rawAmplitude', np.nan), param.get('minAmplitude', np.nan), 'ge'),
    ('signalToNoiseRatio', row.get('signalToNoiseRatio', np.nan), param.get('minSNR', np.nan), 'ge'),
]

plot_rows = [x for x in checks if np.isfinite(x[1]) and np.isfinite(x[2])]
if len(plot_rows)==0:
    print('No finite metric-threshold pairs available for this unit.')
else:
    names = [x[0] for x in plot_rows]
    vals = np.array([x[1] for x in plot_rows], dtype=float)
    thrs = np.array([x[2] for x in plot_rows], dtype=float)
    dirs = [x[3] for x in plot_rows]
    passed = np.array([(v<=t if d=='le' else v>=t) for v,t,d in zip(vals,thrs,dirs)])

    y = np.arange(len(names))
    plt.figure(figsize=(10, max(4, 0.45*len(names))))
    plt.scatter(vals, y, c=np.where(passed, 'green', 'red'), label='unit value', zorder=3)
    plt.scatter(thrs, y, c='black', marker='|', s=250, label='threshold', zorder=4)
    for i,(v,t,d,p) in enumerate(zip(vals,thrs,dirs,passed)):
        sign = '<=' if d=='le' else '>='
        plt.plot([min(v,t), max(v,t)], [i,i], color='gray', alpha=0.3, zorder=1)
        plt.text(max(v,t), i+0.12, f'{"PASS" if p else "FAIL"} ({v:.3g} {sign} {t:.3g})', fontsize=8)

    plt.yticks(y, names)
    plt.xlabel('metric value')
    plt.title(f'Unit {selected_unit_index} metric vs threshold ({row["bombcell_label"]})')
    plt.legend(loc='best')
    plt.tight_layout()
    plt.show()
