# TUH EEG Corpus - Exploratory Analysis

This notebook provides a comprehensive exploration of all 7 TUH EEG corpora:
- **TUEG** v2.0.1 - Full corpus (1,639 GB)
- **TUAB** v3.0.1 - Abnormal EEG detection (58 GB)
- **TUAR** v3.0.1 - Artifact detection (5.4 GB)
- **TUEP** v3.0.0 - Epilepsy diagnosis (35 GB)
- **TUEV** v2.0.1 - Event classification (19 GB)
- **TUSZ** v2.0.3 - Seizure detection (81 GB)
- **TUSL** v2.0.1 - Slowing vs seizure (1.5 GB)

**Goal:** Understand data formats, distributions, and characteristics before model training.

In [None]:
import glob
import subprocess
from collections import Counter, defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import mne
import numpy as np
import pandas as pd
from tqdm import tqdm

from neurofisio import (
    apply_tcp_montage,
    inspect_edf,
    inventory_corpus,
    load_csv_annotations,
    load_rec_annotations,
    load_tse_annotations,
    parse_tuh_path,
    plot_eeg_segment,
    safe_read_raw_edf,
)

mne.set_log_level("WARNING")
plt.rcParams["figure.figsize"] = (16, 6)
plt.rcParams["figure.dpi"] = 100

DATA_ROOT = Path("/home/carlos/workspace/neurofisio/data/tuh_eeg")


def is_download_running():
    """Check if an rsync download targeting our data dir is active."""
    try:
        out = subprocess.check_output(
            ["pgrep", "-af", "rsync.*tuh_eeg"], text=True, stderr=subprocess.DEVNULL
        )
        return bool(out.strip())
    except subprocess.CalledProcessError:
        return False


DOWNLOAD_ACTIVE = is_download_running()

print(f"Data root exists: {DATA_ROOT.exists()}")
if DOWNLOAD_ACTIVE:
    print(
        "⚠ Active rsync download detected — notebook will use lightweight scanning (no heavy I/O)."
    )
if DATA_ROOT.exists():
    available = sorted(
        [d.name for d in DATA_ROOT.iterdir() if d.is_dir() and d.name.startswith("tu")]
    )
    print(f"Available corpora: {available}")
else:
    print(
        "⚠ Data root not found — download corpora first. Notebook will skip data-dependent cells."
    )

In [None]:
corpora_config = {
    "TUSL": DATA_ROOT / "tusl",
    "TUAR": DATA_ROOT / "tuar",
    "TUEV": DATA_ROOT / "tuev",
    "TUEP": DATA_ROOT / "tuep",
    "TUAB": DATA_ROOT / "tuab",
    "TUSZ": DATA_ROOT / "tusz",
    "TUEG": DATA_ROOT / "tueg",
}

inventories = []
if DATA_ROOT.exists():
    for name, path in corpora_config.items():
        if not path.exists():
            print(f"Skipping {name} (not yet downloaded)")
            continue
        try:
            has_edfs = any(path.rglob("*.edf"))
        except Exception:
            has_edfs = False
        if has_edfs:
            print(f"Scanning {name}...")
            inv = inventory_corpus(path, name)
            inventories.append(inv)
            print(f"  {inv['edf_files']} EDF files, {inv['subjects']} subjects")
        else:
            print(f"Skipping {name} (download in progress or empty)")
else:
    print("Data root not found — skipping inventory.")

if inventories:
    inv_df = pd.DataFrame(inventories)
    inv_df.set_index("corpus", inplace=True)
    display(
        inv_df[
            [
                "edf_files",
                "csv_files",
                "csv_bi_files",
                "tse_files",
                "lab_files",
                "rec_files",
                "subjects",
            ]
        ]
    )
else:
    inv_df = pd.DataFrame()
    print("No corpora with EDF files found yet.")

## 3. EDF File Structure Deep Dive

Inspect EDF headers across corpora: channels, sampling rates, durations.

In [None]:
# Sample N files from each available corpus to characterize EDF properties
N_SAMPLE = 50

