In [None]:
# -*- coding: utf-8 -*-


from __future__ import annotations
from pathlib import Path
import json
import hashlib
import pandas as pd
import numpy as np

# ====== SETTINGS ======
ROOT_DIR = Path(r"/home/tsultan1/BioRob(Final)/Data")
SUBJECT_GLOB = "Sub-*"
TARGET_SUBDIR = "cleaned"
REPORT_NAME = "trainschema_prune_report.csv"
SCHEMA_LOCK = "train_schema.json"


DELETE_POLICY = "add_missing"

# Real vs MI: require EMG only for real trials
REQUIRE_EMG = True

# --- Keep for later (NOT fed to model directly, but retained in files) ---
META_ONLY = ["subject_id", "task", "trial", "Timestamp_seconds"]

# --- Core inputs for training ---
EEG_COLS = [f"Ch{i}" for i in range(1, 9)]  # µV
EMG_RAW  = [f"Ch{i} EMG raw" for i in range(1, 5)]

ET_CORE = [
    "ET_GazeLeftx","ET_GazeLefty","ET_GazeRightx","ET_GazeRighty",
    "ET_PupilLeft","ET_PupilRight",
    "ET_ValidityLeftEye","ET_ValidityRightEye",
    "ET_Blink","ET_Fixation","ET_Worn",
]

IMU_HEAD = [
    "ET_GyroX","ET_GyroY","ET_GyroZ","ET_AccX","ET_AccY","ET_AccZ",
    "ET_HeadRotationPitch","ET_HeadRotationYaw","ET_HeadRotationRoll"
]

ET_DIST = ["ET_DistanceLeft","ET_DistanceRight"]  # include but fill -1 if missing

# Always drop: logging/timekeeping/3D eye geometry/stray event columns
ALWAYS_DROP = {
    "Row","Timestamp","SampleNumber","ET_TimeSignal","LSL Timestamp",
    "EventSource","SlideEvent","StimType","Duration","CollectionPhase","SourceStimuliName",
    "EventSource.1","EventSource.2","EventSource.3",
    "ET_CameraLeftX","ET_CameraLeftY","ET_CameraRightX","ET_CameraRightY",
    "ET_Gaze3dEyeballXLeft","ET_Gaze3dEyeballYLeft","ET_Gaze3dEyeballZLeft",
    "ET_Gaze3dEyeballXRight","ET_Gaze3dEyeballYRight","ET_Gaze3dEyeballZRight",
    "ET_Gaze3dOpticalAxisXLeft","ET_Gaze3dOpticalAxisYLeft","ET_Gaze3dOpticalAxisZLeft",
    "ET_Gaze3dOpticalAxisXRight","ET_Gaze3dOpticalAxisYRight","ET_Gaze3dOpticalAxisZRight",
    "ET_Gaze3dEyelidAngleTopLeft","ET_Gaze3dEyelidAngleBottomLeft",
    "ET_Gaze3dEyelidAngleTopRight","ET_Gaze3dEyelidAngleBottomRight",
    "ET_Gaze3dEyelidApertureLeft","ET_Gaze3dEyelidApertureRight",
}

# Canonical training schema (order enforced)
CANONICAL_ORDER = (
    META_ONLY
    + EEG_COLS
    + EMG_RAW
    + ET_CORE
    + ET_DIST
    + IMU_HEAD
)

# Which of the canonical columns are considered required (must exist or be addable)
REQUIRED_MIN = set(META_ONLY + EEG_COLS + ET_CORE + IMU_HEAD + ET_DIST)
if REQUIRE_EMG:
    REQUIRED_MIN |= set(EMG_RAW)

# Default fillers for when we must add missing columns
FILL_DEFAULTS = {
    # numeric eye distances (if absent)
    "ET_DistanceLeft": -1.0,
    "ET_DistanceRight": -1.0,
    # binary-ish eye state flags (if absent)
    "ET_Blink": 0, "ET_Fixation": 0, "ET_Worn": 1,
    # validity flags (0=valid/1=invalid varies by vendor; keep numeric placeholder)
    "ET_ValidityLeftEye": 0, "ET_ValidityRightEye": 0,
}

NUMERIC_LIKE = set(CANONICAL_ORDER) - set(META_ONLY)

def should_delete_for_missing(missing_required: list) -> bool:
    if DELETE_POLICY == "keep_bad":
        return False
    if DELETE_POLICY == "add_missing":
        return False
    # missing_only
    return bool(missing_required)

def read_csv_safely(path: Path) -> pd.DataFrame | None:
    try:
        return pd.read_csv(path, encoding="utf-8-sig", on_bad_lines="skip")
    except Exception:
        try:
            return pd.read_csv(path, encoding="utf-8-sig", engine="python", on_bad_lines="skip")
        except Exception:
            return None

def coerce_types(df: pd.DataFrame) -> pd.DataFrame:
    # Make sure numeric-like columns are numeric
    for c in df.columns:
        if c in NUMERIC_LIKE:
            df[c] = pd.to_numeric(df[c], errors="coerce")
    return df

def add_missing_columns(df: pd.DataFrame, missing_cols: list) -> pd.DataFrame:
    for c in missing_cols:
        default = FILL_DEFAULTS.get(c, np.nan)
        df[c] = default
    return df

