# Retroactive ROI labels for Phy (IN_ROI / OUT_ROI)

Use this notebook to write `cluster_bc_roiLabel.tsv` into existing Kilosort folders **without rerunning Bombcell**.
Phy can then show `bc_roiLabel` in ClusterView after reopening the dataset.


In [None]:
from pathlib import Path
import sys
import numpy as np
import pandas as pd
import bombcell as bc

sys.path.insert(0, str((Path.cwd() / '..' / '..').resolve()))
from grant.grant_config import load_grant_config

CONFIG_FILE = Path('../configs/grant_recording_config.json')
PROBES_TO_PROCESS = ['A', 'B', 'C', 'D', 'E', 'F']
ROI_END_UM_OVERRIDE = {}  # example: {'B': 950}
KS_DIR_OVERRIDE_BY_PROBE = {}  # example: {'B': Path(r'D:/.../kilosort4_B')}
TIP_POSITION = 'min_y'  # 'min_y' or 'max_y'
DRY_RUN = False


In [None]:
cfg = load_grant_config(CONFIG_FILE)
print('Loaded config:', cfg['config_path'])
print('Recording:', cfg['recording_name'])

def resolve_ks_dir(probe):
    probe = probe.upper()
    if probe in KS_DIR_OVERRIDE_BY_PROBE:
        return Path(KS_DIR_OVERRIDE_BY_PROBE[probe])
    return Path(cfg['probe_kilosort_dirs'][probe])

def resolve_roi_end_um(probe):
    probe = probe.upper()
    if probe in ROI_END_UM_OVERRIDE:
        return ROI_END_UM_OVERRIDE[probe]
    return cfg.get('probe_recording_roi', {}).get(probe)

def load_quality_metrics_table(ks_dir):
    parquet_path = ks_dir / 'bombcell' / 'templates._bc_qMetrics.parquet'
    csv_path = ks_dir / 'metrics.csv'
    if parquet_path.exists():
        return pd.read_parquet(parquet_path), parquet_path
    if csv_path.exists():
        return pd.read_csv(csv_path), csv_path
    raise FileNotFoundError(
        f'Could not find quality metrics in {ks_dir}. Expected either {parquet_path.name} or {csv_path.name}.'
    )

def compute_roi_labels(quality_metrics_df, ks_dir, roi_end_um, tip_position='min_y', in_label='IN_ROI', out_label='OUT_ROI'):
    if 'maxChannels' not in quality_metrics_df.columns:
        raise KeyError("quality metrics table must include 'maxChannels'.")

    ephys_data = bc.load_ephys_data(str(ks_dir))
    channel_positions = ephys_data[6]
    shank_y = channel_positions[:, 1].astype(float)

    max_channels = quality_metrics_df['maxChannels'].astype(int).to_numpy()
    if np.any(max_channels < 0) or np.any(max_channels >= len(channel_positions)):
        raise IndexError(f'maxChannels out of range for {ks_dir}')

    unit_y = channel_positions[max_channels, 1].astype(float)
    if tip_position == 'min_y':
        dist_um = unit_y - float(np.nanmin(shank_y))
    elif tip_position == 'max_y':
        dist_um = float(np.nanmax(shank_y)) - unit_y
    else:
        raise ValueError("tip_position must be 'min_y' or 'max_y'.")

    return np.where(dist_um <= float(roi_end_um), in_label, out_label)

def resolve_cluster_ids(quality_metrics_df):
    if 'phy_clusterID' in quality_metrics_df.columns:
        return quality_metrics_df['phy_clusterID'].astype(int).to_numpy()
    if 'cluster_id' in quality_metrics_df.columns:
        return quality_metrics_df['cluster_id'].astype(int).to_numpy()
    raise KeyError("Could not find cluster IDs. Expected 'phy_clusterID' or 'cluster_id' in quality metrics table.")


In [None]:
results = []
for probe in [p.upper() for p in PROBES_TO_PROCESS]:
    ks_dir = resolve_ks_dir(probe)
    roi_end_um = resolve_roi_end_um(probe)

    if roi_end_um is None:
        print(f'Skipping probe {probe}: ROI not set in config and no override provided.')
        results.append({'probe': probe, 'status': 'SKIPPED_ROI_NOT_SET', 'ks_dir': str(ks_dir)})
        continue

    try:
        qm_df, qm_source = load_quality_metrics_table(ks_dir)
        cluster_ids = resolve_cluster_ids(qm_df)
        roi_labels = compute_roi_labels(qm_df, ks_dir, roi_end_um=roi_end_um, tip_position=TIP_POSITION)

        if len(cluster_ids) != len(roi_labels):
            raise ValueError(
                f'cluster_ids length ({len(cluster_ids)}) != roi_labels length ({len(roi_labels)}) for probe {probe}'
            )

        out_tsv = ks_dir / 'cluster_bc_roiLabel.tsv'
        out_df = pd.DataFrame({'cluster_id': cluster_ids, 'bc_roiLabel': roi_labels})

        if not DRY_RUN:
            out_df.to_csv(out_tsv, sep='\t', index=False)

        counts = out_df['bc_roiLabel'].value_counts().to_dict()
        print(f"Probe {probe}: {'would write' if DRY_RUN else 'wrote'} {out_tsv}")
        print(f'  Source quality metrics: {qm_source}')
        print(f'  Counts: {counts}')

        results.append({
            'probe': probe,
            'status': 'DRY_RUN' if DRY_RUN else 'WROTE',
            'ks_dir': str(ks_dir),
            'quality_metrics_source': str(qm_source),
            'roi_end_um': float(roi_end_um),
            'n_clusters': int(len(out_df)),
            'n_in_roi': int((out_df['bc_roiLabel'] == 'IN_ROI').sum()),
            'n_out_roi': int((out_df['bc_roiLabel'] == 'OUT_ROI').sum()),
        })
    except Exception as exc:
        print(f'Probe {probe}: ERROR -> {exc}')
        results.append({'probe': probe, 'status': 'ERROR', 'ks_dir': str(ks_dir), 'error': str(exc)})

summary_df = pd.DataFrame(results)
summary_df


## Next step in Phy
After writing `cluster_bc_roiLabel.tsv`, reopen the dataset in Phy.
You should be able to show the `bc_roiLabel` column in ClusterView.
