In [None]:
import sys

sys.path.append("../")

import matplotlib.pyplot as plt
from src.file_loader import FileLoader
import pandas as pd
import numpy as np


color_map = {
    "RAW_TP9": "tab:blue",
    "RAW_AF7": "tab:orange",
    "RAW_AF8": "tab:green",
    "RAW_TP10": "tab:red",
}
gray_shade = {
    "RAW_TP9": "gray",
    "RAW_AF7": "dimgray",
    "RAW_AF8": "darkgray",
    "RAW_TP10": "slategray",
}


# ---------------------------------------------------------------------
def _mad(x: pd.Series) -> float:
    """Median absolute deviation scaled to σ."""
    return 1.4826 * np.median(np.abs(x - np.median(x)))


# ---------------------------------------------------------------------
def is_eeg_record_valid(
    eeg_df: pd.DataFrame,
    *,
    # ----------- thresholds you may want to tune ------------------
    flat_range_thr: float = 25,  #  ≲ 25  Muse-raw units ⇒ flat
    flat_mad_thr: float = 5,  #  ≲  5  Muse-raw units ⇒ flat
    wild_mad_thr: float = 120,  #  ≳ 120 units ⇒ wild
    wild_jump_thr: float = 150,  # 95-th pct(|Δ|) ≳ 150 ⇒ wild
) -> dict:
    """
    Returns
    -------
    {channel_name: (is_valid: bool, explanation: str)}
    """
    chans = ["RAW_TP9", "RAW_AF7", "RAW_AF8", "RAW_TP10"]
    sig = eeg_df[chans].astype(float)

    results = {}
    for ch in chans:
        s = sig[ch].dropna()
        mad = _mad(s)
        p95_jump = s.diff().abs().quantile(0.95)

        # ---------------- failure modes ----------------------------
        if mad < flat_mad_thr:
            results[ch] = (False, f"MAD {mad:.1f} < {flat_mad_thr})")
        elif mad > wild_mad_thr or p95_jump > wild_jump_thr:
            results[ch] = (
                False,
                f"MAD {mad:.0f} > {wild_mad_thr} or " f"95-pct|Δ| {p95_jump:.0f} > {wild_jump_thr})",
            )
        else:
            results[ch] = (True, "normal")

    return results


def data_generator():
    loader = FileLoader(base_path="../ufal_emmt")
    data_units = loader.get_data_files(category_filter="Read", participant_id_filter=None, sentence_id_filter=None)

    for data_unit in data_units:
        yield loader.load_data(data_unit), data_unit


%matplotlib inline
g = data_generator()
n = 0
for (df, gaze_df, audio_io), du in g:
    if n > 100:
        break
    n += 1

    results = is_eeg_record_valid(df)
    overall_valid = all(ok for ok, _ in results.values())

    df["TimeStamp"] = pd.to_datetime(df["TimeStamp"], format="%H:%M:%S.%f")
    plt.figure(figsize=(15, 6))

    # Plot multiple Delta channels
    for ch, label in zip(["RAW_TP9", "RAW_AF7", "RAW_AF8", "RAW_TP10"], ["TP9", "AF7", "AF8", "TP10"]):
        ok, reason = results[ch]
        print(f"  {ch:7s} -> {'OK' if ok else 'BAD'} | {reason}")
        color = color_map[ch] if ok else gray_shade[ch]
        plt.plot(df["TimeStamp"], df[ch], label=label, color=color)

    plt.xlabel("Time")
    plt.ylabel("Amplitude")
    plt.title("EEG Signals Over Time - overall_valid: " + str(overall_valid))
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