edf_stats = []
for inv in inventories:
    corpus_name = inv["corpus"]
    corpus_path = corpora_config[corpus_name]
    all_edfs = sorted(glob.glob(str(corpus_path / "**/*.edf"), recursive=True))
    if not all_edfs:
        continue
    sample = np.random.choice(all_edfs, min(N_SAMPLE, len(all_edfs)), replace=False)

    for edf_path in tqdm(sample, desc=corpus_name):
        info = inspect_edf(edf_path)
        if "error" not in info:
            edf_stats.append(
                {
                    "corpus": corpus_name,
                    "n_channels": info["n_channels"],
                    "duration_min": info["duration_sec"] / 60,
                    "sfreq": info["sample_freqs"][0] if info["sample_freqs"] else None,
                    "path": edf_path,
                }
            )

edf_df = pd.DataFrame(edf_stats)
if not edf_df.empty:
    print(f"\nSampled {len(edf_df)} EDF files across {edf_df['corpus'].nunique()} corpora")
    display(
        edf_df.groupby("corpus")
        .agg(
            {
                "n_channels": ["mean", "min", "max"],
                "duration_min": ["mean", "min", "max"],
                "sfreq": ["mean", "min", "max"],
            }
        )
        .round(1)
    )
else:
    print("No EDF files could be read — corpora may still be downloading.")

In [None]:
# Visualize distributions
if not edf_df.empty:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    for corpus in edf_df["corpus"].unique():
        subset = edf_df[edf_df["corpus"] == corpus]
        axes[0].hist(subset["n_channels"], bins=20, alpha=0.6, label=corpus)
        axes[1].hist(subset["duration_min"], bins=20, alpha=0.6, label=corpus)
        axes[2].hist(subset["sfreq"], bins=20, alpha=0.6, label=corpus)

    axes[0].set_xlabel("Number of Channels")
    axes[0].set_title("Channel Count Distribution")
    axes[0].legend()
    axes[1].set_xlabel("Duration (minutes)")
    axes[1].set_title("Recording Duration Distribution")
    axes[2].set_xlabel("Sampling Frequency (Hz)")
    axes[2].set_title("Sampling Rate Distribution")

    plt.tight_layout()
    plt.savefig(str(DATA_ROOT.parent / "edf_distributions.png"), bbox_inches="tight")
    plt.show()
else:
    print("No EDF data to plot — skipping distributions.")

## 4. Signal Visualization

Load and plot actual EEG signals from each corpus.

In [None]:
# Plot one sample from each available corpus
for inv in inventories:
    if not inv["sample_edf"]:
        continue
    edf_path = inv["sample_edf"]
    corpus = inv["corpus"]
    print(f"\n{'=' * 80}")
    print(f"Corpus: {corpus} | File: {Path(edf_path).name}")

    raw = safe_read_raw_edf(edf_path)
    if raw is None:
        print(f"  Could not read {Path(edf_path).name} (file may be incomplete)")
        continue
    print(
        f"  Channels: {raw.info['nchan']}, Sfreq: {raw.info['sfreq']} Hz, Duration: {raw.times[-1]:.0f}s"
    )

    # Apply TCP montage
    tcp_raw = apply_tcp_montage(raw)
    if tcp_raw is not None:
        fig = plot_eeg_segment(
            tcp_raw,
            start_sec=10,
            duration_sec=10,
            title=f"{corpus}: {Path(edf_path).stem} (TCP montage, 10s window)",
        )
        plt.show()
    else:
        print("  Could not apply TCP montage, plotting raw channels")
        fig = plot_eeg_segment(
            raw,
            start_sec=10,
            duration_sec=10,
            title=f"{corpus}: {Path(edf_path).stem} (raw channels, 10s window)",
        )
        plt.show()

if not inventories:
    print("No corpora available to plot.")

## 5. Annotation Analysis by Corpus

### 5.1 TUAR - Artifact Labels

