# Bombcell classification reason audit (GOOD / MUA / NOISE / NON-SOMATIC)

This notebook shows **exactly why each unit got its Bombcell label** by recomputing threshold pass/fail flags from saved run outputs.


In [None]:
from pathlib import Path
import sys
import numpy as np
import pandas as pd
import bombcell as bc

analysis_dir = Path.cwd().resolve()
sys.path.insert(0, str(analysis_dir))
from post_analysis_setup import load_post_analysis_context

CONFIG_FILE = '../configs/grant_recording_config.json'
RUN_MODE = 'batch'   # batch | single_probe | np20_rerun
TARGET_PROBE = 'B'   # A-F


In [None]:
ctx = load_post_analysis_context(CONFIG_FILE)
mode_to_roots = {
    'batch': (ctx['DEFAULT_KS_STAGING_ROOT'], 'DEFAULT'),
    'np20_rerun': (ctx['NP20_KS_STAGING_ROOT'], 'NP2_RERUN'),
    'single_probe': (ctx['BOMBCELL_KS_SINGLEPROBE_STAGING_ROOT'], 'SINGLE_PROBE'),
}

staging_root, save_subdir = mode_to_roots[RUN_MODE]
ks_dir = Path(staging_root) / f'kilosort4_{TARGET_PROBE}'
save_path = ks_dir / 'bombcell' / save_subdir

print('ks_dir:', ks_dir)
print('save_path:', save_path)


In [None]:
param, quality_metrics, _ = bc.load_bc_results(str(save_path))
unit_type, unit_type_string = bc.qm.get_quality_unit_type(param, quality_metrics)
qm_df = pd.DataFrame(quality_metrics).copy()
qm_df['bombcell_label'] = unit_type_string
qm_df['unit_index'] = np.arange(len(qm_df))

if 'cluster_id' not in qm_df.columns:
    qm_df['cluster_id'] = qm_df['unit_index']

print('Loaded units:', len(qm_df))
print(qm_df['bombcell_label'].value_counts(dropna=False))


## Build explicit threshold-fail flags for each class block


In [None]:
def col(name):
    if name in qm_df.columns:
        return qm_df[name]
    return pd.Series(np.nan, index=qm_df.index)

noise_fail = {
    'nPeaks>maxNPeaks': col('nPeaks') > param['maxNPeaks'],
    'nTroughs>maxNTroughs': col('nTroughs') > param['maxNTroughs'],
    'wvDuration<minWvDuration': col('waveformDuration_peakTrough') < param['minWvDuration'],
    'wvDuration>maxWvDuration': col('waveformDuration_peakTrough') > param['maxWvDuration'],
    'baselineFlatness>maxWvBaselineFraction': col('waveformBaselineFlatness') > param['maxWvBaselineFraction'],
    'scndPeakToTroughRatio>maxScndPeakToTroughRatio_noise': col('scndPeakToTroughRatio') > param['maxScndPeakToTroughRatio_noise'],
}
if bool(param.get('computeSpatialDecay', False)):
    if bool(param.get('spDecayLinFit', False)):
        noise_fail['spatialDecaySlope<minSpatialDecaySlope'] = col('spatialDecaySlope') < param['minSpatialDecaySlope']
    else:
        noise_fail['spatialDecaySlope<minSpatialDecaySlopeExp'] = col('spatialDecaySlope') < param['minSpatialDecaySlopeExp']
        noise_fail['spatialDecaySlope>maxSpatialDecaySlopeExp'] = col('spatialDecaySlope') > param['maxSpatialDecaySlopeExp']

mua_fail = {
    'percentageSpikesMissing_gaussian>maxPercSpikesMissing': col('percentageSpikesMissing_gaussian') > param['maxPercSpikesMissing'],
    'nSpikes<minNumSpikes': col('nSpikes') < param['minNumSpikes'],
    'fractionRPVs_estimatedTauR>maxRPVviolations': col('fractionRPVs_estimatedTauR') > param['maxRPVviolations'],
    'presenceRatio<minPresenceRatio': col('presenceRatio') < param['minPresenceRatio'],
}
if bool(param.get('extractRaw', False)):
    mua_fail['rawAmplitude<minAmplitude'] = col('rawAmplitude') < param['minAmplitude']
    mua_fail['signalToNoiseRatio<minSNR'] = col('signalToNoiseRatio') < param['minSNR']
