In [None]:
# =============================================================================
# Phase 3 — Manifest & TRUE LOSO splits (sliding + onset-anchor windows)
# Uses ALL available columns including label_action and task_target
# =============================================================================
from __future__ import annotations
import re, glob, hashlib
from pathlib import Path
import numpy as np
import pandas as pd
from collections import defaultdict

# ---------------- CONFIG ----------------
ROOT_DIR = Path(r"/home/tsultan1/BioRob(Final)/Data")
LABEL_SUBPATH = Path(r"cleaned/synchronized_proper_lite_union_v3/labelonly")
CSV_GLOB = "*_icml_consensus_labels.csv"

DATASET_DIR = ROOT_DIR / "_dataset_icml_v1"
DATASET_DIR.mkdir(exist_ok=True)
MANIFEST_OUT = DATASET_DIR / "manifest_v1.csv"
SPLITS_OUT   = DATASET_DIR / "splits_v1.csv"

# Windowing
SLIDE_WIN_S     = 2.0
SLIDE_STRIDE_S  = 0.25
ANCHOR_PRE_S    = 2.0
ANCHOR_POST_S   = 3.0
ACTIVE_MAJ_FRAC = 0.50  # >=50% active → label_action=1

# Val subject selection per fold
VAL_SELECTION_MODE = "smallest_other"
VAL_SUBJECT_ID     = 1

WRITE_FILE_MD5 = False

# Required columns - updated based on your data
REQ = ["Timestamp_seconds", "active", "task", "trial", "subject_id", "label_action", "task_target"]

# ---------------- HELPERS ----------------
def _is_subdir_name(name: str) -> bool:
    return re.match(r"(?i)^sub-?\d+$", name) is not None

def _collect_files() -> list[Path]:
    files = []
    for sub in sorted(p for p in ROOT_DIR.iterdir() if p.is_dir() and _is_subdir_name(p.name)):
        label_dir = sub / LABEL_SUBPATH
        if not label_dir.exists():
            print(f"[skip] no labelonly here: {label_dir}")
            continue
        got = sorted(Path(p) for p in glob.glob(str((label_dir / CSV_GLOB).resolve())))
        if not got:
            print(f"[skip] no CSV in: {label_dir}")
            continue
        print(f"[use] {sub.name} → labelonly ({len(got)} files)")
        files.extend(got)
    return files

def _median_fs(t: np.ndarray) -> float:
    dt = np.diff(t)
    dt = dt[np.isfinite(dt) & (dt > 0)]
    if dt.size == 0: return np.nan
    return 1.0 / np.median(dt)

def _file_md5(path: Path, chunk: int = 1<<20) -> str:
    h = hashlib.md5()
    with open(path, "rb") as fh:
        while True:
            b = fh.read(chunk)
            if not b: break
            h.update(b)
    return h.hexdigest()

def _choose_val_subject(all_subj: list[int], held_out: int) -> int:
    others = sorted([s for s in all_subj if s != held_out])
    if not others:
        return held_out
    if VAL_SELECTION_MODE == "fixed" and VAL_SUBJECT_ID in others and VAL_SUBJECT_ID != held_out:
        return VAL_SUBJECT_ID
    return others[0]

def _onsets_from_active(active: np.ndarray) -> np.ndarray:
    x = (active.astype(int) > 0).astype(int)
    d = np.diff(np.r_[0, x])
    return np.where(d == 1)[0]

def _to_int_safe(x):
    return int(pd.to_numeric(x, errors="coerce"))

