# Bombcell classification reason audit (GOOD / MUA / NOISE / NON-SOMATIC)

This notebook shows **exactly why each unit got its Bombcell label** by recomputing threshold pass/fail flags from saved run outputs.


In [21]:
from pathlib import Path
import sys
import numpy as np
import pandas as pd
import bombcell as bc

analysis_dir = Path.cwd().resolve()
sys.path.insert(0, str(analysis_dir))
from post_analysis_setup import load_post_analysis_context

CONFIG_FILE = r'C:\Users\user\Documents\github\bombcell\py_bombcell\grant\configs\grant_recording_config_reach15_20260201_session007.json'
RUN_MODE = 'batch'   # batch | single_probe | np20_rerun
TARGET_PROBE = 'B'   # A-F


In [22]:
ctx = load_post_analysis_context(CONFIG_FILE)
mode_to_roots = {
    'batch': (ctx['DEFAULT_KS_STAGING_ROOT'], ''),
    'np20_rerun': (ctx['NP20_KS_STAGING_ROOT'], ''),
    'single_probe': (ctx['BOMBCELL_KS_SINGLEPROBE_STAGING_ROOT'], ''),
}

staging_root, save_subdir = mode_to_roots[RUN_MODE]
ks_dir = Path(staging_root) / f'kilosort4_{TARGET_PROBE}'
save_path = ks_dir / 'bombcell' / save_subdir

print('ks_dir:', ks_dir)
print('save_path:', save_path)


ks_dir: H:\Grant\Neuropixels\Kilosort_Recordings\Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00\bombcell\bombcell_DEFAULT\kilosort4_B
save_path: H:\Grant\Neuropixels\Kilosort_Recordings\Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00\bombcell\bombcell_DEFAULT\kilosort4_B\bombcell


In [23]:
param, quality_metrics, _ = bc.load_bc_results(str(save_path))
unit_type, unit_type_string = bc.qm.get_quality_unit_type(param, quality_metrics)
qm_df = pd.DataFrame(quality_metrics).copy()
qm_df['bombcell_label'] = unit_type_string
qm_df['unit_index'] = np.arange(len(qm_df))

if 'cluster_id' not in qm_df.columns:
    qm_df['cluster_id'] = qm_df['unit_index']

print('Loaded units:', len(qm_df))
print(qm_df['bombcell_label'].value_counts(dropna=False))


Loaded units: 1316
bombcell_label
MUA         764
NON-SOMA    286
NOISE       247
GOOD         19
Name: count, dtype: int64


## Build explicit threshold-fail flags for each class block


In [24]:
def col(name):
    if name in qm_df.columns:
        return qm_df[name]
    return pd.Series(np.nan, index=qm_df.index)

noise_fail = {
    'nPeaks>maxNPeaks': col('nPeaks') > param['maxNPeaks'],
    'nTroughs>maxNTroughs': col('nTroughs') > param['maxNTroughs'],
    'wvDuration<minWvDuration': col('waveformDuration_peakTrough') < param['minWvDuration'],
    'wvDuration>maxWvDuration': col('waveformDuration_peakTrough') > param['maxWvDuration'],
    'baselineFlatness>maxWvBaselineFraction': col('waveformBaselineFlatness') > param['maxWvBaselineFraction'],
    'scndPeakToTroughRatio>maxScndPeakToTroughRatio_noise': col('scndPeakToTroughRatio') > param['maxScndPeakToTroughRatio_noise'],
}
if bool(param.get('computeSpatialDecay', False)):
    if bool(param.get('spDecayLinFit', False)):
        noise_fail['spatialDecaySlope<minSpatialDecaySlope'] = col('spatialDecaySlope') < param['minSpatialDecaySlope']
    else:
        noise_fail['spatialDecaySlope<minSpatialDecaySlopeExp'] = col('spatialDecaySlope') < param['minSpatialDecaySlopeExp']
        noise_fail['spatialDecaySlope>maxSpatialDecaySlopeExp'] = col('spatialDecaySlope') > param['maxSpatialDecaySlopeExp']