In [None]:
tuar_path = DATA_ROOT / "tuar"
if tuar_path.exists():
    tuar_csvs = sorted(tuar_path.rglob("*.csv"))
    tuar_csvs = [f for f in tuar_csvs if "_seiz" not in f.name]

    if tuar_csvs:
        all_labels = []
        all_durations = []
        for csv_path in tqdm(tuar_csvs[:100], desc="TUAR annotations"):
            df, meta = load_csv_annotations(csv_path)
            if "label" in df.columns and "start_time" in df.columns and "stop_time" in df.columns:
                df["duration"] = df["stop_time"] - df["start_time"]
                all_labels.extend(df["label"].tolist())
                all_durations.extend(df["duration"].tolist())

        label_counts = Counter(all_labels)
        print(f"TUAR: {len(tuar_csvs)} annotation files, {len(all_labels)} total events")

        if label_counts and all_durations:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))
            labels, counts = zip(*label_counts.most_common(15))
            ax1.barh(range(len(labels)), counts, color="steelblue")
            ax1.set_yticks(range(len(labels)))
            ax1.set_yticklabels(labels)
            ax1.set_xlabel("Count")
            ax1.set_title("TUAR: Artifact Type Distribution")
            ax1.invert_yaxis()

            ax2.hist(all_durations, bins=50, color="coral", edgecolor="white")
            ax2.set_xlabel("Duration (seconds)")
            ax2.set_title("TUAR: Artifact Event Duration")
            ax2.set_xlim(0, min(30, np.percentile(all_durations, 99)))
            plt.tight_layout()
            plt.show()
    else:
        print("TUAR directory exists but no CSV annotations yet (download in progress?)")
else:
    print("TUAR not available yet")

### 5.2 TUSZ - Seizure Annotations

In [None]:
tusz_path = DATA_ROOT / "tusz"
if tusz_path.exists():
    csv_bi_candidates = list(tusz_path.rglob("*.csv_bi"))
    if csv_bi_candidates:
        seizure_stats = {
            "train": defaultdict(int),
            "dev": defaultdict(int),
            "eval": defaultdict(int),
        }

        for split in ["train", "dev", "eval"]:
            split_path = tusz_path / "edf" / split
            if not split_path.exists():
                continue
            csv_bi_files = sorted(split_path.rglob("*.csv_bi"))

            for csv_path in tqdm(csv_bi_files[:200], desc=f"TUSZ {split}"):
                df, meta = load_csv_annotations(csv_path)
                if "label" in df.columns:
                    for label in df["label"].unique():
                        seizure_stats[split][label] += (df["label"] == label).sum()

        print("\nTUSZ Label Distribution (sampled):")
        for split, counts in seizure_stats.items():
            total = sum(counts.values())
            if total > 0:
                print(f"  {split}: {dict(counts)} (total: {total})")

        # Also check multi-class annotations
        edf_dir = tusz_path / "edf"
        if edf_dir.exists():
            csv_files = sorted(edf_dir.rglob("*.csv"))
            csv_files = [f for f in csv_files if not f.name.endswith(".csv_bi")]
            seizure_types = Counter()
            for csv_path in tqdm(csv_files[:300], desc="TUSZ seizure types"):
                df, _ = load_csv_annotations(csv_path)
                if "label" in df.columns:
                    for label in df["label"]:
                        if label != "bckg":
                            seizure_types[label] += 1

            if seizure_types:
                print(f"\nSeizure type distribution (from {min(300, len(csv_files))} files):")
                for label, count in seizure_types.most_common():
                    print(f"  {label}: {count}")
    else:
        print("TUSZ directory exists but no .csv_bi annotations yet (download in progress?)")
else:
    print("TUSZ not available yet")

### 5.3 TUAB - Normal/Abnormal Distribution