def _window_rows_for_file(file_path: Path,
                          df: pd.DataFrame,
                          fs: float,
                          fold_id: int,
                          split_tag: str) -> list[dict]:
    """Build sliding and onset-anchor rows using label_action and task_target columns."""
    rows = []
    t = pd.to_numeric(df["Timestamp_seconds"], errors="coerce").to_numpy()
    n = len(df)
    if n < 2:
        return rows

    subj = _to_int_safe(df["subject_id"].iloc[0])
    task_code = _to_int_safe(df["task"].iloc[0])
    trial_id  = _to_int_safe(df["trial"].iloc[0])
    active = df["active"].fillna(0).astype(int).to_numpy()
    
    # Use the actual label_action and task_target columns if available
    has_label_action = "label_action" in df.columns
    has_task_target = "task_target" in df.columns
    
    if has_label_action:
        label_action_col = df["label_action"].fillna(0).astype(int).to_numpy()
    if has_task_target:
        task_target_col = df["task_target"].fillna(0).astype(int).to_numpy()

    win_n    = max(1, int(round(SLIDE_WIN_S    * fs)))
    stride_n = max(1, int(round(SLIDE_STRIDE_S * fs)))
    pre_n    = max(0, int(round(ANCHOR_PRE_S   * fs)))
    post_n   = max(1, int(round(ANCHOR_POST_S  * fs)))

    # Sliding windows - use actual labels if available
    for s in range(0, n - win_n + 1, stride_n):
        e = s + win_n
        
        if has_label_action and has_task_target:
            # Use the actual labels from the data
            window_labels = label_action_col[s:e]
            window_tasks = task_target_col[s:e]
            
            # Use majority vote for label_action
            if len(window_labels) > 0:
                label_action = 1 if np.mean(window_labels) >= ACTIVE_MAJ_FRAC else 0
            else:
                label_action = 0
                
            # Use most common non-zero task, or 0 if no task
            non_zero_tasks = window_tasks[window_tasks != 0]
            if len(non_zero_tasks) > 0:
                task_target = np.bincount(non_zero_tasks).argmax()
            else:
                task_target = 0
        else:
            # Fallback to old method using active column
            frac_act = float(np.mean(active[s:e])) if e > s else 0.0
            label_action = 1 if frac_act >= ACTIVE_MAJ_FRAC else 0
            task_target = int(task_code) if label_action == 1 else 0

        rows.append(dict(
            file=str(file_path.resolve()),
            subject_id=subj,
            split=split_tag,
            type="sliding",
            task_target=int(task_target),
            label_action=int(label_action),
            start_idx=int(s),
            end_idx=int(e),
            task_code=int(task_code),
            trial_id=int(trial_id),
            fold_id=int(fold_id),
        ))

    # Onset-anchor windows - use actual label_action for onsets if available
    if has_label_action:
        onsets = _onsets_from_active(label_action_col)
    else:
        onsets = _onsets_from_active(active)
        
    for o in onsets:
        s = o - pre_n
        e = o + post_n
        if s < 0 or e > n or e <= s:
            continue
            
        # For onset windows, use the task at the onset point
        if has_task_target and e <= len(task_target_col):
            onset_task = task_target_col[o]
        else:
            onset_task = task_code

        rows.append(dict(
            file=str(file_path.resolve()),
            subject_id=subj,
            split=split_tag,
            type="onset_anchor",
            task_target=int(onset_task),
            label_action=1,
            start_idx=int(s),
            end_idx=int(e),
            task_code=int(task_code),
            trial_id=int(trial_id),
            fold_id=int(fold_id),
        ))
    return rows

def analyze_class_distribution(splits_df: pd.DataFrame):
    """Analyze class distribution across subjects and splits"""
    print("\n=== CLASS DISTRIBUTION ANALYSIS ===")
    
    # Overall distribution
    overall_action = splits_df["label_action"].value_counts().sort_index()
    overall_task = splits_df["task_target"].value_counts().sort_index()
    
    print(f"Overall label_action distribution:")
    for label, count in overall_action.items():
        print(f"  Class {label}: {count} samples ({count/len(splits_df)*100:.1f}%)")
    
    print(f"\nOverall task_target distribution:")
    for task, count in overall_task.items():
        print(f"  Task {task}: {count} samples ({count/len(splits_df)*100:.1f}%)")
    
    # Per-subject distribution
    subjects = splits_df["subject_id"].unique()
    print(f"\nPer-subject distribution ({len(subjects)} subjects):")
    
    subject_stats = []
    for subject in sorted(subjects):
        subject_data = splits_df[splits_df["subject_id"] == subject]
        action_dist = subject_data["label_action"].value_counts().sort_index()
        task_dist = subject_data["task_target"].value_counts().sort_index()
        
        action_0 = action_dist.get(0, 0)
        action_1 = action_dist.get(1, 0)
        total_actions = action_0 + action_1
        
        subject_stats.append({
            "subject_id": subject,
            "total_samples": len(subject_data),
            "action_0": action_0,
            "action_1": action_1,
            "action_1_ratio": action_1 / total_actions if total_actions > 0 else 0,
            "unique_tasks": len(task_dist)
        })
        
        if subject <= 5:  # Show first 5 subjects as sample
            print(f"  Subject {subject}: {len(subject_data)} samples, "
                  f"action=0: {action_0}, action=1: {action_1} "
                  f"({action_1/total_actions*100:.1f}% positive)")
    
    # Per-split distribution
    print(f"\nPer-split distribution:")
    for split in ["train", "val", "test"]:
        split_data = splits_df[splits_df["split"] == split]
        if len(split_data) > 0:
            action_dist = split_data["label_action"].value_counts().sort_index()
            action_0 = action_dist.get(0, 0)
            action_1 = action_dist.get(1, 0)
            total = action_0 + action_1
            print(f"  {split}: {len(split_data)} samples, "
                  f"action=0: {action_0}, action=1: {action_1} "
                  f"({action_1/total*100:.1f}% positive)")
    
    return subject_stats

