In [None]:
import sys

sys.path.append("../")

import matplotlib.pyplot as plt
from src.file_loader import FileLoader
import pandas as pd
import numpy as np


color_map = {
    "RAW_TP9": "tab:blue",
    "RAW_AF7": "tab:orange",
    "RAW_AF8": "tab:green",
    "RAW_TP10": "tab:red",
}
gray_shade = {
    "RAW_TP9": "gray",
    "RAW_AF7": "dimgray",
    "RAW_AF8": "darkgray",
    "RAW_TP10": "slategray",
}


# ---------------------------------------------------------------------
def _mad(x: pd.Series) -> float:
    """Median absolute deviation scaled to σ."""
    return 1.4826 * np.median(np.abs(x - np.median(x)))


# ---------------------------------------------------------------------
def is_eeg_record_valid(
    eeg_df: pd.DataFrame,
    *,
    # ----------- thresholds you may want to tune ------------------
    flat_mad_thr: float = 2.5,  #  ≲  2.5  Muse-raw units ⇒ flat
    wild_mad_thr: float = 120,  #  ≳ 120 units ⇒ wild
    wild_jump_thr: float = 150,  # 95-th pct(|Δ|) ≳ 150 ⇒ wild
) -> dict:
    """
    Returns
    -------
    {channel_name: (is_valid: bool, explanation: str)}
    """
    chans = ["RAW_TP9", "RAW_AF7", "RAW_AF8", "RAW_TP10"]
    sig = eeg_df[chans].astype(float)

    results = {}
    for ch in chans:
        s = sig[ch].dropna()
        mad = _mad(s)
        p95_jump = s.diff().abs().quantile(0.95)

        # ---------------- failure modes ----------------------------
        if mad < flat_mad_thr:
            results[ch] = (False, f"MAD {mad:.1f} < {flat_mad_thr})")
        elif mad > wild_mad_thr or p95_jump > wild_jump_thr:
            results[ch] = (
                False,
                f"MAD {mad:.0f} > {wild_mad_thr} or " f"95-pct|Δ| {p95_jump:.0f} > {wild_jump_thr})",
            )
        else:
            results[ch] = (True, "normal")

    return results


def data_generator():
    loader = FileLoader(base_path="../ufal_emmt")
    data_units = loader.get_data_files(
        category_filter="Read", participant_id_filter=None, sentence_id_filter=None, sort_by="participant_id"
    )

    print(data_units[0])

    for read_data_unit in data_units:
        participant_id = read_data_unit["participant_id"]
        sentence_id = read_data_unit["sentence_id"]

        translate_data_units = loader.get_data_files(
            category_filter="Translate", participant_id_filter=participant_id, sentence_id_filter=sentence_id
        )
        see_data_units = loader.get_data_files(
            category_filter="See", participant_id_filter=participant_id, sentence_id_filter=sentence_id
        )
        update_data_units = loader.get_data_files(
            category_filter="Update", participant_id_filter=participant_id, sentence_id_filter=sentence_id
        )

        rd = loader.load_data(read_data_unit)
        td = loader.load_data(translate_data_units[0])
        sd = loader.load_data(see_data_units[0])
        ud = loader.load_data(update_data_units[0])

        yield (rd, td, sd, ud), read_data_unit


%matplotlib inline
g = data_generator()
n = 0
for (rd, td, sd, ud), du in g:
    if n > 100:
        break
    n += 1

    r_eeg_df, r_gaze_df, r_audio_io = rd
    t_eeg_df, t_gaze_df, t_audio_io = td
    s_eeg_df, s_gaze_df, s_audio_io = sd
    u_eeg_df, u_gaze_df, u_audio_io = ud

    fig, axes = plt.subplots(1, 4, figsize=(20, 6), sharey=True)
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Make room for the suptitle
    plt.suptitle(
        f"EEG | participant_id: {du['participant_id']} | sentence_id: {du['sentence_id']} | order: {du['order']}",
        fontsize=16,
    )

    for i, (eeg_df, label, ax) in enumerate(
        zip([r_eeg_df, t_eeg_df, s_eeg_df, u_eeg_df], ["Read", "Translate", "See", "Update"], axes)
    ):
        # Compute validity
        results = is_eeg_record_valid(eeg_df)
        overall_valid = all(ok for ok, _ in results.values())

        # Convert TimeStamp
        eeg_df["TimeStamp"] = pd.to_datetime(eeg_df["TimeStamp"], format="%H:%M:%S.%f")

        # Plot each channel
        for ch, ch_label in zip(["RAW_TP9", "RAW_AF7", "RAW_AF8", "RAW_TP10"], ["TP9", "AF7", "AF8", "TP10"]):
            ok, reason = results[ch]
            color = color_map[ch] if ok else gray_shade[ch]
            ax.plot(eeg_df["TimeStamp"], eeg_df[ch], label=ch_label, color=color)
            print(f"{label} - {ch:7s} -> {'OK' if ok else 'BAD'} | {reason}")

        ax.set_title(f"{label} EEG\nValid: {overall_valid}")
        ax.set_xlabel("Time")
        ax.grid(True)
        if i == 0:
            ax.set_ylabel("Amplitude")
        ax.legend()

    plt.tight_layout()
    plt.show()