if bool(param.get('computeDrift', False)):
    mua_fail['maxDriftEstimate>maxDrift'] = col('maxDriftEstimate') > param['maxDrift']
if bool(param.get('computeDistanceMetrics', False)):
    mua_fail['isolationDistance<isoDmin'] = col('isolationDistance') < param['isoDmin']
    mua_fail['Lratio>lratioMax'] = col('Lratio') > param['lratioMax']

non_soma_fail = {
    'troughToPeak2Ratio<minTroughToPeak2Ratio_nonSomatic': col('troughToPeak2Ratio') < param['minTroughToPeak2Ratio_nonSomatic'],
    'mainPeak_before_width<minWidthFirstPeak_nonSomatic': col('mainPeak_before_width') < param['minWidthFirstPeak_nonSomatic'],
    'mainTrough_width<minWidthMainTrough_nonSomatic': col('mainTrough_width') < param['minWidthMainTrough_nonSomatic'],
    'peak1ToPeak2Ratio>maxPeak1ToPeak2Ratio_nonSomatic': col('peak1ToPeak2Ratio') > param['maxPeak1ToPeak2Ratio_nonSomatic'],
    'mainPeakToTroughRatio>maxMainPeakToTroughRatio_nonSomatic': col('mainPeakToTroughRatio') > param['maxMainPeakToTroughRatio_nonSomatic'],
}

for prefix, flag_dict in [('noise_', noise_fail), ('mua_', mua_fail), ('nonsoma_', non_soma_fail)]:
    for k, v in flag_dict.items():
        qm_df[prefix + k] = v.fillna(False).astype(bool)


In [None]:
def reasons_for_unit(i: int) -> str:
    label = str(qm_df.loc[i, 'bombcell_label'])
    reasons = []

    noise_hits = [k for k, v in noise_fail.items() if bool(v.iloc[i])]
    mua_hits = [k for k, v in mua_fail.items() if bool(v.iloc[i])]
    nonsoma_hits = [k for k, v in non_soma_fail.items() if bool(v.iloc[i])]

    if label == 'NOISE':
        reasons.extend([f'NOISE: {r}' for r in noise_hits] or ['NOISE: classified by noise mask'])
    elif label in ('MUA', 'NON-SOMA MUA'):
        reasons.extend([f'MUA: {r}' for r in mua_hits] or ['MUA: classified by MUA mask'])
    elif label in ('NON-SOMA', 'NON-SOMA GOOD'):
        # non-soma can be reassigned from GOOD (or MUA when split)
        reasons.extend([f'NON-SOMA: {r}' for r in nonsoma_hits] or ['NON-SOMA: non-somatic waveform rule'])

    if label == 'GOOD':
        reasons.append('GOOD: passed NOISE and MUA thresholds, and not non-somatic')

    return ' | '.join(reasons)

qm_df['classification_reason'] = [reasons_for_unit(i) for i in range(len(qm_df))]


In [None]:
show_cols = [
    'unit_index', 'cluster_id', 'bombcell_label', 'classification_reason',
    'nPeaks', 'nTroughs', 'waveformDuration_peakTrough',
    'fractionRPVs_estimatedTauR', 'presenceRatio', 'nSpikes'
]
show_cols = [c for c in show_cols if c in qm_df.columns]

display(qm_df[show_cols].head(30))


In [None]:
# Filter by one class to inspect exact reasons (GOOD / MUA / NOISE / NON-SOMA)
CLASS_TO_INSPECT = 'MUA'
mask = qm_df['bombcell_label'].astype(str).str.contains(CLASS_TO_INSPECT, regex=False)
display(qm_df.loc[mask, show_cols].head(100))


In [None]:
# Optional export for downstream auditing
out_csv = Path(save_path) / f'Probe_{TARGET_PROBE}_classification_reason_audit.csv'
qm_df.to_csv(out_csv, index=False)
print('Saved:', out_csv)
