# Kilosort4 Splitting / Mergging Cells (Open Ephys + Kilosort4)

## this File is for checking the relaibaility of any MANUAL splits and merges you make within phy
**Expected folder convention**
- `{NP_recording_name}/bombcell_DEFAULT/`
  - `DUPLICATED_KILOSORT4_FILES/`
  - `batch_DEFAULT_results/`
- `{NP_recording_name}/bombcell_NP2.0/`
  - `DUPLICATED_KILOSORT4_FILES_ACD/`
  - `NP2_ReRun_results/`

In [2]:
NP_recording_name = 'Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00'


In [3]:
# New Code
from pathlib import Path

RECORDING_ROOT = Path(r"H:\Grant\Neuropixels\Kilosort_Recordings") / NP_recording_name

BOMBCELL_DEFAULT_ROOT = RECORDING_ROOT / 'bombcell' / "bombcell_DEFAULT"
BOMBCELL_NP20_ROOT = RECORDING_ROOT / 'bombcell'  / "bombcell_NP2.0"
BOMBCELL_SINGLEPROBE_ROOT = RECORDING_ROOT / "bombcell" / "bombcell_single_probe"


DEFAULT_KS_STAGING_ROOT = BOMBCELL_DEFAULT_ROOT 
NP20_KS_STAGING_ROOT = BOMBCELL_NP20_ROOT 
BOMBCELL_KS_SINGLEPROBE_STAGING_ROOT = BOMBCELL_SINGLEPROBE_ROOT 

DEFAULT_EXPORT_ROOT = BOMBCELL_DEFAULT_ROOT / "batch_DEFAULT_results"
NP20_EXPORT_ROOT = BOMBCELL_NP20_ROOT / "NP2_ReRun_results"
SINGLE_EXPORT_ROOT = BOMBCELL_SINGLEPROBE_ROOT / "single_probe_results"

# Make sure they exist
for p in [
    DEFAULT_KS_STAGING_ROOT,
    NP20_KS_STAGING_ROOT,
    BOMBCELL_KS_SINGLEPROBE_STAGING_ROOT,
    DEFAULT_EXPORT_ROOT,
    NP20_EXPORT_ROOT,
    SINGLE_EXPORT_ROOT

]:
    p.mkdir(parents=True, exist_ok=True)
    print('ALL Paths Exist')


ALL Paths Exist
ALL Paths Exist
ALL Paths Exist
ALL Paths Exist
ALL Paths Exist
ALL Paths Exist


In [4]:
# =========================
# Configure
# =========================
from pathlib import Path
import pandas as pd
import numpy as np
import json

PROBES_ALL  = ["A","B","C","D","E","F"]
PROBES_NP20 = ["A","C","D"]

print("RECORDING_ROOT:", RECORDING_ROOT)
print("DEFAULT_EXPORT_ROOT exists:", DEFAULT_EXPORT_ROOT.exists())
print(DEFAULT_EXPORT_ROOT)
print("NP20_EXPORT_ROOT exists:", NP20_EXPORT_ROOT.exists())
print(NP20_EXPORT_ROOT)
print("SINGLE_EXPORT_ROOT exists:", SINGLE_EXPORT_ROOT.exists())
print(SINGLE_EXPORT_ROOT)

RECORDING_ROOT: H:\Grant\Neuropixels\Kilosort_Recordings\Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00
DEFAULT_EXPORT_ROOT exists: True
H:\Grant\Neuropixels\Kilosort_Recordings\Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00\bombcell\bombcell_DEFAULT\batch_DEFAULT_results
NP20_EXPORT_ROOT exists: True
H:\Grant\Neuropixels\Kilosort_Recordings\Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00\bombcell\bombcell_NP2.0\NP2_ReRun_results
SINGLE_EXPORT_ROOT exists: True
H:\Grant\Neuropixels\Kilosort_Recordings\Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00\bombcell\bombcell_single_probe\single_probe_results


