In [None]:
import re
import random
import warnings
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import KFold
from sklearn.metrics import (
    precision_score, recall_score, f1_score, accuracy_score,
    confusion_matrix, balanced_accuracy_score, cohen_kappa_score,
    matthews_corrcoef, log_loss
)
from sklearn.metrics import top_k_accuracy_score
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score, average_precision_score


# ============================================================
# CONFIG
# ============================================================

DATA_ROOT = Path("/home/tsultan1/paper-2/Dataset-2")

SUB_DIR_PATTERN = re.compile(r"final_exports-sub(\d+)$")

# Output folder for best model
OUT_DIR = DATA_ROOT / "train_all_subjects_outputs"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Training
K_FOLDS = 5
EPOCHS = 50
PATIENCE = 5
BATCH_SIZE = 256
LR = 5e-4
WEIGHT_DECAY = 1e-2

# Optional online adaptation
DO_ONLINE_ADAPT = True
ONLINE_ADAPT_FRAC = 0.30
ONLINE_CYCLES = 5
ONLINE_BS = 128
ONLINE_LR = 1e-5
REPLAY_N = 100

# GPU speed
USE_AMP = True
NUM_WORKERS = 4
PIN_MEMORY = True

SEED = 42


# ============================================================
# SEED
# ============================================================

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)


# ============================================================
# CSV HELPERS
# ============================================================

def read_numeric_csv(path: Path) -> np.ndarray:
    df = pd.read_csv(path)
    df = df.apply(pd.to_numeric, errors="coerce").fillna(0.0)
    return df.to_numpy(dtype=np.float32)

def load_labels_csv(lbl_path: Path, sid: int):
    """
    Supports:
      - columns: Label
      - columns: subject_id, Label
      - lowercase variants
    """
    df = pd.read_csv(lbl_path)

    # normalize column names
    cols = {c: c.strip() for c in df.columns}
    df.rename(columns=cols, inplace=True)

    # find Label column
    label_col = None
    for c in df.columns:
        if c.lower() == "label":
            label_col = c
            break
    if label_col is None:
        raise ValueError(f"{lbl_path} missing 'Label' column")

    y = df[label_col].astype(np.int64).to_numpy()

    # subject_id optional
    subj_col = None
    for c in df.columns:
        if c.lower() == "subject_id":
            subj_col = c
            break
    if subj_col is None:
        sid_arr = np.full(len(y), sid, dtype=np.int64)
    else:
        sid_arr = df[subj_col].astype(np.int64).to_numpy()

    return y, sid_arr


# ============================================================
# DATA LOADING (per subject folder: eeg_sub{sid}.csv etc.)
# ============================================================

def subject_id_from_folder(folder: Path) -> int:
    m = SUB_DIR_PATTERN.match(folder.name)
    if not m:
        raise ValueError(f"Folder does not match pattern final_exports-subX: {folder}")
    return int(m.group(1))

def find_subject_files(sub_dir: Path, sid: int):
    """
    Your true filenames:
      eeg_sub{sid}.csv, emg_sub{sid}.csv, labels_sub{sid}.csv
    Also supports fallback to *_final.csv if present.
    """
    eeg_candidates = [
        sub_dir / f"eeg_sub{sid}.csv",
        sub_dir / "eeg_final.csv",
    ]
    emg_candidates = [
        sub_dir / f"emg_sub{sid}.csv",
        sub_dir / "emg_final.csv",
    ]
    lbl_candidates = [
        sub_dir / f"labels_sub{sid}.csv",
        sub_dir / "labels_final.csv",
    ]

    eeg_path = next((p for p in eeg_candidates if p.exists()), None)
    emg_path = next((p for p in emg_candidates if p.exists()), None)
    lbl_path = next((p for p in lbl_candidates if p.exists()), None)

    return eeg_path, emg_path, lbl_path