In [None]:
tuab_path = DATA_ROOT / "tuab"
if tuab_path.exists():
    edf_candidates = list(tuab_path.rglob("*.edf"))
    if edf_candidates:
        tuab_stats = defaultdict(lambda: defaultdict(int))

        for split in ["train", "eval"]:
            for label in ["normal", "abnormal"]:
                label_path = tuab_path / "edf" / split / label
                if label_path.exists():
                    n_files = len(list(label_path.rglob("*.edf")))
                    tuab_stats[split][label] = n_files

        if tuab_stats:
            print("TUAB Normal/Abnormal Distribution:")
            tuab_summary = pd.DataFrame(tuab_stats).T
            if "abnormal" in tuab_summary.columns and "normal" in tuab_summary.columns:
                tuab_summary["total"] = tuab_summary.sum(axis=1)
                tuab_summary["abnormal_pct"] = (
                    tuab_summary["abnormal"] / tuab_summary["total"] * 100
                ).round(1)
            display(tuab_summary)

        # Duration distribution by class
        dur_data = []
        for split in ["train", "eval"]:
            for label in ["normal", "abnormal"]:
                label_path = tuab_path / "edf" / split / label
                if not label_path.exists():
                    continue
                edfs = sorted(label_path.rglob("*.edf"))
                for edf in edfs[:50]:
                    info = inspect_edf(edf)
                    if "error" not in info:
                        dur_data.append(
                            {
                                "split": split,
                                "label": label,
                                "duration_min": info["duration_sec"] / 60,
                            }
                        )

        if dur_data:
            dur_df = pd.DataFrame(dur_data)
            fig, ax = plt.subplots(figsize=(10, 5))
            for label in ["normal", "abnormal"]:
                subset = dur_df[dur_df["label"] == label]
                if not subset.empty:
                    ax.hist(subset["duration_min"], bins=30, alpha=0.6, label=label)
            ax.set_xlabel("Duration (minutes)")
            ax.set_title("TUAB: Recording Duration by Class")
            ax.legend()
            plt.tight_layout()
            plt.show()
    else:
        print("TUAB directory exists but no EDF files yet (download in progress?)")
else:
    print("TUAB not available yet")

### 5.4 TUEV - Event Classification

In [None]:
tuev_path = DATA_ROOT / "tuev"
if tuev_path.exists():
    rec_files = sorted(tuev_path.rglob("*.rec"))
    if rec_files:
        event_counts = Counter()

        for rec_path in tqdm(rec_files[:200], desc="TUEV events"):
            df = load_rec_annotations(rec_path)
            if not df.empty:
                event_counts.update(df["label"].tolist())

        print(f"TUEV: {len(rec_files)} .rec files")
        print("Event distribution:")

        if event_counts:
            fig, ax = plt.subplots(figsize=(10, 4))
            labels, counts = zip(*event_counts.most_common())
            colors = {
                "spsw": "#e74c3c",
                "gped": "#e67e22",
                "pled": "#f39c12",
                "eyem": "#3498db",
                "artf": "#95a5a6",
                "bckg": "#2ecc71",
            }
            bar_colors = [colors.get(l, "#95a5a6") for l in labels]
            ax.bar(labels, counts, color=bar_colors)
            ax.set_ylabel("Count")
            ax.set_title("TUEV: Event Type Distribution")
            for i, (_l, c) in enumerate(zip(labels, counts)):
                ax.text(i, c + max(counts) * 0.01, str(c), ha="center", fontsize=9)
            plt.tight_layout()
            plt.show()
    else:
        print("TUEV directory exists but no .rec files yet (download in progress?)")
else:
    print("TUEV not available yet")

### 5.5 TUSL - Slowing Corpus

In [None]:
tusl_path = DATA_ROOT / "tusl"
if tusl_path.exists():
    tse_files = sorted(tusl_path.rglob("*.tse"))
    tse_agg_files = sorted(tusl_path.rglob("*.tse_agg"))

    if tse_files:
        label_counts = Counter()
        for tse_path in tqdm(tse_files, desc="TUSL annotations"):
            df = load_tse_annotations(tse_path)
            if not df.empty:
                label_counts.update(df["label"].tolist())

        print(f"TUSL: {len(tse_files)} .tse files, {len(tse_agg_files)} .tse_agg files")
        print(f"Labels: {dict(label_counts)}")

        if label_counts:
            fig, ax = plt.subplots(figsize=(8, 4))
            labels, counts = zip(*label_counts.most_common())
            colors = {"seiz": "#e74c3c", "slow": "#f39c12", "bckg": "#2ecc71"}
            ax.bar(labels, counts, color=[colors.get(l, "#95a5a6") for l in labels])
            ax.set_title("TUSL: Label Distribution")
            ax.set_ylabel("Count")
            plt.tight_layout()
            plt.show()
    else:
        print("TUSL directory exists but no .tse files yet (download in progress?)")
else:
    print("TUSL not available")

### 5.6 TUEP - Epilepsy Corpus with Metadata