In [5]:
# =========================
# Helpers
# =========================
def load_probe_exports(export_root: Path, probe: str):
    # Loads Probe_{probe} exports: quality_metrics.csv, unit_type_counts.csv, param.json, checks.json.
    probe_dir = export_root / f"Probe_{probe}"
    qm_path = probe_dir / f"Probe_{probe}_quality_metrics.csv"
    counts_path = probe_dir / f"Probe_{probe}_unit_type_counts.csv"
    param_path = probe_dir / f"Probe_{probe}_param.json"
    checks_path = probe_dir / f"Probe_{probe}_checks.json"
    err_path = probe_dir / "ERROR.txt"

    if err_path.exists():
        return {"probe": probe, "status": "FAILED", "error": err_path.read_text(), "probe_dir": probe_dir}

    out = {"probe": probe, "status": "OK", "probe_dir": probe_dir}
    out["qm"] = pd.read_csv(qm_path) if qm_path.exists() else None
    out["counts"] = pd.read_csv(counts_path) if counts_path.exists() else None
    out["param"] = json.loads(param_path.read_text()) if param_path.exists() else {}
    out["checks"] = json.loads(checks_path.read_text()) if checks_path.exists() else {}

    out["cluster_id_col"] = None

    out["ks_dir"] = out["checks"].get("ks_dir", None)
    out["save_path"] = out["checks"].get("save_path", None)
    
    if out["qm"] is not None:
        for c in ["cluster_id","clusterID","cluster_id_ks","cluster_id_phy","cluster"]:
            if c in out["qm"].columns:
                out["cluster_id_col"] = c
                break

    return out

def load_batch_summary(export_root: Path):
    p = export_root / "batch_summary.csv"
    return pd.read_csv(p) if p.exists() else None

def summarize_unit_types(qm: pd.DataFrame, label_col="Bombcell_unit_type"):
    if qm is None or label_col not in qm.columns:
        return None
    return qm[label_col].value_counts().rename_axis("unit_type").reset_index(name="count")

def add_percentages(df_counts: pd.DataFrame):
    if df_counts is None or df_counts.empty:
        return df_counts
    total = df_counts["count"].sum()
    df_counts = df_counts.copy()
    df_counts["pct"] = 100 * df_counts["count"] / total
    return df_counts

def find_cluster_row(qm: pd.DataFrame, cluster_id: int, cluster_id_col: str):
    if qm is None:
        raise ValueError("qm is None")
    if cluster_id_col is None or cluster_id_col not in qm.columns:
        raise ValueError("No cluster_id column found in quality_metrics.csv")
    sub = qm.loc[qm[cluster_id_col] == cluster_id]
    if sub.empty:
        raise KeyError(f"Cluster id {cluster_id} not found in {cluster_id_col}")
    return sub.iloc[0]

def threshold_fail_report(row, qm_cols, param):
    # Common Bombcell gates; only checks metrics that exist in the CSV.
    rules = [
        ("rawAmplitude", "<", param.get("minAmplitude", 40)),
        ("signalToNoiseRatio", "<", param.get("minSNR", 5)),
        ("presenceRatio", "<", param.get("minPresenceRatio", 0.7)),
        ("fractionRPVs_estimatedTauR", ">", param.get("maxRPVviolations", 0.1)),
        ("percentageSpikesMissing_gaussian", ">", param.get("maxPercSpikesMissing", 20)),
        ("waveformDuration_peakTrough", "<", param.get("minWvDuration", 100)),
        ("waveformDuration_peakTrough", ">", param.get("maxWvDuration", 1150)),
        ("nPeaks", ">", param.get("maxNPeaks", 2)),
        ("nTroughs", ">", param.get("maxNTroughs", 1)),
        ("waveformBaselineFlatness", ">", param.get("maxWvBaselineFraction", 0.3)),
    ]
    fails = []
    for col, op, thr in rules:
        if col not in qm_cols:
            continue
        v = row[col]
        if pd.isna(v):
            continue
        if (op == "<" and v < thr) or (op == ">" and v > thr):
            fails.append((col, float(v), op, float(thr)))
    return fails

