In [None]:
import numpy as np
from pathlib import Path

DATA_DIR = Path("data")
DATA_FS  = 30_000.0       # data sample-rate (Hz)
MRK_FS   =    512.0       # marker sample-rate (Hz)

def _parse_mrk(path: Path) -> np.ndarray:
    """return 512 Hz marker indices from a .mrk file"""
    out = []
    with path.open() as f:
        next(f)                          # skip BrainVision header
        for ln in f:
            sp = ln.strip().split()
            if len(sp) == 3:
                out.append(int(sp[0]))
    return np.asarray(out, int)

def mrk_count_by_patient(patient_ids, fraction=1.0):
    """print #events in first *fraction* of each patient’s recording"""
    for pid in patient_ids:
        # load EEG to get total length
        n_samples = np.load(DATA_DIR / f"Patient{pid}EEG.npy")[0].size
        trunc_len = int(n_samples * fraction)

        # load marker indices (512 Hz) and convert to 30 kHz space
        mrk_path  = DATA_DIR / f"Patient{pid:02d}_OfflineMrk.mrk"
        mrk_30k   = (_parse_mrk(mrk_path) * DATA_FS / MRK_FS).astype(int)

        n_events  = np.sum(mrk_30k < trunc_len)
        print(f"Patient {pid}: {n_events} events in first {fraction:.0%} of data")

# example: first 10 % on patients 2,3,4,6,7
mrk_count_by_patient([2, 3, 4, 6, 7], fraction=1.0)


In [None]:
import numpy as np
from pathlib import Path
import traceback

# --- Configuration ---
DATA_DIR = Path("data")
DATA_FS = 30_000.0  # Data sample-rate (Hz)
MRK_FS = 512.0      # Marker sample-rate (Hz)

# --- Helper Functions ---
def _parse_mrk(path: Path) -> np.ndarray:
    """Return 512 Hz marker indices from a .mrk file"""
    out = []
    if not path.exists():
        return np.array([], dtype=int)
    with path.open() as f:
        next(f)  # Skip BrainVision header
        # CORRECTED: Iterate over the file object line by line
        for ln in f:
            # Based on your first script, the marker index is the first element
            sp = ln.strip().split()
            if sp:  # Ensure the line is not empty after stripping
                try:
                    out.append(int(sp[0]))
                except (ValueError, IndexError):
                    # This will skip lines that don't start with a valid integer
                    continue
    return np.asarray(out, dtype=int)

def format_seconds(sec: float) -> str:
    """Formats seconds into a human-readable HH:MM:SS.ms string."""
    if sec < 0:
        return "00:00:00.000"
    m, s = divmod(sec, 60)
    h, m = divmod(m, 60)
    return f"{int(h):02d}:{int(m):02d}:{s:06.3f}"

# --- Main Verification Logic ---
def verify_marker_data_alignment(patient_ids: list):
    """
    Checks if the last event in a patient's .mrk file aligns with the
    total length of their .npy data file.
    """
    print(f"{'Patient':<10} | {'Status':<10} | {'Last Marker Time':<18} | {'Data Duration':<18} | {'Details'}")
    print(f"{'-'*10} | {'-'*10} | {'-'*18} | {'-'*18} | {'-'*40}")

    for pid in patient_ids:
        try:
            # 1. Load EEG data to get its total length in samples
            eeg_path = DATA_DIR / f"Patient{pid}EEG.npy"
            if not eeg_path.exists():
                print(f"{pid:<10} | {'❌ ERROR':<10} | {'N/A':<18} | {'N/A':<18} | EEG data file not found")
                continue

            n_samples = np.load(eeg_path)[0].size
            data_duration_sec = n_samples / DATA_FS

            # 2. Load and process marker indices
            mrk_path = DATA_DIR / f"Patient{pid:02d}_OfflineMrk.mrk"
            mrk_512 = _parse_mrk(mrk_path)

            if not mrk_512.size:
                print(f"{pid:<10} | {'⚠️ WARNING':<10} | {'N/A':<18} | {format_seconds(data_duration_sec):<18} | No markers found in file")
                continue

            # 3. Convert marker indices from 512 Hz to 30 kHz sample rate
            mrk_30k = (mrk_512 * DATA_FS / MRK_FS).astype(int)

            # 4. Perform the checks
            last_marker_sample = mrk_30k.max()
            last_marker_sec = last_marker_sample / DATA_FS

            # Check for any markers that fall outside the data file's duration
            n_outside_bounds = np.sum(mrk_30k >= n_samples)

            status = "✅ OK"
            details = f"{len(mrk_30k)} markers found."

            if n_outside_bounds > 0:
                status = "❌ ERROR"
                details = f"{n_outside_bounds} markers found AFTER data ends."
            # Warn if the last marker occurs very early in the recording
            elif last_marker_sec < (data_duration_sec * 0.8):
                status = "⚠️ WARNING"
                details = "Last marker occurs unusually early."

            print(f"{pid:<10} | {status:<10} | {format_seconds(last_marker_sec):<18} | {format_seconds(data_duration_sec):<18} | {details}")

        except Exception as e:
            print(f"{pid:<10} | {'❌ ERROR':<10} | {'-':<18} | {'-':<18} | Exception: {e}")
            # traceback.print_exc() # Uncomment for full error details