def load_one_subject_folder(sub_dir: Path):
    sid = subject_id_from_folder(sub_dir)
    eeg_path, emg_path, lbl_path = find_subject_files(sub_dir, sid)

    if eeg_path is None or emg_path is None or lbl_path is None:
        return None

    eeg = read_numeric_csv(eeg_path)
    emg = read_numeric_csv(emg_path)
    y, sid_arr = load_labels_csv(lbl_path, sid)

    # align lengths
    n = min(len(y), eeg.shape[0], emg.shape[0], len(sid_arr))
    if n <= 0:
        return None

    eeg = eeg[:n]
    emg = emg[:n]
    y = y[:n]
    sid_arr = sid_arr[:n]

    # combine features
    X = np.concatenate([eeg, emg], axis=1)  # (N, Feeg+Femg)
    input_dim = eeg.shape[1]                # split point
    return X, y, sid_arr, input_dim

def load_all_subjects_combined(data_root: Path):
    subdirs = [p for p in data_root.iterdir() if p.is_dir() and SUB_DIR_PATTERN.match(p.name)]
    subdirs = sorted(subdirs, key=lambda p: subject_id_from_folder(p))

    if not subdirs:
        raise RuntimeError(f"No folders found like final_exports-subX under {data_root}")

    X_all, y_all, sid_all = [], [], []
    input_dim_ref = None
    total_dim_ref = None

    for sd in subdirs:
        try:
            out = load_one_subject_folder(sd)
            if out is None:
                print(f"[SKIP] {sd.name}: missing files or empty")
                continue

            X, y, sid_arr, input_dim = out

            if input_dim_ref is None:
                input_dim_ref = input_dim
                total_dim_ref = X.shape[1]
            else:
                if input_dim != input_dim_ref or X.shape[1] != total_dim_ref:
                    warnings.warn(
                        f"Skipping {sd.name}: feature mismatch "
                        f"(got input_dim={input_dim}, total={X.shape[1]} vs "
                        f"ref input_dim={input_dim_ref}, total={total_dim_ref})"
                    )
                    continue

            X_all.append(X)
            y_all.append(y)
            sid_all.append(sid_arr)

            print(f"[OK] {sd.name}: N={len(y)} | X={X.shape}")

        except Exception as e:
            warnings.warn(f"[SKIP] {sd.name}: {e}")

    if not X_all:
        raise RuntimeError("No subject data loaded. Check file names and folder structure.")

    X = np.vstack(X_all)
    y = np.concatenate(y_all)
    sid = np.concatenate(sid_all)

    return X, y, sid, input_dim_ref


# ============================================================
# DATASET
# ============================================================

class EEGEMGDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


# ============================================================
# MODEL
# ============================================================

class EEGEMGTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, num_classes, dropout_rate=0.5):
        super().__init__()
        self.align_eeg = nn.Linear(input_dim, input_dim)
        self.align_emg = nn.Linear(input_dim, input_dim)

        self.eeg_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, batch_first=True, dropout=dropout_rate),
            num_layers=2
        )
        self.emg_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, batch_first=True, dropout=dropout_rate),
            num_layers=2
        )

        self.eeg_projector = nn.Linear(input_dim, hidden_dim)
        self.emg_projector = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(p=dropout_rate)

        self.cross_attention_weights = nn.Parameter(
            torch.tensor([[0.7], [0.3]], dtype=torch.float32), requires_grad=True
        )
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)

        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, eeg, emg):
        eeg = self.align_eeg(eeg)
        emg = self.align_emg(emg)

        eeg_features = self.eeg_encoder(eeg)
        emg_features = self.emg_encoder(emg)

        eeg_features = self.dropout(self.eeg_projector(eeg_features))
        emg_features = self.dropout(self.emg_projector(emg_features))

        combined = (
            self.cross_attention_weights[0] * eeg_features +
            self.cross_attention_weights[1] * emg_features
        )

        combined, _ = self.cross_attention(combined, combined, combined)
        out = self.fc(combined.mean(dim=1))
        return out, self.cross_attention_weights


# ============================================================
# TRAIN
# ============================================================