## Load DEFAULT exports (all probes)

In [6]:
default_summary = load_batch_summary(DEFAULT_EXPORT_ROOT)
default_summary

Unnamed: 0,probe,status,ks_dir,save_path,n_NOISE,n_MUA,n_NON-SOMA,n_GOOD,max_raw_metric_nan_frac
0,A,OK,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,392,109,51,2,0.012635
1,B,OK,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,247,764,286,19,0.00228
2,C,OK,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,154,135,30,36,0.014085
3,D,OK,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,93,540,109,78,0.002439
4,E,OK,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,100,394,251,119,0.008102
5,F,OK,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,95,273,210,113,0.007236


## ==================================
## Check all probes and all BC session
## ==================================
- batch
- NP2.0 
- single probe

In [7]:
default_data = {p: load_probe_exports(DEFAULT_EXPORT_ROOT, p) for p in PROBES_ALL}

for p in PROBES_ALL:
    d = default_data[p]
    print("="*60, f"Probe {p} ({d['status']})")
    if d["status"] != "OK":
        print(d.get("error",""))
        continue
    counts = add_percentages(summarize_unit_types(d["qm"]))
    display(counts)



Unnamed: 0,unit_type,count,pct
0,NOISE,392,70.758123
1,MUA,109,19.67509
2,NON-SOMA,51,9.205776
3,GOOD,2,0.361011




Unnamed: 0,unit_type,count,pct
0,MUA,764,58.054711
1,NON-SOMA,286,21.732523
2,NOISE,247,18.768997
3,GOOD,19,1.443769




Unnamed: 0,unit_type,count,pct
0,NOISE,154,43.380282
1,MUA,135,38.028169
2,GOOD,36,10.140845
3,NON-SOMA,30,8.450704




Unnamed: 0,unit_type,count,pct
0,MUA,540,65.853659
1,NON-SOMA,109,13.292683
2,NOISE,93,11.341463
3,GOOD,78,9.512195




Unnamed: 0,unit_type,count,pct
0,MUA,394,45.601852
1,NON-SOMA,251,29.050926
2,GOOD,119,13.773148
3,NOISE,100,11.574074




Unnamed: 0,unit_type,count,pct
0,MUA,273,39.507959
1,NON-SOMA,210,30.390738
2,GOOD,113,16.353111
3,NOISE,95,13.748191


## Load NP2.0 rerun exports (A/C/D)

In [8]:
np20_summary = load_batch_summary(NP20_EXPORT_ROOT)
np20_summary

Unnamed: 0,probe,status,ks_dir,save_path,n_NOISE,n_MUA,n_NON-SOMA,n_GOOD,max_raw_metric_nan_frac,error
0,A,OK,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,297.0,178.0,72.0,7.0,0.012635,
1,C,OK,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,H:\Grant\Neuropixels\Kilosort_Recordings\Reach...,128.0,129.0,39.0,59.0,0.014085,
2,D,FAILED,,,,,,,,No quality_metrics found (unknown failure)


## Load NP2.0 Re-run results

In [9]:
np20_data = {p: load_probe_exports(NP20_EXPORT_ROOT, p) for p in PROBES_NP20}

for p in PROBES_NP20:
    d = np20_data[p]
    print("="*60, f"Probe {p} ({d['status']})")
    if d["status"] != "OK":
        print(d.get("error",""))
        continue
    counts = add_percentages(summarize_unit_types(d["qm"]))
    display(counts)



Unnamed: 0,unit_type,count,pct
0,NOISE,297,53.610108
1,MUA,178,32.129964
2,NON-SOMA,72,12.99639
3,GOOD,7,1.263538




Unnamed: 0,unit_type,count,pct
0,MUA,129,36.338028
1,NOISE,128,36.056338
2,GOOD,59,16.619718
3,NON-SOMA,39,10.985915


No quality_metrics found (unknown failure)


## Load single probe data

In [10]:
single_summary = load_batch_summary(SINGLE_EXPORT_ROOT)
single_summary