# --- Run Verification ---
if __name__ == "__main__":
    patient_list = [2, 3, 4, 6, 7] # The patients from your example
    verify_marker_data_alignment(patient_list)

In [None]:
#!/usr/bin/env python3
"""
DirectNeuralBiasing – multi-threaded grid-search with progress bars + event plots
==================================================================================

•  Outer progress bar: hyper-parameter *combo*.
•  Inner progress bars: per-patient chunks (processed in parallel), shows live TP / FP counts.
•  Fraction switch lets you smoke-test on 1 % of every EEG file.

Adjust **PARAM_GRID**, **PATIENT_IDS**, **FRACTION**, or **MAX_WORKERS** at the top.
"""

from __future__ import annotations
import itertools, os, csv, json, tempfile, threading
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from queue import Queue

import numpy as np, yaml, matplotlib.pyplot as plt
import direct_neural_biasing as dnb
from tqdm.auto import tqdm
import scipy.signal as sig 

# ─────────────────────—— user knobs ─────────────────────────────────────────
PATIENT_IDS      = [2,3,4,6,7]               # full: [2,3,4,6,7]
DATA_DIR         = Path("data")
FRACTION         = 1.0                 #
MAX_WORKERS      = 4                   # number of parallel threads

PARAM_GRID = {
    "z_score_threshold"      : [2.5, 3.0, 3.5, 4.0],        # keep full range
    "sinusoidness_threshold" : [0.4, 0.5, 0.6, 0.7],        # mid band only
    "check_sinusoidness"     : [True, False],          # still a switch

    # freeze slow-wave band for now
    "f_low"   : [0.25],
    "f_high"  : [4.0],

    # explore duration only when check_sin is True
    "min_wave_ms" : [250.0],                           # fixed
    "max_wave_ms" : [1000.0],

    # simpler IED + refractory
    "z_ied_threshold" : [1.5, 2.0, 2.5],
    "refrac_ms"       : [2500.0],
}


DATA_FS      = 30_000.0;  MRK_FS = 512.0
TOLERANCE_MS =       50
CTX_MS       =       800          # ± ms context for plots

SHOW_PLOTS   = True
MAX_PLOTS    = 4                  # per patient
PRINT_EVENTS = True
MAX_EVENT_PRINT = 6

RESULTS_DIR = Path("results"); RESULTS_DIR.mkdir(exist_ok=True)
OUT_CSV     = RESULTS_DIR / "grid_metrics_mt.csv"

# Thread-safe CSV writer
csv_lock = threading.Lock()
plot_queue = Queue()  # Queue for deferred plotting (matplotlib isn't thread-safe)

# ─────────────────────—— data helpers ──────────────────────────────────────
def _parse_mrk(p: Path):
    with p.open() as f:
        next(f)
        return np.asarray([int(l.split()[0]) for l in f if l.split()], int)

def _mrk512_to_30k(idx): return (idx * DATA_FS / MRK_FS).astype(int)

def load_patient(pid:int, frac:float)->Tuple[np.ndarray,Dict[int,int]]:
    sig = np.load(DATA_DIR/f"Patient{pid}EEG.npy")[0]
    sig = sig[: int(len(sig)*frac)]
    mrk = _parse_mrk(DATA_DIR/f"Patient{pid:02d}_OfflineMrk.mrk")
    gt  = dict(zip(_mrk512_to_30k(mrk), mrk))
    gt  = {k:v for k,v in gt.items() if k < len(sig)}   # drop markers > slice
    return sig, gt