def prune_to_train_schema(df: pd.DataFrame) -> (pd.DataFrame, dict):
    cols_present = set(df.columns)

    # Build keep set: canonical minus anything we know we always drop
    keep_set = set(CANONICAL_ORDER)  # we now ALWAYS include ET distances for schema uniformity

    # Figure out what's missing from canonical
    missing_canonical = [c for c in CANONICAL_ORDER if c not in cols_present]

    # If we’re allowed to add missing, do it
    added_missing = []
    if DELETE_POLICY == "add_missing" and missing_canonical:
        df = add_missing_columns(df, missing_canonical)
        added_missing = missing_canonical
        cols_present = set(df.columns)

    # Compute required-missing against REQUIRED_MIN (after optional add)
    missing_required = sorted([c for c in REQUIRED_MIN if c not in cols_present])

    # Drop everything not in keep_set, plus ALWAYS_DROP
    drop_set = (cols_present - keep_set) | (cols_present & ALWAYS_DROP)
    pruned = df.drop(columns=[c for c in drop_set if c in df.columns], errors="ignore")

    # Coerce types and reorder canonically
    pruned = coerce_types(pruned)
    pruned = pruned.reindex(columns=CANONICAL_ORDER, fill_value=np.nan)

    # Simple schema hash for reporting
    schema_sig = hashlib.md5(("|".join(pruned.columns)).encode()).hexdigest()[:8]

    info = {
        "kept_n": len(pruned.columns),
        "dropped_n": len(drop_set),
        "missing_required": ";".join(missing_required),
        "added_missing": ";".join(added_missing),
        "kept_cols": ";".join(pruned.columns),
        "dropped_cols": ";".join(sorted(drop_set)),
        "schema_sig": schema_sig,
    }
    return pruned, info

def write_schema_lock():
    lock = {
        "meta_only": META_ONLY,
        "eeg_cols": EEG_COLS,
        "emg_raw": EMG_RAW,
        "et_core": ET_CORE,
        "et_dist": ET_DIST,
        "imu_head": IMU_HEAD,
        "canonical_order": CANONICAL_ORDER,
        "required_min": sorted(REQUIRED_MIN),
        "always_drop": sorted(ALWAYS_DROP),
        "defaults": FILL_DEFAULTS,
        "require_emg": REQUIRE_EMG,
    }
    out = ROOT_DIR / SCHEMA_LOCK
    with out.open("w", encoding="utf-8") as f:
        json.dump(lock, f, indent=2)
    return out

def verify_prune_save():
    rows = []
    total = good = bad = 0
    schema_sigs = set()

    print(f"DELETE_POLICY = {DELETE_POLICY} | REQUIRE_EMG = {REQUIRE_EMG}")
    lock_path = write_schema_lock()
    print(f"[lock] schema → {lock_path}")

    for subj_dir in sorted([p for p in ROOT_DIR.glob(SUBJECT_GLOB) if p.is_dir()]):
        scan_dir = subj_dir / TARGET_SUBDIR
        if not scan_dir.exists():
            continue

        csvs = sorted(scan_dir.glob("*.csv"))
        if not csvs:
            continue

        print(f"\n=== Pruning {subj_dir.name}/{TARGET_SUBDIR}: {len(csvs)} files ===")

        for csv_path in csvs:
            total += 1
            df = read_csv_safely(csv_path)

            if df is None:
                print(f"[READ-FAIL] {csv_path.name} (kept untouched)")
                rows.append({
                    "subject": subj_dir.name, "file": csv_path.name,
                    "status": "read_failed_kept",
                    "reason": "read_failed",
                    "missing_required": "", "added_missing": "",
                    "kept_n": "", "dropped_n": "", "kept_cols": "", "dropped_cols": "",
                    "schema_sig": ""
                })
                bad += 1
                continue

            pruned, info = prune_to_train_schema(df)
            missing_req = info["missing_required"].split(";") if info["missing_required"] else []

            if should_delete_for_missing(missing_req):
                try:
                    csv_path.unlink()
                    print(f"[DELETED] {csv_path.name} (missing required: {missing_req})")
                    status = "deleted_missing_required"
                except Exception as e:
                    print(f"[WARN] Could not delete {csv_path.name}: {e}")
                    status = "delete_failed"
                bad += 1
            else:
                pruned.to_csv(csv_path, index=False, encoding="utf-8-sig")
                add_msg = f", +{info['added_missing']}" if info["added_missing"] else ""
                print(f"[SAVED] {csv_path.name}: kept {info['kept_n']} cols, dropped {info['dropped_n']}{add_msg} | schema {info['schema_sig']}")
                status = "ok" if not missing_req else "ok_with_missing"
                good += 1
                schema_sigs.add(info["schema_sig"])

            rows.append({
                "subject": subj_dir.name, "file": csv_path.name, "status": status,
                "reason": "" if status.startswith("ok") else ("missing_required" if missing_req else ""),
                **info
            })

    report_path = ROOT_DIR / REPORT_NAME
    pd.DataFrame(rows).to_csv(report_path, index=False, encoding="utf-8-sig")

    print(f"\n=== Summary ===\nTotal: {total} | Good: {good} | Bad: {bad}")
    print(f"Report saved: {report_path}")
    if len(schema_sigs) == 1:
        print("[schema] All files share ONE canonical schema.")
    else:
        print(f"[schema] WARNING: {len(schema_sigs)} different schema signatures encountered (see report).")

if __name__ == "__main__":
    verify_prune_save()