mua_fail = {
    'percentageSpikesMissing_gaussian>maxPercSpikesMissing': col('percentageSpikesMissing_gaussian') > param['maxPercSpikesMissing'],
    'nSpikes<minNumSpikes': col('nSpikes') < param['minNumSpikes'],
    'fractionRPVs_estimatedTauR>maxRPVviolations': col('fractionRPVs_estimatedTauR') > param['maxRPVviolations'],
    'presenceRatio<minPresenceRatio': col('presenceRatio') < param['minPresenceRatio'],
}
if bool(param.get('extractRaw', False)):
    mua_fail['rawAmplitude<minAmplitude'] = col('rawAmplitude') < param['minAmplitude']
    mua_fail['signalToNoiseRatio<minSNR'] = col('signalToNoiseRatio') < param['minSNR']
if bool(param.get('computeDrift', False)):
    mua_fail['maxDriftEstimate>maxDrift'] = col('maxDriftEstimate') > param['maxDrift']
if bool(param.get('computeDistanceMetrics', False)):
    mua_fail['isolationDistance<isoDmin'] = col('isolationDistance') < param['isoDmin']
    mua_fail['Lratio>lratioMax'] = col('Lratio') > param['lratioMax']

non_soma_fail = {
    'troughToPeak2Ratio<minTroughToPeak2Ratio_nonSomatic': col('troughToPeak2Ratio') < param['minTroughToPeak2Ratio_nonSomatic'],
    'mainPeak_before_width<minWidthFirstPeak_nonSomatic': col('mainPeak_before_width') < param['minWidthFirstPeak_nonSomatic'],
    'mainTrough_width<minWidthMainTrough_nonSomatic': col('mainTrough_width') < param['minWidthMainTrough_nonSomatic'],
    'peak1ToPeak2Ratio>maxPeak1ToPeak2Ratio_nonSomatic': col('peak1ToPeak2Ratio') > param['maxPeak1ToPeak2Ratio_nonSomatic'],
    'mainPeakToTroughRatio>maxMainPeakToTroughRatio_nonSomatic': col('mainPeakToTroughRatio') > param['maxMainPeakToTroughRatio_nonSomatic'],
}

for prefix, flag_dict in [('noise_', noise_fail), ('mua_', mua_fail), ('nonsoma_', non_soma_fail)]:
    for k, v in flag_dict.items():
        qm_df[prefix + k] = v.fillna(False).astype(bool)


In [25]:
noise_fail_np = {k: np.asarray(v, dtype=bool) for k, v in noise_fail.items()}
mua_fail_np = {k: np.asarray(v, dtype=bool) for k, v in mua_fail.items()}
nonsoma_fail_np = {k: np.asarray(v, dtype=bool) for k, v in non_soma_fail.items()}
labels_np = qm_df['bombcell_label'].astype(str).to_numpy()

def reasons_for_unit(i: int) -> str:
    label = labels_np[i]
    reasons = []

    noise_hits = [k for k, v in noise_fail_np.items() if bool(v[i])]
    mua_hits = [k for k, v in mua_fail_np.items() if bool(v[i])]
    nonsoma_hits = [k for k, v in nonsoma_fail_np.items() if bool(v[i])]

    if label == 'NOISE':
        reasons.extend([f'NOISE: {r}' for r in noise_hits] or ['NOISE: classified by noise mask'])
    elif label in ('MUA', 'NON-SOMA MUA'):
        reasons.extend([f'MUA: {r}' for r in mua_hits] or ['MUA: classified by MUA mask'])
    elif label in ('NON-SOMA', 'NON-SOMA GOOD'):
        # non-soma can be reassigned from GOOD (or MUA when split)
        reasons.extend([f'NON-SOMA: {r}' for r in nonsoma_hits] or ['NON-SOMA: non-somatic waveform rule'])

    if label == 'GOOD':
        reasons.append('GOOD: passed NOISE and MUA thresholds, and not non-somatic')

    return ' | '.join(reasons)