# ─────────────────────—— plotting ──────────────────────────────────────────
def _plot(sig_raw: np.ndarray,
          center: int,
          gt_idx: np.ndarray,
          title: str,
          f_low: float = 0.25,         # pass in grid values if varied
          f_high: float = 4.0):
    """
    Show raw (grey) and band-pass-filtered (blue) signal ±CTX_MS around `center`.
    Skips plot if window has < 10 samples or is all-zero.
    """
    if center is None or center >= len(sig_raw):
        return
    ctx = int(CTX_MS / 1000 * DATA_FS)
    L, R = max(0, center - ctx), min(len(sig_raw), center + ctx)
    if R - L < 10 or np.allclose(sig_raw[L:R], 0, atol=1e-12):
        return

    # ── design Butterworth band-pass (2nd order each side) ──────────────────
    nyq = 0.5 * DATA_FS
    b, a = sig.butter(
        N=2,
        Wn=[f_low / nyq, f_high / nyq],
        btype="bandpass",
        analog=False,
    )
    filt = sig.filtfilt(b, a, sig_raw[L:R])

    t = (np.arange(L, R) - center) / DATA_FS * 1000  # ms axis

    plt.figure(figsize=(10, 3))
    plt.plot(t, sig_raw[L:R], color="lightgray", lw=0.6, label="raw")
    plt.plot(t, filt, color="C0", lw=1.0, label="band-passed")
    plt.axvline(0, c="r", lw=1)

    # GT markers inside window
    in_win = (gt_idx >= L) & (gt_idx < R)
    plt.vlines((gt_idx[in_win] - center) / DATA_FS * 1000,
               *plt.ylim(), colors="g", linestyles=":", label="GT")

    plt.title(title)
    plt.xlabel("time (ms)")
    plt.legend(loc="upper right")
    plt.tight_layout()
    plt.show()

# ─────────────────────—— metric / live inner bar ───────────────────────────
def eval_patient(proc, sig, gt_map, desc:str, position:int=0):
    """
    Evaluate a single patient with a given processor configuration.
    position: tqdm position for nested progress bars in multi-threaded context
    """
    tol = int(TOLERANCE_MS/1000*DATA_FS)
    gt_idx = np.fromiter(gt_map.keys(), int)
    matched = np.zeros(gt_idx.size,bool)
    tp=fp=0; events=[]
    bar = tqdm(range(0,len(sig),4096), leave=False, desc=desc, unit="chunk", position=position)

    for off in bar:
        out,_ = proc.run_chunk(sig[off:off+4096].tolist())
        for o in out:
            if o.get("detectors:slow_wave_detector:detected")!=1: continue
            det = int(o.get("detectors:slow_wave_detector:wave_start_index",-1))
            i = int(np.abs(gt_idx-det).argmin()) if gt_idx.size else None
            if i is not None and abs(gt_idx[i]-det)<=tol and not matched[i]:
                matched[i]=True; tp+=1; events.append(("TP",det,gt_idx[i]))
            else:
                fp+=1; events.append(("FP",det,None))
        bar.set_postfix(tp=tp, fp=fp)
    fn_idx = gt_idx[~matched]
    events.extend([("FN",None,x) for x in fn_idx])
    bar.close()
    return tp,fp,len(fn_idx),events

# ─────────────────────—— YAML helper ───────────────────────────────────────
BASE_CFG = {
    "processor":{"verbose":False,"fs":DATA_FS,"channel":1,
                 "enable_debug_logging":False},
    "filters":{"bandpass_filters":[
        {"id":"slow_wave_filter","f_low":0.25,"f_high":4.0},
        {"id":"ied_filter","f_low":80.0,"f_high":120.0}]},
    "detectors":{"wave_peak_detectors":[
        {"id":"slow_wave_detector","filter_id":"slow_wave_filter",
         "z_score_threshold":2.5,"sinusoidness_threshold":0.6,
         "check_sinusoidness":True,"wave_polarity":"downwave",
         "min_wave_length_ms":250.0,"max_wave_length_ms":1000.0},
        {"id":"ied_detector","filter_id":"ied_filter","z_score_threshold":2.5,
         "sinusoidness_threshold":0.0,"check_sinusoidness":False,
         "wave_polarity":"upwave"}]},
    "triggers":{"pulse_triggers":[{
        "id":"pulse_trigger","activation_detector_id":"slow_wave_detector",
        "inhibition_detector_id":"ied_detector",
        "inhibition_cooldown_ms":2500.0,"pulse_cooldown_ms":0}]}
}