In [None]:
tuep_path = DATA_ROOT / "tuep"
if tuep_path.exists():
    # Check for metadata spreadsheet
    metadata_path = tuep_path / "DOCS" / "metadata_v00r.xlsx"
    if metadata_path.exists():
        try:
            meta_df = pd.read_excel(metadata_path)
            print(f"TUEP Metadata: {len(meta_df)} rows, {len(meta_df.columns)} columns")
            print(f"Columns: {list(meta_df.columns)}")
            display(meta_df.head(10))
        except Exception as e:
            print(f"Could not read metadata: {e}")

    # Count files per class
    found_any = False
    for cls in ["00_epilepsy", "01_no_epilepsy"]:
        cls_path = tuep_path / cls
        if cls_path.exists():
            found_any = True
            n_edfs = len(list(cls_path.rglob("*.edf")))
            n_subjects = (
                len([d for d in cls_path.iterdir() if d.is_dir()]) if cls_path.is_dir() else 0
            )
            print(f"  {cls}: {n_edfs} EDF files, {n_subjects} subjects")
    if not found_any:
        print("TUEP directory exists but class folders not found yet (download in progress?)")
else:
    print("TUEP not available")

## 6. Signal Quality & Spectral Analysis

In [None]:
# Pick a representative file — try TUSL first (smallest), fall back to any available corpus
sample_edfs = sorted(glob.glob(str(DATA_ROOT / "tusl" / "**/*.edf"), recursive=True))
if not sample_edfs and inventories:
    # Fall back to first corpus that has an EDF
    for inv in inventories:
        if inv["sample_edf"]:
            sample_edfs = [inv["sample_edf"]]
            break

if sample_edfs:
    sample_edf = sample_edfs[0]
    raw = safe_read_raw_edf(sample_edf)

    if raw is not None:
        tcp = apply_tcp_montage(raw)

        if tcp is not None:
            fig, axes = plt.subplots(1, 2, figsize=(16, 5))

            # PSD
            psd = tcp.compute_psd(fmin=0.5, fmax=50, method="welch", verbose=False)
            psd.plot(axes=axes[0], show=False, spatial_colors=True)
            axes[0].set_title("Power Spectral Density (TCP montage)")

            # Spectrogram of one channel
            data = tcp.get_data(picks=[0])[0]
            sfreq = tcp.info["sfreq"]
            axes[1].specgram(
                data,
                NFFT=int(sfreq * 2),
                Fs=sfreq,
                noverlap=int(sfreq),
                cmap="viridis",
                vmin=-30,
                vmax=10,
            )
            axes[1].set_ylabel("Frequency (Hz)")
            axes[1].set_xlabel("Time (s)")
            axes[1].set_ylim(0, 50)
            axes[1].set_title(f"Spectrogram: {tcp.ch_names[0]}")

            plt.tight_layout()
            plt.savefig(str(DATA_ROOT.parent / "spectral_analysis.png"), bbox_inches="tight")
            plt.show()
        else:
            print("Could not apply TCP montage to sample file")
    else:
        print(f"Could not read {sample_edf} (file may be incomplete)")
else:
    print("No EDF files found — skipping spectral analysis.")

## 7. Cross-Corpus Comparison

Compare key properties across all available corpora.

In [None]:
# Summarize all findings
if inventories:
    summary_data = []
    for inv in inventories:
        corpus = inv["corpus"]
        corpus_edfs = (
            edf_df[edf_df["corpus"] == corpus]
            if not edf_df.empty and "corpus" in edf_df.columns
            else pd.DataFrame()
        )

        summary_data.append(
            {
                "Corpus": corpus,
                "EDF Files": inv["edf_files"],
                "Subjects": inv["subjects"],
                "Annotation Files": inv["csv_files"]
                + inv["tse_files"]
                + inv["rec_files"]
                + inv["lab_files"],
                "Avg Duration (min)": corpus_edfs["duration_min"].mean()
                if not corpus_edfs.empty
                else None,
                "Avg Channels": corpus_edfs["n_channels"].mean()
                if not corpus_edfs.empty
                else None,
                "Primary Sfreq (Hz)": corpus_edfs["sfreq"].mode().iloc[0]
                if not corpus_edfs.empty and not corpus_edfs["sfreq"].mode().empty
                else None,
                "Montages": ", ".join(inv["montages"].keys()) if inv["montages"] else "N/A",
            }
        )

    summary_df = pd.DataFrame(summary_data).set_index("Corpus")
    display(summary_df)