print(SINGLE_EXPORT_ROOT)

H:\Grant\Neuropixels\Kilosort_Recordings\Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00\bombcell\bombcell_single_probe\single_probe_results


In [11]:
single_probe_data = {p: load_probe_exports(SINGLE_EXPORT_ROOT, p) for p in PROBES_ALL}

for p in PROBES_ALL:
    d = single_probe_data[p]
    print("="*60, f"Probe {p} ({d['status']})")
    if d["status"] != "OK":
        print(d.get("error",""))
        continue
    counts = add_percentages(summarize_unit_types(d["qm"]))
    display(counts)



None



Unnamed: 0,unit_type,count,pct
0,GOOD,425,32.294833
1,MUA,358,27.203647
2,NON-SOMA,286,21.732523
3,NOISE,247,18.768997




None



None



None



None

## ==================================
## Check units here
## ==================================


#### Options for BC_SESSION you may have used 

In [12]:
BOMBCELL_DEFAULT_ROOT = RECORDING_ROOT / 'bombcell' / "bombcell_DEFAULT"
BOMBCELL_NP20_ROOT = RECORDING_ROOT / 'bombcell'  / "bombcell_NP2.0"
BOMBCELL_SINGLEPROBE_ROOT = RECORDING_ROOT / "bombcell" / "bombcell_single_probe"


## Load in the Kilosort4 directory you were using in phy
##### SET the following:
1. Probe Letter = 'A', 'B', 'C', 'D', 'E', 'F'
2. BC_SESSION = BOMBCELL_DEFAULT_ROOT / BOMBCELL_NP20_ROOT / BOMBCELL_SINGLEPROBE_ROOT


In [23]:
probe = "B"      

run = "SINGLE"      # "DEFAULT" or "NP20" or 'SINGLE'
BC_SESSION = BOMBCELL_SINGLEPROBE_ROOT # BOMBCELL_DEFAULT_ROOT or BOMBCELL_NP20_ROOT or BOMBCELL_SINGLEPROBE_ROOT


In [20]:
ks_dir = Path(fr"{BC_SESSION}\kilosort4_{probe}")

if ks_dir.exists():
    print(f'Found BC session for probe {probe}')
    print(ks_dir)
else:
    print('No Dir Found')


Found BC session for probe B
H:\Grant\Neuropixels\Kilosort_Recordings\Reach15_20260201_session007_NP_Recording_Number02_2026-02-01_18-25-00\bombcell\bombcell_single_probe\kilosort4_B


## Check direct kilosort4 outputs (NOT BC outputs):
kilosort4 Output Files:
1. **amplitudes.npy** -> AmplitudeView is not computed from the raw waveform you see in WaveformView. It is computed from Kilosort’s per-spike amplitude scalar (usually amplitudes.npy, i.e., template scaling for each spike)
2. **spike_clusters.npy**

In [15]:
import numpy as np

amp_path = ks_dir / "amplitudes.npy"
print('==============================================================')
print('AmplitudeView Overview')
print('==============================================================')
print('')
print("amplitudes.npy exists:", amp_path.exists())

if amp_path.exists():

    print('')
    amplitudes = np.load(amp_path).squeeze()
    spike_clusters = np.load(ks_dir / "spike_clusters.npy").squeeze()

    cl = 23
    a = amplitudes[spike_clusters == cl]
    print("Cluster", cl, "n spikes:", a.size)
    print("min/median/max:", float(np.min(a)), float(np.median(a)), float(np.max(a)))
    print("std:", float(np.std(a)))
    print("unique values (first 10):", np.unique(a)[:10])
    print('')

    print('==============================================================')
    print('what fraction of spikes are in the tight band:')
    print('==============================================================')
    print('')

    lo, hi = np.percentile(a, [1, 99])
    print("1st–99th percentile:", lo, hi)
    print("Fraction within [median ± 3*std]:", np.mean((a > np.median(a)-3*np.std(a)) & (a < np.median(a)+3*np.std(a))))



AmplitudeView Overview