def make_cfg(over:dict):
    cfg=deepcopy(BASE_CFG)
    det=cfg["detectors"]["wave_peak_detectors"][0]
    for k,v in over.items():
        if k=="z_score_threshold":det["z_score_threshold"]=v
        elif k=="sinusoidness_threshold":det["sinusoidness_threshold"]=v
        elif k=="check_sinusoidness":det["check_sinusoidness"]=v
        elif k=="f_low":cfg["filters"]["bandpass_filters"][0]["f_low"]=v
        elif k=="f_high":cfg["filters"]["bandpass_filters"][0]["f_high"]=v
        elif k=="min_wave_ms":det["min_wave_length_ms"]=v
        elif k=="max_wave_ms":det["max_wave_length_ms"]=v
        elif k=="z_ied_threshold":cfg["detectors"]["wave_peak_detectors"][1]["z_score_threshold"]=v
        elif k=="refrac_ms":
            cfg["triggers"]["pulse_triggers"][0]["inhibition_cooldown_ms"]=v
    return cfg

# ─────────────────────—— worker function ───────────────────────────────────
def process_patient(args):
    """Worker function to process a single patient"""
    cid, params, pid, cfg_path, position = args
    
    # Create processor from config
    proc = dnb.PySignalProcessor.from_config_file(cfg_path)
    
    # Load patient data
    sig, gt = load_patient(pid, FRACTION)
    desc = f"P{pid} combo{cid}"
    
    # Evaluate
    tp, fp, fn, events = eval_patient(proc, sig, gt, desc, position)
    
    # Calculate metrics
    prec = tp/(tp+fp) if tp+fp else 0
    rec = tp/(tp+fn) if tp+fn else 0
    
    # Prepare plot data if needed
    plot_data = None
    if SHOW_PLOTS:
        plot_data = {
            'sig': sig,
            'gt': gt,
            'events': events[:MAX_PLOTS + 1],  # Include one extra for FN
            'pid': pid,
            'cid': cid,
            'params': params
        }
    
    return {
        'cid': cid,
        'params': params,
        'pid': pid,
        'tp': tp,
        'fp': fp,
        'fn': fn,
        'prec': prec,
        'rec': rec,
        'events': events[:MAX_EVENT_PRINT] if PRINT_EVENTS else [],
        'plot_data': plot_data
    }

# ─────────────────────—— grid driver ───────────────────────────────────────
def run_grid():
    fresh = not OUT_CSV.exists()
    
    # Open CSV file and write header if needed
    with OUT_CSV.open("a", newline="") as fh:
        wr = csv.writer(fh)
        if fresh: 
            wr.writerow(["combo","params","patient","tp","fp","fn",
                         "precision","recall"])
    
    keys = list(PARAM_GRID.keys())
    combos = list(itertools.product(*PARAM_GRID.values()))
    outer = tqdm(combos, desc="param-sets")
    
    for cid, vals in enumerate(outer, 1):
        params = dict(zip(keys, vals))
        outer.set_postfix(params=params)
        cfg = make_cfg(params)
        
        # Create temporary config file
        with tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) as tmp:
            yaml.dump(cfg, tmp)
            cfg_path = tmp.name
        
        try:
            # Prepare tasks for all patients
            tasks = []
            for i, pid in enumerate(PATIENT_IDS):
                tasks.append((cid, params, pid, cfg_path, i+1))
            
            # Process patients in parallel
            with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
                # Submit all tasks
                future_to_task = {executor.submit(process_patient, task): task 
                                  for task in tasks}
                
                # Collect results as they complete
                results = []
                for future in as_completed(future_to_task):
                    try:
                        result = future.result()
                        results.append(result)
                    except Exception as e:
                        task = future_to_task[future]
                        print(f"Error processing patient {task[2]}: {e}")
                
                # Sort results by patient ID to maintain consistent order
                results.sort(key=lambda x: x['pid'])
                
                # Process results
                for result in results:
                    # Write to CSV (thread-safe)
                    with csv_lock:
                        with OUT_CSV.open("a", newline="") as fh:
                            wr = csv.writer(fh)
                            wr.writerow([
                                result['cid'],
                                json.dumps(result['params']),
                                result['pid'],
                                result['tp'],
                                result['fp'],
                                result['fn'],
                                result['prec'],
                                result['rec']
                            ])
                    
                    # Print summary
                    print(f"\n▶ combo {cid}/{len(combos)} params={params}"
                          f"\n   patient {result['pid']}: TP={result['tp']} FP={result['fp']} FN={result['fn']} "
                          f"prec={result['prec']:.3f} rec={result['rec']:.3f}")
                    
                    # Print events
                    if PRINT_EVENTS:
                        for lab, det, gt_i in result['events']:
                            if lab == "TP": 
                                print(f"     TP det={det} gt={gt_i}")
                            elif lab == "FP":
                                print(f"     FP det={det}")
                            else:
                                print(f"     FN miss gt={gt_i}")
                    
                    # Queue plots for later (matplotlib isn't thread-safe)
                    if result['plot_data']:
                        plot_queue.put(result['plot_data'])
                
                # Process all queued plots after parallel processing
                while not plot_queue.empty():
                    plot_data = plot_queue.get()
                    sig = plot_data['sig']
                    gt = plot_data['gt']
                    events = plot_data['events']
                    pid = plot_data['pid']
                    cid = plot_data['cid']
                    params = plot_data['params']
                    
                    plotted = 0
                    for lab, det, gt_i in events:
                        if plotted >= MAX_PLOTS: 
                            break
                        if lab in ("TP", "FP"):
                            _plot(sig, det, np.fromiter(gt.keys(), int),
                                  f"{lab}   P{pid}   cfg{cid}",
                                  params.get('f_low', 0.25),
                                  params.get('f_high', 4.0))
                            plotted += 1
                    
                    # Plot one FN if we have room
                    if plotted < MAX_PLOTS:
                        for lab, det, gt_i in events:
                            if lab == "FN":
                                _plot(sig, gt_i, np.fromiter(gt.keys(), int),
                                      f"FN   P{pid}   cfg{cid}",
                                      params.get('f_low', 0.25),
                                      params.get('f_high', 4.0))
                                break
        
        finally:
            # Clean up temp file
            os.remove(cfg_path)

