# Export Classification Reason to Phy

This notebook extracts the main classification reason for each unit from Bombcell results and exports it as a TSV file that can be viewed in Phy's cluster view tab.

The classification reason shows why each unit was classified as GOOD, NOISE, MUA, or NON-SOMA.

## Usage:
1. Set the configuration parameters below (RUN_MODE, TARGET_PROBE, etc.)
2. Run all cells
3. The notebook will create a `cluster_bc_classificationReason.tsv` file in your Kilosort directory
4. Open the data in Phy - you'll see a new "bc_classificationReason" column in the cluster view

## Configuration

In [None]:
# Configuration - CHANGE THESE VALUES FOR YOUR DATA
CONFIG_FILE = r'C:\Users\user\Documents\github\bombcell\py_bombcell\grant\configs\grant_recording_config_reach15_20260201_session007.json'
RUN_MODE = 'single_probe'  # 'batch', 'single_probe', or 'np20_rerun'
TARGET_PROBE = 'E'  # Only used for single_probe mode

# Optional: Set to True to see detailed information about each unit's classification
VERBOSE = True

## Setup and Load Data

In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd

# Add parent directory to path for grant_config import
sys.path.insert(0, str(Path.cwd().parent))
from grant_config import load_grant_config

import bombcell as bc

In [None]:
# Load configuration
cfg = load_grant_config(CONFIG_FILE)

# Determine staging root based on run mode
if RUN_MODE == 'batch':
    staging_root = cfg['default_ks_staging_root']
elif RUN_MODE == 'np20_rerun':
    staging_root = cfg['np20_ks_staging_root']
else:  # single_probe
    staging_root = cfg['bombcell_singleprobe_root']

# Build paths
ks_dir = Path(staging_root) / f'kilosort4_{TARGET_PROBE}'
save_path = ks_dir / 'bombcell'

print('Kilosort directory:', ks_dir)
print('Bombcell save path:', save_path)
print()
print('TSV file will be saved to:', ks_dir / 'cluster_bc_classificationReason.tsv')

## Load Bombcell Results

In [None]:
# Load Bombcell results
param, quality_metrics, _ = bc.load_bc_results(str(save_path))

# Get unit classifications
unit_type, unit_type_string = bc.qm.get_quality_unit_type(param, quality_metrics)

# Create DataFrame
qm_df = pd.DataFrame(quality_metrics).copy()
qm_df['bombcell_label'] = unit_type_string
qm_df['unit_index'] = np.arange(len(qm_df))

# Get cluster IDs
if 'cluster_id' not in qm_df.columns:
    if 'unique_templates' in param:
        qm_df['cluster_id'] = param['unique_templates']
    elif 'phy_clusterID' in quality_metrics:
        qm_df['cluster_id'] = quality_metrics['phy_clusterID']
    else:
        qm_df['cluster_id'] = qm_df['unit_index']

print(f'Loaded {len(qm_df)} units')
print('\nLabel distribution:')
print(qm_df['bombcell_label'].value_counts(dropna=False))

## Define Classification Reason Extraction Logic

In [None]:
def col(name):
    """Helper function to safely get column from DataFrame"""
    if name in qm_df.columns:
        return qm_df[name]
    return pd.Series(np.nan, index=qm_df.index)

# Define NOISE failure conditions
noise_fail = {
    'nPeaks>maxNPeaks': col('nPeaks') > param['maxNPeaks'],
    'nTroughs>maxNTroughs': col('nTroughs') > param['maxNTroughs'],
    'wvDuration<minWvDuration': col('waveformDuration_peakTrough') < param['minWvDuration'],
    'wvDuration>maxWvDuration': col('waveformDuration_peakTrough') > param['maxWvDuration'],
    'baselineFlatness>maxWvBaselineFraction': col('waveformBaselineFlatness') > param['maxWvBaselineFraction'],
    'scndPeakToTroughRatio>maxScndPeakToTroughRatio_noise': col('scndPeakToTroughRatio') > param['maxScndPeakToTroughRatio_noise'],
}