amplitudes.npy exists: True

Cluster 23 n spikes: 105079
min/median/max: 10.338539123535156 14.226888656616211 42.52353286743164
std: 1.7066458463668823
unique values (first 10): [10.338539  10.375977  10.391951  10.41609   10.436833  10.43893
 10.527426  10.5430975 10.56171   10.586558 ]

what fraction of spikes are in the tight band:

1st–99th percentile: 11.652233867645263 19.18854705810547
Fraction within [median ± 3*std]: 0.9908925665451708


## Set NEW unitIDs created by splitting single unit

In [21]:
# NEW unitIDs
cl_a = 1324
cl_b = 1325

## 1️⃣ Within-unit RPV check (ISI < 2 ms)
#### compute the fraction of cross-intervals (nearest spike in NEW_UNIT_#1 relative to spikes in NEW_UNIT_#2) that fall within <2 ms. 
- If that fraction is low (comparable to within-unit RPV), the split is defensible regardless of baseline asymmetry.

In [22]:
import numpy as np

fs = 30000
r_ms = 2.0
r_samp = int(round((r_ms / 1000) * fs))

spike_times = np.load(ks_dir / "spike_times.npy").squeeze().astype(np.int64)
spike_clusters = np.load(ks_dir / "spike_clusters.npy").squeeze().astype(np.int64)

ta = np.sort(spike_times[spike_clusters == cl_a])
tb = np.sort(spike_times[spike_clusters == cl_b])

def within_rpv_fraction(t, r_samp):
    if t.size < 2:
        return np.nan
    return np.mean(np.diff(t) < r_samp)

def nearest_neighbor_cross_fraction(t_src, t_tgt, r_samp):
    if t_src.size == 0 or t_tgt.size == 0:
        return np.nan

    idx = np.searchsorted(t_tgt, t_src)

    d = np.full(t_src.shape, np.inf, dtype=np.float64)

    m = idx > 0
    if np.any(m):
        d[m] = np.minimum(d[m], np.abs(t_src[m] - t_tgt[idx[m] - 1]))

    m = idx < t_tgt.size
    if np.any(m):
        d[m] = np.minimum(d[m], np.abs(t_src[m] - t_tgt[idx[m]]))

    return np.mean(d < r_samp)

print("Within-unit RPV fractions (ISI < 2 ms):")
print("1324:", within_rpv_fraction(ta, r_samp))
print("1325:", within_rpv_fraction(tb, r_samp))

print("\nCross nearest-neighbor fractions (|Δt| < 2 ms):")
print("1324 -> 1325:", nearest_neighbor_cross_fraction(ta, tb, r_samp))
print("1325 -> 1324:", nearest_neighbor_cross_fraction(tb, ta, r_samp))


Within-unit RPV fractions (ISI < 2 ms):
1324: 0.010101010101010102
1325: 0.010700062377045247

Cross nearest-neighbor fractions (|Δt| < 2 ms):
1324 -> 1325: 0.027296301351166917
1325 -> 1324: 0.0049420627114171245


### What to look for in output from above cell -> 1️⃣ Within-unit RPV check (ISI < 2 ms)

2. **Low refractory violation fraction (<~1–2%)**

- Values near or below ~1% are consistent with a well-isolated single unit.

- Values substantially higher (>3–5%) suggest contamination or bad splitting.

2 **Improvement after split**

- If the original merged unit had elevated RPVs and both children now have lower RPVs, the split likely improved isolation.

- If RPVs increase after split, the split likely introduced contamination.

3 **Symmetry between child units**

- If both children have similar low RPV values, that supports two plausible neurons.

-  If one child has clean RPV and the other is elevated, the latter is likely contamination/subcluster.

## =========================================================================================
## =========================================================================================

## 2️⃣ Cross nearest-neighbor fraction check (< 2 ms)
#### Compute the same cross-nearest-neighbor fraction restricted to the time interval where both clusters are active, e.g., the overlap window between their 5th and 95th percentile spike times. 
- If the elevated 2.73% collapses toward ~1% in the overlap window, the asymmetry was mostly nonstationarity and the split is fine. 
- If it stays high, you likely still have mixing at the split boundary.