if __name__ == "__main__":
    print(f"Starting multi-threaded grid search with {MAX_WORKERS} workers...")
    run_grid()
    print(f"\n✅ finished – rows saved to {OUT_CSV}")

In [None]:
#!/usr/bin/env python3
"""
Summarise grid-search results and pick the most selective parameter set
======================================================================

*   Reads **results/grid_metrics.csv** (rows = patient × param-combo).
*   Aggregates TP/FP/FN across all patients.
*   Computes precision (= selectivity), recall, and F1.
*   Ranks by **precision first** (tie-break by recall, then F1).
*   Prints the top 20 and writes a full summary CSV.
*   Dumps the best parameters as JSON for easy reuse.

Usage
-----

    python summarise_grid.py
"""

import json
from pathlib import Path

import pandas as pd

# ─── paths ──────────────────────────────────────────────────────────────────
CSV_PATH   = Path("results/grid_metrics.csv")
SUMMARY_CSV = CSV_PATH.with_suffix(".combo_summary.csv")
BEST_JSON   = CSV_PATH.with_suffix(".best_params.json")

# ─── load ───────────────────────────────────────────────────────────────────
if not CSV_PATH.exists():
    raise FileNotFoundError(f"{CSV_PATH} not found – run the grid search first.")

df = pd.read_csv(CSV_PATH)

# turn params JSON into a canonical string for grouping
df["params_str"] = df["params"].apply(
    lambda s: json.dumps(json.loads(s), sort_keys=True)
)

# ─── aggregate ──────────────────────────────────────────────────────────────
agg = (
    df.groupby("params_str")
      .agg(tp=("tp", "sum"),
           fp=("fp", "sum"),
           fn=("fn", "sum"))
)

agg["precision"] = agg["tp"] / (agg["tp"] + agg["fp"])
agg["recall"]    = agg["tp"] / (agg["tp"] + agg["fn"])
agg["f1"]        = 2 * agg["precision"] * agg["recall"] / \
                   (agg["precision"] + agg["recall"])

# ─── rank: maximise selectivity (precision) first ───────────────────────────
ranked = (
    agg.sort_values(
        by=["precision", "recall", "f1"],
        ascending=[False,   False,    False]
    )
)

# ─── display top 20 ─────────────────────────────────────────────────────────
pd.set_option("display.max_rows", 20)
print("\n🏆  TOP 20 PARAMETER SETS (by precision)\n")
print(
    ranked[["tp", "fp", "fn", "precision", "recall", "f1"]]
    .head(20)
    .to_string(float_format="%.3f")
)

# ─── write full summary ─────────────────────────────────────────────────────
ranked.to_csv(SUMMARY_CSV)
print(f"\n📄  Full combo summary saved to {SUMMARY_CSV}")

# ─── extract & save the best parameter set ─────────────────────────────────-
best_str = ranked.index[0]
best_params = json.loads(best_str)
with BEST_JSON.open("w") as f:
    json.dump(best_params, f, indent=2)
print(f"⭐  Best (most selective) params written to {BEST_JSON}\n")
print("Best params:\n", json.dumps(best_params, indent=2))