print("\n" + "=" * 80)
print("EXPLORATION COMPLETE")
print("=" * 80)
if inventories:
    print(f"Corpora explored: {len(inventories)}")
    print(f"Total EDF files across corpora: {sum(inv['edf_files'] for inv in inventories)}")
    print(
        f"Total unique subjects: {sum(inv['subjects'] for inv in inventories)} (with overlap between corpora)"
    )
else:
    print("No corpora available yet. Re-run after downloading data.")

## 8. Interactive EEG Explorer

Select a corpus, browse subjects, preview file info, and open recordings — all from dropdowns.
- **Plot in Notebook**: inline 10-second TCP montage preview
- **Open Qt Viewer**: full scrollable EEG browser in a separate window (requires display/X11)

In [None]:
import ipywidgets as widgets
from IPython.display import HTML, clear_output, display

# ── Corpus metadata ──────────────────────────────────────────────────────────
CORPUS_META = {
    "TUSL": {
        "desc": "Slowing vs Seizure differentiation",
        "size": "1.5 GB",
        "version": "v2.0.1",
        "labels": "seiz, slow, bckg",
        "ann_ext": ".tse",
    },
    "TUAR": {
        "desc": "Per-channel artifact detection",
        "size": "5.4 GB",
        "version": "v3.0.1",
        "labels": "eyem, chew, shiv, musc, elec, ...",
        "ann_ext": ".csv",
    },
    "TUEV": {
        "desc": "6-class EEG event classification",
        "size": "19 GB",
        "version": "v2.0.1",
        "labels": "spsw, gped, pled, eyem, artf, bckg",
        "ann_ext": ".rec",
    },
    "TUEP": {
        "desc": "Epilepsy / no-epilepsy diagnosis",
        "size": "35 GB",
        "version": "v3.0.0",
        "labels": "epilepsy, no_epilepsy",
        "ann_ext": ".csv",
    },
    "TUAB": {
        "desc": "Binary normal / abnormal EEG",
        "size": "58 GB",
        "version": "v3.0.1",
        "labels": "normal, abnormal (by folder)",
        "ann_ext": None,
    },
    "TUSZ": {
        "desc": "Seizure detection benchmark",
        "size": "81 GB",
        "version": "v2.0.3",
        "labels": "seiz, bckg (+ seizure types)",
        "ann_ext": ".csv_bi",
    },
    "TUEG": {
        "desc": "Full TUH EEG Corpus (unlabeled)",
        "size": "1,639 GB",
        "version": "v2.0.1",
        "labels": "N/A",
        "ann_ext": None,
    },
}

# ── Lazy per-subject index with cache ────────────────────────────────────────
_subject_cache = {}


def _scan_subjects(corpus_name):
    """Scan corpus EDF files grouped by subject. Cached after first call."""
    if corpus_name in _subject_cache:
        return _subject_cache[corpus_name]
    path = corpora_config.get(corpus_name)
    if not path or not path.exists():
        return {}
    subjects = {}
    for edf in sorted(path.rglob("*.edf")):
        meta = parse_tuh_path(edf)
        if meta:
            subjects.setdefault(meta["subject_id"], []).append(meta)
    _subject_cache[corpus_name] = subjects
    return subjects


# ── Available corpora (from inventory, already computed) ─────────────────────
_available = [inv["corpus"] for inv in inventories] if inventories else []

# ── Widgets ──────────────────────────────────────────────────────────────────
_style = {"description_width": "70px"}

corpus_dd = widgets.Dropdown(
    options=["", *_available],
    value="",
    description="Corpus:",
    style=_style,
    layout=widgets.Layout(width="260px"),
)
subject_dd = widgets.Dropdown(
    options=[""],
    value="",
    description="Subject:",
    style=_style,
    layout=widgets.Layout(width="260px"),
)
file_dd = widgets.Dropdown(
    options=[""],
    value="",
    description="File:",
    style=_style,
    layout=widgets.Layout(width="420px"),
)

plot_btn = widgets.Button(
    description=" Plot in Notebook",
    icon="line-chart",
    button_style="info",
    layout=widgets.Layout(width="180px", height="36px"),
)
viewer_btn = widgets.Button(
    description=" Open Qt Viewer",
    icon="desktop",
    button_style="success",
    layout=widgets.Layout(width="180px", height="36px"),
)