def balance_classes_stratified(splits_df: pd.DataFrame, min_samples_per_class=100):
    """Balance classes using stratified sampling per subject"""
    print("\n=== APPLYING CLASS BALANCING ===")
    
    balanced_splits = []
    
    for subject in splits_df["subject_id"].unique():
        subject_data = splits_df[splits_df["subject_id"] == subject]
        
        # Get class distribution for this subject
        class_counts = subject_data["label_action"].value_counts()
        min_class_count = class_counts.min()
        
        # If we have very few samples in one class, use min_samples_per_class
        target_count = max(min_samples_per_class, min_class_count)
        
        # Stratified sampling per class
        subject_balanced = []
        for class_label in class_counts.index:
            class_data = subject_data[subject_data["label_action"] == class_label]
            if len(class_data) > target_count:
                # Undersample majority class
                class_data = class_data.sample(n=target_count, random_state=42)
            subject_balanced.append(class_data)
        
        balanced_subject = pd.concat(subject_balanced, ignore_index=True)
        balanced_splits.append(balanced_subject)
        
        print(f"  Subject {subject}: {len(subject_data)} → {len(balanced_subject)} samples")
    
    balanced_df = pd.concat(balanced_splits, ignore_index=True)
    
    # Verify balancing
    print("\nAfter balancing:")
    overall_action = balanced_df["label_action"].value_counts().sort_index()
    for label, count in overall_action.items():
        print(f"  Class {label}: {count} samples ({count/len(balanced_df)*100:.1f}%)")
    
    return balanced_df

# ---------------- MAIN ----------------
files = _collect_files()
if not files:
    raise SystemExit("[stop] No label CSVs found in any subject's labelonly/")

# Manifest: one row per file
manifest_rows, bad = [], []
for p in files:
    try:
        df = pd.read_csv(p, low_memory=False, usecols=lambda c: c in set(REQ))
        if not all(c in df.columns for c in REQ):
            bad.append((str(p), f"missing cols: {[c for c in REQ if c not in df.columns]}"))
            continue

        t  = pd.to_numeric(df["Timestamp_seconds"], errors="coerce").to_numpy()
        fs = _median_fs(t)
        dur = float((t[-1] - t[0])) if len(t) else 0.0

        subj  = _to_int_safe(df["subject_id"].iloc[0])
        task  = _to_int_safe(df["task"].iloc[0])
        trial = _to_int_safe(df["trial"].iloc[0])

        row = dict(
            file=str(p.resolve()),
            subject_id=subj,
            fs_hz=round(fs, 6) if np.isfinite(fs) else np.nan,
            duration_s=round(dur, 6),
            task_code=task,
            trial_id=trial,
            fold_id=subj,   # LOSO: fold id = subject id
        )
        if WRITE_FILE_MD5:
            row["md5"] = _file_md5(p)
        manifest_rows.append(row)
    except Exception as e:
        bad.append((str(p), str(e)))