In [None]:
import numpy as np

# Assumes ks_dir is already defined in your notebook
fs = 30000
r_ms = 2.0
r_samp = int(round((r_ms / 1000) * fs))

spike_times = np.load(ks_dir / "spike_times.npy").squeeze().astype(np.int64)
spike_clusters = np.load(ks_dir / "spike_clusters.npy").squeeze().astype(np.int64)

ta_all = np.sort(spike_times[spike_clusters == cl_a])
tb_all = np.sort(spike_times[spike_clusters == cl_b])

def within_rpv_fraction(t, r_samp):
    if t.size < 2:
        return np.nan
    return float(np.mean(np.diff(t) < r_samp))

def nearest_neighbor_cross_fraction(t_src, t_tgt, r_samp):
    if t_src.size == 0 or t_tgt.size == 0:
        return np.nan

    idx = np.searchsorted(t_tgt, t_src)
    d = np.full(t_src.shape, np.inf, dtype=np.float64)

    m = idx > 0
    if np.any(m):
        d[m] = np.minimum(d[m], np.abs(t_src[m] - t_tgt[idx[m] - 1]))

    m = idx < t_tgt.size
    if np.any(m):
        d[m] = np.minimum(d[m], np.abs(t_src[m] - t_tgt[idx[m]]))

    return float(np.mean(d < r_samp))

def restrict_to_overlap_window(ta, tb, p_lo=5, p_hi=95):
    if ta.size == 0 or tb.size == 0:
        return ta, tb, (None, None)

    a_lo, a_hi = np.percentile(ta, [p_lo, p_hi])
    b_lo, b_hi = np.percentile(tb, [p_lo, p_hi])

    t0 = int(max(a_lo, b_lo))
    t1 = int(min(a_hi, b_hi))

    if t1 <= t0:
        return np.array([], dtype=ta.dtype), np.array([], dtype=tb.dtype), (t0, t1)

    ta_win = ta[(ta >= t0) & (ta <= t1)]
    tb_win = tb[(tb >= t0) & (tb <= t1)]
    return ta_win, tb_win, (t0, t1)

# Full-session metrics
within_a = within_rpv_fraction(ta_all, r_samp)
within_b = within_rpv_fraction(tb_all, r_samp)
cross_a_to_b = nearest_neighbor_cross_fraction(ta_all, tb_all, r_samp)
cross_b_to_a = nearest_neighbor_cross_fraction(tb_all, ta_all, r_samp)

# Overlap-window metrics (5th–95th percentile overlap)
ta_win, tb_win, (t0, t1) = restrict_to_overlap_window(ta_all, tb_all, p_lo=5, p_hi=95)

within_a_win = within_rpv_fraction(ta_win, r_samp)
within_b_win = within_rpv_fraction(tb_win, r_samp)
cross_a_to_b_win = nearest_neighbor_cross_fraction(ta_win, tb_win, r_samp)
cross_b_to_a_win = nearest_neighbor_cross_fraction(tb_win, ta_win, r_samp)

# Report
print("Clusters:", cl_a, "and", cl_b)
print(f"Refractory window: {r_ms} ms ({r_samp} samples @ {fs} Hz)")

print("\n=== Full session ===")
print("n spikes:", cl_a, ta_all.size, "|", cl_b, tb_all.size)
print("Within RPV(<2ms):", cl_a, within_a, "|", cl_b, within_b)
print("Cross NN(<2ms):", f"{cl_a}->{cl_b}", cross_a_to_b, "|", f"{cl_b}->{cl_a}", cross_b_to_a)

print("\n=== Overlap window (5th–95th percentile intersection) ===")
if ta_win.size == 0 or tb_win.size == 0:
    print("No usable overlap window (clusters largely non-overlapping in time).")
    print("Overlap bounds (samples):", (t0, t1))