# Add spatial decay checks if computed
if bool(param.get('computeSpatialDecay', False)):
    if bool(param.get('spDecayLinFit', False)):
        noise_fail['spatialDecaySlope<minSpatialDecaySlope'] = col('spatialDecaySlope') < param['minSpatialDecaySlope']
    else:
        noise_fail['spatialDecaySlope<minSpatialDecaySlopeExp'] = col('spatialDecaySlope') < param['minSpatialDecaySlopeExp']
        noise_fail['spatialDecaySlope>maxSpatialDecaySlopeExp'] = col('spatialDecaySlope') > param['maxSpatialDecaySlopeExp']

# Define MUA failure conditions
mua_fail = {
    'percentageSpikesMissing_gaussian>maxPercSpikesMissing': col('percentageSpikesMissing_gaussian') > param['maxPercSpikesMissing'],
    'nSpikes<minNumSpikes': col('nSpikes') < param['minNumSpikes'],
    'fractionRPVs_estimatedTauR>maxRPVviolations': col('fractionRPVs_estimatedTauR') > param['maxRPVviolations'],
    'presenceRatio<minPresenceRatio': col('presenceRatio') < param['minPresenceRatio'],
}

# Add raw waveform checks if computed
if bool(param.get('extractRaw', False)):
    mua_fail['rawAmplitude<minAmplitude'] = col('rawAmplitude') < param['minAmplitude']
    mua_fail['signalToNoiseRatio<minSNR'] = col('signalToNoiseRatio') < param['minSNR']

# Add drift checks if computed
if bool(param.get('computeDrift', False)):
    mua_fail['maxDriftEstimate>maxDrift'] = col('maxDriftEstimate') > param['maxDrift']

# Add distance metric checks if computed
if bool(param.get('computeDistanceMetrics', False)):
    mua_fail['isolationDistance<isoDmin'] = col('isolationDistance') < param['isoDmin']
    mua_fail['Lratio>lratioMax'] = col('Lratio') > param['lratioMax']

# Define NON-SOMA failure conditions
non_soma_fail = {
    'troughToPeak2Ratio<minTroughToPeak2Ratio_nonSomatic': col('troughToPeak2Ratio') < param['minTroughToPeak2Ratio_nonSomatic'],
    'mainPeak_before_width<minWidthFirstPeak_nonSomatic': col('mainPeak_before_width') < param['minWidthFirstPeak_nonSomatic'],
    'mainTrough_width<minWidthMainTrough_nonSomatic': col('mainTrough_width') < param['minWidthMainTrough_nonSomatic'],
    'peak1ToPeak2Ratio>maxPeak1ToPeak2Ratio_nonSomatic': col('peak1ToPeak2Ratio') > param['maxPeak1ToPeak2Ratio_nonSomatic'],
    'mainPeakToTroughRatio>maxMainPeakToTroughRatio_nonSomatic': col('mainPeakToTroughRatio') > param['maxMainPeakToTroughRatio_nonSomatic'],
}

print('Defined classification rules:')
print(f'  NOISE: {len(noise_fail)} criteria')
print(f'  MUA: {len(mua_fail)} criteria')
print(f'  NON-SOMA: {len(non_soma_fail)} criteria')

## Extract Classification Reasons

In [None]:
# Convert to numpy arrays for faster processing
noise_fail_np = {k: np.asarray(v, dtype=bool) for k, v in noise_fail.items()}
mua_fail_np = {k: np.asarray(v, dtype=bool) for k, v in mua_fail.items()}
nonsoma_fail_np = {k: np.asarray(v, dtype=bool) for k, v in non_soma_fail.items()}
labels_np = qm_df['bombcell_label'].astype(str).to_numpy()