manifest = pd.DataFrame(manifest_rows).sort_values(["subject_id","file"]).reset_index(drop=True)
manifest.to_csv(MANIFEST_OUT, index=False)
print(f"[ok] manifest → {MANIFEST_OUT}  ({len(manifest)} ok, {len(bad)} skipped)")

# TRUE LOSO splits - ensure no data leakage
subjects = sorted(manifest["subject_id"].unique().tolist())
print(f"\nFound {len(subjects)} subjects: {subjects}")

splits_rows = []

for test_subject in subjects:
    print(f"\nProcessing TRUE LOSO fold: Subject {test_subject} as test")
    
    # Training subjects: all except test subject
    train_subjects = [s for s in subjects if s != test_subject]
    
    # Choose validation subject from training subjects
    val_subject = _choose_val_subject(train_subjects, test_subject)
    
    # Split mapping: test_subject=test, val_subject=val, others=train
    split_for_subj = {}
    for s in subjects:
        if s == test_subject:
            split_for_subj[s] = "test"
        elif s == val_subject:
            split_for_subj[s] = "val" 
        else:
            split_for_subj[s] = "train"
    
    print(f"  Train subjects: {[s for s in train_subjects if s != val_subject]}")
    print(f"  Val subject: {val_subject}")
    print(f"  Test subject: {test_subject}")

    for file_path_str, subj in manifest[["file","subject_id"]].to_records(index=False):
        split_tag = split_for_subj[int(subj)]
        p = Path(file_path_str)
        try:
            df = pd.read_csv(p, low_memory=False, usecols=lambda c: c in set(REQ))
            if not all(c in df.columns for c in REQ):
                print(f"[warn] skipping (missing cols) {p}")
                continue

            t  = pd.to_numeric(df["Timestamp_seconds"], errors="coerce").to_numpy()
            fs = _median_fs(t)
            if not np.isfinite(fs) or fs <= 0:
                print(f"[warn] fs not finite for {p}; skipping")
                continue

            rows = _window_rows_for_file(p, df, fs, fold_id=test_subject, split_tag=split_tag)
            splits_rows.extend(rows)
        except Exception as e:
            print(f"[warn] windowing failed for {p}: {e}")

splits = pd.DataFrame(splits_rows)

# Analyze class distribution
subject_stats = analyze_class_distribution(splits)

# Apply class balancing
splits_balanced = balance_classes_stratified(splits, min_samples_per_class=100)

# Column order for downstream
ordered = ["file","subject_id","split","type","task_target","label_action",
           "start_idx","end_idx","task_code","trial_id","fold_id"]
extra = [c for c in splits_balanced.columns if c not in ordered]
splits_final = splits_balanced[ordered + extra]

splits_final.to_csv(SPLITS_OUT, index=False)
print(f"\n[ok] splits  → {SPLITS_OUT}  ({len(splits_final)} windows across {len(subjects)} folds)")

# Save analysis report
analysis_report = DATASET_DIR / "dataset_analysis_report.txt"
with open(analysis_report, "w") as f:
    f.write("DATASET ANALYSIS REPORT\n")
    f.write("======================\n\n")
    f.write(f"Total subjects: {len(subjects)}\n")
    f.write(f"Total windows: {len(splits_final)}\n")
    f.write(f"Subjects: {subjects}\n\n")
    
    f.write("PER-SUBJECT STATISTICS:\n")
    for stats in subject_stats:
        f.write(f"Subject {stats['subject_id']}: {stats['total_samples']} samples, "
                f"action_1_ratio: {stats['action_1_ratio']:.3f}, "
                f"unique_tasks: {stats['unique_tasks']}\n")
    
    f.write(f"\nCLASS DISTRIBUTION AFTER BALANCING:\n")
    action_dist = splits_final["label_action"].value_counts().sort_index()
    for label, count in action_dist.items():
        f.write(f"  Class {label}: {count} samples ({count/len(splits_final)*100:.1f}%)\n")

print(f"[info] analysis report → {analysis_report}")

if bad:
    skipped = DATASET_DIR / "manifest_v1_skipped.csv"
    pd.DataFrame(bad, columns=["file","reason"]).to_csv(skipped, index=False)
    print(f"[info] skipped manifest files → {skipped}")