def train_model(model, train_loader, val_loader, criterion, device, input_dim,
                epochs=50, patience=5, lr=5e-4, weight_decay=1e-2, use_amp=True):

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    best_loss = float("inf")
    best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
    bad = 0

    for ep in range(epochs):
        model.train()
        tr_loss = 0.0

        for Xb, yb in train_loader:
            Xb = Xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            eeg = Xb[:, :input_dim].unsqueeze(1)
            emg = Xb[:, input_dim:].unsqueeze(1)

            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits, _ = model(eeg, emg)
                loss = criterion(logits, yb)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            tr_loss += float(loss.detach().cpu())

        # val
        model.eval()
        va_loss = 0.0
        with torch.no_grad():
            for Xb, yb in val_loader:
                Xb = Xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True)

                eeg = Xb[:, :input_dim].unsqueeze(1)
                emg = Xb[:, input_dim:].unsqueeze(1)

                with torch.cuda.amp.autocast(enabled=use_amp):
                    logits, _ = model(eeg, emg)
                    loss = criterion(logits, yb)

                va_loss += float(loss.detach().cpu())

        tr_loss /= max(1, len(train_loader))
        va_loss /= max(1, len(val_loader))
        print(f"Epoch {ep+1}/{epochs} | train={tr_loss:.4f} | val={va_loss:.4f}")

        if va_loss < best_loss:
            best_loss = va_loss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                print("Early stopping.")
                break

    model.load_state_dict(best_state)
    return model


# ============================================================
# ONLINE ADAPT
# ============================================================

def online_adapt(model, buffer_X, buffer_y, replay_X, replay_y, criterion,
                 val_loader, device, input_dim, num_cycles=5, batch_size=128, lr=1e-5, use_amp=True):

    if replay_X is not None and len(replay_X) > 0:
        buffer_X = np.concatenate([buffer_X, replay_X], axis=0)
        buffer_y = np.concatenate([buffer_y, replay_y], axis=0)

    ds = EEGEMGDataset(buffer_X, buffer_y)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-2)
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr/10, max_lr=lr, step_size_up=5, mode="triangular2")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    best_val = float("inf")
    best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    for c in range(num_cycles):
        print(f"Online cycle {c+1}/{num_cycles}")

        model.train()
        for Xb, yb in loader:
            Xb = Xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            eeg = Xb[:, :input_dim].unsqueeze(1)
            emg = Xb[:, input_dim:].unsqueeze(1)

            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=use_amp):
                logits, _ = model(eeg, emg)
                loss = criterion(logits, yb)

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

        # val check
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for Xb, yb in val_loader:
                Xb = Xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True)
                eeg = Xb[:, :input_dim].unsqueeze(1)
                emg = Xb[:, input_dim:].unsqueeze(1)

                with torch.cuda.amp.autocast(enabled=use_amp):
                    logits, _ = model(eeg, emg)
                    loss = criterion(logits, yb)

                val_loss += float(loss.detach().cpu())

        val_loss /= max(1, len(val_loader))
        print(f"  val_loss={val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            print("  early stop online adapt")
            break

    model.load_state_dict(best_state)
    return model


# ============================================================
# EVAL (NO PLOTS)
# ============================================================

def evaluate_no_plots(model, loader, num_classes, device, input_dim):
    model.eval()
    y_true, y_pred, y_scores = [], [], []

    with torch.no_grad():
        for Xb, yb in loader:
            Xb = Xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            eeg = Xb[:, :input_dim].unsqueeze(1)
            emg = Xb[:, input_dim:].unsqueeze(1)

            logits, _ = model(eeg, emg)
            probs = torch.softmax(logits, dim=1)
            pred = torch.argmax(logits, dim=1)

            y_true.extend(yb.cpu().numpy().tolist())
            y_pred.extend(pred.cpu().numpy().tolist())
            y_scores.extend(probs.cpu().numpy().tolist())

    y_true = np.asarray(y_true, dtype=np.int64)
    y_pred = np.asarray(y_pred, dtype=np.int64)
    y_scores = np.asarray(y_scores, dtype=np.float64)

    labels_all = np.arange(num_classes, dtype=np.int64)

    precision = precision_score(y_true, y_pred, average="weighted", zero_division=0)
    recall = recall_score(y_true, y_pred, average="weighted", zero_division=0)
    f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    kappa = cohen_kappa_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred)
    ll = log_loss(y_true, y_scores, labels=labels_all)
    top3 = top_k_accuracy_score(y_true, y_scores, k=min(3, num_classes), labels=labels_all)

    cm = confusion_matrix(y_true, y_pred, labels=labels_all)
    with np.errstate(divide="ignore", invalid="ignore"):
        per_class_acc = np.diag(cm) / np.maximum(1, cm.sum(axis=1))
    mpce = float(np.mean(1.0 - per_class_acc))

    auroc = None
    auprc = None
    try:
        Yb = label_binarize(y_true, classes=labels_all)
        auroc = roc_auc_score(Yb, y_scores, average="weighted")
        auprc = average_precision_score(Yb, y_scores, average="weighted")
    except Exception:
        pass

    print(f"Prec={precision:.3f} Rec={recall:.3f} F1={f1:.3f} Acc={acc:.3f} BalAcc={bal_acc:.3f}")
    print(f"Kappa={kappa:.3f} MCC={mcc:.3f} LogLoss={ll:.4f} Top3={top3:.4f} MPCE={mpce:.4f}")
    if auroc is not None:
        print(f"AUROC={auroc:.4f} AUPRC={auprc:.4f}")

    return {"accuracy": acc}


# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        print(f"[INFO] CUDA: {torch.cuda.get_device_name(0)}")
        torch.backends.cudnn.benchmark = True
        torch.set_float32_matmul_precision("high")

    # Load ALL subjects
    X, y, sid, input_dim = load_all_subjects_combined(DATA_ROOT)
    num_classes = int(np.max(y)) + 1

    print(f"[INFO] X={X.shape} y={y.shape} input_dim={input_dim} classes={num_classes}")
    print(f"[INFO] label counts: {np.bincount(y)}")
    print(f"[SANITY] mean(feature variance)={np.var(X, axis=0).mean():.6f}")

    # KFold (not LOSO)
    kf = KFold(n_splits=K_FOLDS, shuffle=True, random_state=SEED)

    best_acc = -1.0
    best_state = None
    best_fold = None

    for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
        print(f"\n===== Fold {fold}/{K_FOLDS} =====")
        X_tr, X_va = X[tr_idx], X[va_idx]
        y_tr, y_va = y[tr_idx], y[va_idx]

        train_ds = EEGEMGDataset(X_tr, y_tr)
        val_ds = EEGEMGDataset(X_va, y_va)

        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
        val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

        model = EEGEMGTransformer(input_dim=input_dim, hidden_dim=256, num_heads=4, num_classes=num_classes).to(device)
        criterion = nn.CrossEntropyLoss()

        model = train_model(model, train_loader, val_loader, criterion, device, input_dim, EPOCHS, PATIENCE, LR, WEIGHT_DECAY, USE_AMP)

        print("Before online adapt:")
        _ = evaluate_no_plots(model, val_loader, num_classes, device, input_dim)

        if DO_ONLINE_ADAPT:
            online_n = int(len(X_va) * ONLINE_ADAPT_FRAC)
            replay_n = min(REPLAY_N, len(X_tr))
            model = online_adapt(
                model,
                buffer_X=X_va[:online_n], buffer_y=y_va[:online_n],
                replay_X=X_tr[:replay_n], replay_y=y_tr[:replay_n],
                criterion=criterion,
                val_loader=val_loader,
                device=device,
                input_dim=input_dim,
                num_cycles=ONLINE_CYCLES,
                batch_size=ONLINE_BS,
                lr=ONLINE_LR,
                use_amp=USE_AMP
            )

        print("After online adapt:")
        met = evaluate_no_plots(model, val_loader, num_classes, device, input_dim)

        if met["accuracy"] > best_acc:
            best_acc = met["accuracy"]
            best_fold = fold
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    out_path = OUT_DIR / "EEGEMGTransformer_best_allsubjects.pth"
    torch.save(
        {
            "state_dict": best_state,
            "input_dim": input_dim,
            "num_classes": num_classes,
            "best_fold": best_fold,
            "label_set": np.unique(y).tolist(),
            "seed": SEED,
        },
        out_path
    )
    print(f"\n[OK] Saved: {out_path} | best_fold={best_fold} | acc={best_acc:.4f}")