overview_out = widgets.Output(
    layout=widgets.Layout(
        border="1px solid #ccc",
        padding="10px",
        min_height="120px",
        width="100%",
    )
)
file_info_out = widgets.Output(
    layout=widgets.Layout(
        border="1px solid #ccc",
        padding="10px",
        min_height="120px",
        width="100%",
    )
)
plot_out = widgets.Output(layout=widgets.Layout(width="100%"))
status_label = widgets.HTML(value="<i>Select a corpus to begin.</i>")

# ── Internal state ───────────────────────────────────────────────────────────
_state = {"subjects": {}, "files": [], "raw": None, "tcp": None, "path": None}


# ── Callbacks ────────────────────────────────────────────────────────────────
def _on_corpus(change):
    corpus = change["new"]
    # Reset downstream
    _state["raw"] = _state["tcp"] = _state["path"] = None
    file_dd.options = [""]
    file_dd.value = ""
    with file_info_out:
        clear_output()
    with plot_out:
        clear_output()

    if not corpus:
        subject_dd.options = [""]
        with overview_out:
            clear_output()
        status_label.value = "<i>Select a corpus to begin.</i>"
        return

    status_label.value = f"<i>Scanning {corpus} subjects...</i>"
    subjects = _scan_subjects(corpus)
    _state["subjects"] = subjects

    # Populate subject dropdown
    sorted_sids = sorted(subjects.keys())
    subject_dd.options = ["", *sorted_sids]
    subject_dd.value = ""

    # Overview panel
    meta = CORPUS_META.get(corpus, {})
    inv = next((i for i in inventories if i["corpus"] == corpus), None)
    with overview_out:
        clear_output(wait=True)
        lines = [
            f"<b>{corpus}</b> {meta.get('version', '')} &mdash; {meta.get('size', '?')}",
            f"<i>{meta.get('desc', '')}</i>",
            f"<b>EDF files:</b> {inv['edf_files'] if inv else '?'} &nbsp;|&nbsp; "
            f"<b>Subjects:</b> {len(sorted_sids)}",
            f"<b>Labels:</b> {meta.get('labels', 'N/A')}",
        ]
        if inv and inv["montages"]:
            montage_str = ", ".join(f"{k} ({v})" for k, v in inv["montages"].items())
            lines.append(f"<b>Montages:</b> {montage_str}")
        display(HTML("<br>".join(lines)))

    status_label.value = f"<b>{corpus}</b>: {len(sorted_sids)} subjects ready."


def _on_subject(change):
    sid = change["new"]
    _state["raw"] = _state["tcp"] = _state["path"] = None
    with file_info_out:
        clear_output()
    with plot_out:
        clear_output()

    if not sid:
        file_dd.options = [""]
        return

    files = _state["subjects"].get(sid, [])
    _state["files"] = files
    labels = [f"{Path(f['path']).stem}  ({f['montage'] or '?'})" for f in files]
    file_dd.options = ["", *labels]
    file_dd.value = ""
    status_label.value = f"Subject <b>{sid}</b>: {len(files)} file(s)."


def _on_file(change):
    sel = change["new"]
    _state["raw"] = _state["tcp"] = _state["path"] = None
    with plot_out:
        clear_output()

    if not sel:
        with file_info_out:
            clear_output()
        return

    idx = (file_dd.options).index(sel) - 1  # offset for leading ''
    meta = _state["files"][idx]
    edf_path = meta["path"]
    _state["path"] = edf_path

    info = inspect_edf(edf_path)
    with file_info_out:
        clear_output(wait=True)
        if "error" in info:
            display(
                HTML(
                    f"<span style='color:red'>Cannot read file (may be incomplete):</span><br>{info['error']}"
                )
            )
            return
        dur_min = info["duration_sec"] / 60
        sfreqs = info["sample_freqs"]
        sfreq_str = f"{sfreqs[0]:.0f} Hz" if sfreqs else "?"
        lines = [
            f"<b>{Path(edf_path).name}</b>",
            f"<b>Channels:</b> {info['n_channels']} &nbsp;|&nbsp; "
            f"<b>Sfreq:</b> {sfreq_str} &nbsp;|&nbsp; "
            f"<b>Duration:</b> {dur_min:.1f} min",
            f"<b>Montage:</b> {meta['montage'] or 'unknown'}",
            f"<b>Subject:</b> {meta['subject_id']} &nbsp;|&nbsp; "
            f"<b>Session:</b> s{meta['session_num']:03d} &nbsp;|&nbsp; "
            f"<b>Token:</b> t{meta['token_num']:03d}",
        ]
        # Check for annotation files alongside the EDF
        edf_p = Path(edf_path)
        ann_exts = [".csv", ".csv_bi", ".tse", ".tse_agg", ".lbl", ".lbl_agg", ".rec", ".lab"]
        found_ann = [ext for ext in ann_exts if edf_p.with_suffix(ext).exists()]
        if found_ann:
            lines.append(f"<b>Annotations:</b> {', '.join(found_ann)}")
        else:
            lines.append("<b>Annotations:</b> none found")
        display(HTML("<br>".join(lines)))

    status_label.value = f"Ready: {Path(edf_path).name}"