else:
    print("Overlap bounds:")
    print("  samples:", (t0, t1))
    print("  seconds:", (t0 / fs, t1 / fs))
    print("n spikes in overlap:", cl_a, ta_win.size, "|", cl_b, tb_win.size)
    print("Within RPV(<2ms):", cl_a, within_a_win, "|", cl_b, within_b_win)
    print("Cross NN(<2ms):", f"{cl_a}->{cl_b}", cross_a_to_b_win, "|", f"{cl_b}->{cl_a}", cross_b_to_a_win)


Clusters: 1324 and 1325
Refractory window: 2.0 ms (60 samples @ 30000 Hz)

=== Full session ===
n spikes: 1324 7327 | 1325 41683
Within RPV(<2ms): 1324 0.010101010101010102 | 1325 0.010700062377045247
Cross NN(<2ms): 1324->1325 0.027296301351166917 | 1325->1324 0.0049420627114171245

=== Overlap window (5th–95th percentile intersection) ===
Overlap bounds:
  samples: (123994856, 269452032)
  seconds: (4133.161866666666, 8981.7344)
n spikes in overlap: 1324 4881 | 1325 35731
Within RPV(<2ms): 1324 0.0063524590163934426 | 1325 0.010495382031905962
Cross NN(<2ms): 1324->1325 0.03278016799836099 | 1325->1324 0.00464582575354734


### What to look for in output from above cell (Cross nearest-neighbor fraction check, < 2 ms) -> 2️⃣ Cross nearest-neighbor fraction check (< 2 ms)


1. **Cross fraction comparable to within-unit RPV**

- If cross NN(<2 ms) is similar to the within-unit RPV values, the split is consistent with two independent spike trains.

- If cross NN(<2 ms) is much higher (e.g., >2–3× within-unit), the split likely has residual mixing or is not separating two independent neurons.

2. **Bidirectional consistency**

- Clean splits usually show similar cross fractions in both directions (A→B and B→A).

- Strong asymmetry (A→B ≫ B→A) suggests subset structure, temporal coupling, or an amplitude-threshold artifact rather than a clean 2-neuron separation.

3. **Cross fraction should not exceed refractory logic**

- If cross NN(<2 ms) is high while both within-unit RPVs are low, that implies the two clusters are not refractory-independent.

- This is a red flag for either partial mixing or an amplitude-driven partition of one underlying unit.

## =========================================================================================
## =========================================================================================

## 3️⃣ Cross nearest-neighbor distance distribution (in samples)
#### Compute simultaneity in samples for cross-nearest neighbors, not just <2 ms. 
##### Specifically, look at the distribution of nearest-neighbor cross distances in samples

In [24]:
import numpy as np

# Assumes ks_dir is already defined
fs = 30000
cl_a = 1324
cl_b = 1325

spike_times = np.load(ks_dir / "spike_times.npy").squeeze().astype(np.int64)
spike_clusters = np.load(ks_dir / "spike_clusters.npy").squeeze().astype(np.int64)

ta = np.sort(spike_times[spike_clusters == cl_a])
tb = np.sort(spike_times[spike_clusters == cl_b])

def nearest_neighbor_distances_samples(t_src, t_tgt):
    """
    For each spike in t_src, compute the absolute time distance (in samples)
    to the nearest spike in t_tgt.
    """
    if t_src.size == 0 or t_tgt.size == 0:
        return np.array([], dtype=np.int64)

    idx = np.searchsorted(t_tgt, t_src)
    d = np.full(t_src.shape, np.inf, dtype=np.float64)

    m = idx > 0
    if np.any(m):
        d[m] = np.minimum(d[m], np.abs(t_src[m] - t_tgt[idx[m] - 1]))

    m = idx < t_tgt.size
    if np.any(m):
        d[m] = np.minimum(d[m], np.abs(t_src[m] - t_tgt[idx[m]]))

    return d.astype(np.int64)

# Compute nearest-neighbor distances
d_a_to_b = nearest_neighbor_distances_samples(ta, tb)
d_b_to_a = nearest_neighbor_distances_samples(tb, ta)