qm_df['classification_reason'] = [reasons_for_unit(i) for i in range(len(qm_df))]


In [26]:
show_cols = [
    'unit_index', 'cluster_id', 'bombcell_label', 'classification_reason',
    'nPeaks', 'nTroughs', 'waveformDuration_peakTrough',
    'fractionRPVs_estimatedTauR', 'presenceRatio', 'nSpikes'
]
show_cols = [c for c in show_cols if c in qm_df.columns]

display(qm_df[show_cols].head(30))


Unnamed: 0,unit_index,cluster_id,bombcell_label,classification_reason,nPeaks,nTroughs,waveformDuration_peakTrough,fractionRPVs_estimatedTauR,presenceRatio,nSpikes
0,0,0,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolations,1.0,1.0,166.666667,1.0,1.0,342314.0
1,1,1,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolatio...,2.0,1.0,133.333333,1.0,0.922078,93291.0
2,2,2,NOISE,NOISE: scndPeakToTroughRatio>maxScndPeakToTrou...,1.0,1.0,700.0,1.0,1.0,21451.0
3,3,3,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolations,1.0,1.0,200.0,1.0,0.941558,48982.0
4,4,4,GOOD,"GOOD: passed NOISE and MUA thresholds, and not...",1.0,1.0,233.333333,0.080264,0.915584,17269.0
5,5,5,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolations,1.0,1.0,333.333333,1.0,1.0,51091.0
6,6,6,NON-SOMA,NON-SOMA: troughToPeak2Ratio<minTroughToPeak2R...,1.0,1.0,266.666667,1.0,1.0,30325.0
7,7,7,NOISE,NOISE: scndPeakToTroughRatio>maxScndPeakToTrou...,1.0,1.0,633.333333,0.063217,0.876623,42631.0
8,8,8,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolatio...,1.0,1.0,200.0,1.0,0.922078,51450.0
9,9,9,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolations,1.0,1.0,200.0,1.0,1.0,88921.0


In [28]:
# Filter by one class to inspect exact reasons (GOOD / MUA / NOISE / NON-SOMA)
CLASS_TO_INSPECT = 'MUA'
mask = qm_df['bombcell_label'].astype(str).str.contains(CLASS_TO_INSPECT, regex=False)
display(qm_df.loc[mask, show_cols].head(100))


Unnamed: 0,unit_index,cluster_id,bombcell_label,classification_reason,nPeaks,nTroughs,waveformDuration_peakTrough,fractionRPVs_estimatedTauR,presenceRatio,nSpikes
0,0,0,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolations,1.0,1.0,166.666667,1.000000,1.000000,342314.0
1,1,1,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolatio...,2.0,1.0,133.333333,1.000000,0.922078,93291.0
3,3,3,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolations,1.0,1.0,200.000000,1.000000,0.941558,48982.0
5,5,5,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolations,1.0,1.0,333.333333,1.000000,1.000000,51091.0
8,8,8,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolatio...,1.0,1.0,200.000000,1.000000,0.922078,51450.0
...,...,...,...,...,...,...,...,...,...,...
180,180,180,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolatio...,1.0,1.0,200.000000,1.000000,0.357143,11223.0
181,181,181,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolatio...,1.0,1.0,300.000000,0.131132,0.272727,25798.0
182,182,182,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolations,1.0,1.0,200.000000,0.203328,0.740260,117803.0
183,183,183,MUA,MUA: fractionRPVs_estimatedTauR>maxRPVviolatio...,1.0,1.0,266.666667,1.000000,0.051948,576.0


In [29]:
# Optional export for downstream auditing
out_csv = Path(save_path) / f'Probe_{TARGET_PROBE}_classification_reason_audit.csv'
qm_df.to_csv(out_csv, index=False)
print('Saved:', out_csv)


Saved: H:\Grant\Neuropixels\Kilosort_Recordings\Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00\bombcell\bombcell_DEFAULT\kilosort4_B\bombcell\Probe_B_classification_reason_audit.csv