def get_main_reason(i: int) -> str:
    """
    Get the main (first) classification reason for a unit.
    
    Returns the primary reason why the unit was classified as NOISE, MUA, NON-SOMA, or GOOD.
    """
    label = labels_np[i]
    
    # Find which criteria failed for this unit
    noise_hits = [k for k, v in noise_fail_np.items() if bool(v[i])]
    mua_hits = [k for k, v in mua_fail_np.items() if bool(v[i])]
    nonsoma_hits = [k for k, v in nonsoma_fail_np.items() if bool(v[i])]
    
    # Return the first (main) reason based on the label
    if label == 'NOISE':
        if noise_hits:
            return f'NOISE: {noise_hits[0]}'
        return 'NOISE'
    
    elif label in ('MUA', 'NON-SOMA MUA'):
        if mua_hits:
            return f'MUA: {mua_hits[0]}'
        return 'MUA'
    
    elif label in ('NON-SOMA', 'NON-SOMA GOOD'):
        if nonsoma_hits:
            return f'NON-SOMA: {nonsoma_hits[0]}'
        return 'NON-SOMA'
    
    elif label == 'GOOD':
        return 'GOOD: passed all thresholds'
    
    else:
        return f'{label}'

def get_all_reasons(i: int) -> str:
    """
    Get all classification reasons for a unit (for verbose output).
    
    Returns all reasons separated by ' | '.
    """
    label = labels_np[i]
    reasons = []

    noise_hits = [k for k, v in noise_fail_np.items() if bool(v[i])]
    mua_hits = [k for k, v in mua_fail_np.items() if bool(v[i])]
    nonsoma_hits = [k for k, v in nonsoma_fail_np.items() if bool(v[i])]

    if label == 'NOISE':
        reasons.extend([f'NOISE: {r}' for r in noise_hits] or ['NOISE'])
    elif label in ('MUA', 'NON-SOMA MUA'):
        reasons.extend([f'MUA: {r}' for r in mua_hits] or ['MUA'])
    elif label in ('NON-SOMA', 'NON-SOMA GOOD'):
        reasons.extend([f'NON-SOMA: {r}' for r in nonsoma_hits] or ['NON-SOMA'])

    if label == 'GOOD':
        reasons.append('GOOD: passed all thresholds')

    return ' | '.join(reasons) if reasons else label

# Extract main reason for each unit
qm_df['main_reason'] = [get_main_reason(i) for i in range(len(qm_df))]

print('Classification reasons extracted!')
print(f'\nExample reasons (first 5 units):')
for i in range(min(5, len(qm_df))):
    print(f"  Unit {i} ({labels_np[i]}): {qm_df['main_reason'].iloc[i]}")

## View Classification Reason Summary

In [None]:
# Show distribution of main reasons
print('Main reason distribution:')
print(qm_df['main_reason'].value_counts())
print()

# Show example units for each label type
if VERBOSE:
    print('\n' + '='*80)
    print('DETAILED VIEW: Example units for each classification')
    print('='*80)
    
    for label in ['GOOD', 'NOISE', 'MUA', 'NON-SOMA', 'NON-SOMA GOOD', 'NON-SOMA MUA']:
        subset = qm_df[qm_df['bombcell_label'] == label]
        if len(subset) > 0:
            print(f'\n{label} units (showing up to 3 examples):')
            for idx in subset.index[:3]:
                cluster_id = qm_df.loc[idx, 'cluster_id']
                main_reason = qm_df.loc[idx, 'main_reason']
                all_reasons = get_all_reasons(idx)
                print(f'  Cluster {cluster_id}:')
                print(f'    Main: {main_reason}')
                if ' | ' in all_reasons:
                    print(f'    All:  {all_reasons}')
                print()

## Export to Phy TSV File

In [None]:
# Create TSV file for Phy
output_file = ks_dir / 'cluster_bc_classificationReason.tsv'

# Create DataFrame with cluster_id and classification reason
export_df = pd.DataFrame({
    'cluster_id': qm_df['cluster_id'].astype(int),
    'bc_classificationReason': qm_df['main_reason']
})

# Save as TSV
export_df.to_csv(output_file, sep='\t', index=False)

print(f'âœ“ Successfully exported classification reasons to:')
print(f'  {output_file}')
print()
print(f'Exported {len(export_df)} units')
print()
print('To view in Phy:')
print('  1. Open your data in Phy')
print('  2. Look for the "bc_classificationReason" column in the cluster view')
print('  3. You can sort and filter by this column to explore classification reasons')

## Preview the Exported Data

In [None]:
# Show first few rows of exported data
print('Preview of exported TSV file (first 10 rows):')
print(export_df.head(10))
print()
print('Preview of exported TSV file (last 10 rows):')
print(export_df.tail(10))