# Summaries
def summarize(d, name):
    print(f"\n=== {name} ===")
    print("n:", d.size)
    if d.size == 0:
        return
    print("min/median/mean/max (samples):",
          int(np.min(d)), float(np.median(d)), float(np.mean(d)), int(np.max(d)))
    print("min/median/mean/max (ms):",
          np.min(d)/fs*1000, np.median(d)/fs*1000, np.mean(d)/fs*1000, np.max(d)/fs*1000)

    for samp in [0, 1, 2, 3, 5, 10, 20, 30, 60, 100]:
        print(f"fraction d <= {samp:3d} samples ({samp/fs*1000:.3f} ms): {np.mean(d <= samp):.6f}")

summarize(d_a_to_b, f"{cl_a} -> {cl_b}")
summarize(d_b_to_a, f"{cl_b} -> {cl_a}")

# Histogram focused on 0..100 samples (0..3.33 ms)
max_samp = 100
bins = np.arange(0, max_samp + 2)  # integer bins

h_a, _ = np.histogram(d_a_to_b, bins=bins)
h_b, _ = np.histogram(d_b_to_a, bins=bins)

print("\n=== Histogram (nearest-neighbor distances, integer samples) ===")
print("Bin = exact distance in samples, from 0 to 100")
print("\nFirst 30 bins (0..29 samples):")
print("samples :", list(range(0, 30)))
print(f"{cl_a}->{cl_b}:", h_a[:30].tolist())
print(f"{cl_b}->{cl_a}:", h_b[:30].tolist())

print("\nTotal fraction in 0..100 samples:")
print(f"{cl_a}->{cl_b}:", np.mean(d_a_to_b <= 100))
print(f"{cl_b}->{cl_a}:", np.mean(d_b_to_a <= 100))



=== 1324 -> 1325 ===
n: 7327
min/median/mean/max (samples): 7 1505.0 13636.40343933397 1806752
min/median/mean/max (ms): 0.23333333333333334 50.166666666666664 454.54678131113235 60225.066666666666
fraction d <=   0 samples (0.000 ms): 0.000000
fraction d <=   1 samples (0.033 ms): 0.000000
fraction d <=   2 samples (0.067 ms): 0.000000
fraction d <=   3 samples (0.100 ms): 0.000000
fraction d <=   5 samples (0.167 ms): 0.000000
fraction d <=  10 samples (0.333 ms): 0.000546
fraction d <=  20 samples (0.667 ms): 0.002866
fraction d <=  30 samples (1.000 ms): 0.006551
fraction d <=  60 samples (2.000 ms): 0.027706
fraction d <= 100 samples (3.333 ms): 0.056230

=== 1325 -> 1324 ===
n: 41683
min/median/mean/max (samples): 7 9196.0 16022.724587961518 610907
min/median/mean/max (ms): 0.23333333333333334 306.5333333333333 534.0908195987173 20363.566666666666
fraction d <=   0 samples (0.000 ms): 0.000000
fraction d <=   1 samples (0.033 ms): 0.000000
fraction d <=   2 samples (0.067 ms): 0

### What to look for in output from above cell (Cross nearest-neighbor distance distribution, in samples)

1. **No ultra-short distances (0–3 samples)**

- Any non-zero fraction at 0–3 samples (≤0.1 ms) strongly suggests duplicated events or a bad split.

- A defensible split should have ~0 in this range.

2. **Cumulative fraction growth should be gradual**

- For independent units, the cumulative fraction should increase slowly as you go from 10 → 20 → 30 → 60 samples.

- A sharp jump concentrated at very small sample counts indicates spike sharing / mixing.

3. **Directional asymmetry should match expectations**

- If A→B is much larger than B→A, this implies A spikes frequently occur near B spikes, but B spikes usually do not occur near A spikes.

- This supports the interpretation that the smaller cluster may be a subset-like population (contamination, burst-associated subset, or threshold artifact) rather than a fully independent neuron.

## =========================================================================================
## =========================================================================================