def _on_plot(btn):
    if not _state["path"]:
        with plot_out:
            clear_output(wait=True)
            print("Select a file first.")
        return

    with plot_out:
        clear_output(wait=True)
        status_label.value = "<i>Loading EDF...</i>"
        raw = safe_read_raw_edf(_state["path"])
        if raw is None:
            print("Could not read file (may be incomplete).")
            status_label.value = "Load failed."
            return
        _state["raw"] = raw
        tcp = apply_tcp_montage(raw)
        _state["tcp"] = tcp
        target = tcp if tcp is not None else raw
        tag = "TCP montage" if tcp is not None else "raw channels"

        # Pick a start 10s in if recording is long enough
        max_start = max(0, target.times[-1] - 10)
        start = min(10.0, max_start)

        plot_eeg_segment(
            target, start_sec=start, duration_sec=10, title=f"{Path(_state['path']).stem} ({tag})"
        )
        plt.show()
        status_label.value = f"Plotted {Path(_state['path']).name} ({tag})."


def _on_viewer(btn):
    if not _state["path"]:
        status_label.value = '<span style="color:orange">Select a file first.</span>'
        return

    status_label.value = "<i>Loading into Qt viewer...</i>"
    raw = _state.get("raw") or safe_read_raw_edf(_state["path"])
    if raw is None:
        status_label.value = '<span style="color:red">Could not read file.</span>'
        return
    _state["raw"] = raw

    tcp = _state.get("tcp") or apply_tcp_montage(raw)
    target = tcp if tcp is not None else raw

    try:
        mne.viz.set_browser_backend("qt")
        target.plot(block=False, title=Path(_state["path"]).stem)
        status_label.value = f"Qt viewer opened for {Path(_state['path']).name}."
    except Exception as e:
        status_label.value = f'<span style="color:red">Qt viewer failed: {e}</span>'


# ── Wire up ──────────────────────────────────────────────────────────────────
corpus_dd.observe(_on_corpus, names="value")
subject_dd.observe(_on_subject, names="value")
file_dd.observe(_on_file, names="value")
plot_btn.on_click(_on_plot)
viewer_btn.on_click(_on_viewer)

# ── Layout ───────────────────────────────────────────────────────────────────
selectors = widgets.HBox([corpus_dd, subject_dd, file_dd], layout=widgets.Layout(gap="8px"))
buttons = widgets.HBox([plot_btn, viewer_btn], layout=widgets.Layout(gap="8px"))
panels = widgets.HBox(
    [
        widgets.VBox(
            [widgets.HTML("<b>Corpus Overview</b>"), overview_out],
            layout=widgets.Layout(width="50%"),
        ),
        widgets.VBox(
            [widgets.HTML("<b>File Details</b>"), file_info_out],
            layout=widgets.Layout(width="50%"),
        ),
    ],
    layout=widgets.Layout(gap="12px"),
)

ui = widgets.VBox(
    [
        selectors,
        panels,
        widgets.HBox(
            [buttons, status_label], layout=widgets.Layout(gap="16px", align_items="center")
        ),
        plot_out,
    ],
    layout=widgets.Layout(gap="10px", padding="8px"),
)

display(ui)