In [None]:
# ============================================================
# DATASET-1 (BMIS) â€” ALL SUBJECTS (sub1..sub33) COMBINED TRAINING

#
# Expected per-subject folder layout:
#   /home/tsultan1/paper-2/dataset-1/final_exports-sub{K}/
#       eeg_sub{K}.csv
#       emg_sub{K}.csv
#       labels_sub{K}.csv   (columns: subject_id, Label)
# ============================================================

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    precision_score, recall_score, f1_score, accuracy_score,
    confusion_matrix, balanced_accuracy_score,
    cohen_kappa_score, matthews_corrcoef, log_loss,
    roc_auc_score, average_precision_score
)
from scipy.signal import correlate, resample
from sklearn.metrics import hamming_loss, top_k_accuracy_score

# -----------------------------
# GPU
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[DEVICE] {device}")

# -----------------------------
# Root + subject range
# -----------------------------
ROOT = "/home/tsultan1/paper-2/dataset-1"
SUBJECT_MAX = 33

# -----------------------------
# Robust CSV read (numeric only)
# -----------------------------
def read_numeric_csv(path: str) -> np.ndarray:
    df = pd.read_csv(path)
    df = df.apply(pd.to_numeric, errors="coerce").fillna(0.0)
    return df.to_numpy(dtype=np.float32, copy=False)

def read_labels_csv(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    if "Label" not in df.columns:
        raise ValueError(f"'Label' column not found in {path}. Columns={list(df.columns)}")
    # keep numeric Label
    df["Label"] = pd.to_numeric(df["Label"], errors="coerce").fillna(0).astype(int)
    if "subject_id" in df.columns:
        df["subject_id"] = pd.to_numeric(df["subject_id"], errors="coerce").fillna(-1).astype(int)
    else:
        df["subject_id"] = -1
    return df

# -----------------------------
# Sliding Window Cross-Correlation (same logic)
# -----------------------------
def sliding_window_cross_correlation(eeg, emg, window_size=100, overlap=50):
    eeg_aligned = np.zeros_like(eeg)
    emg_aligned = np.zeros_like(emg)
    step = window_size - overlap
    sync_scores = []

    if len(eeg) < window_size or len(emg) < window_size:
        return eeg, emg, 0.0

    for i in range(0, len(eeg) - window_size, step):
        eeg_window = eeg[i:i + window_size].flatten()
        emg_window = emg[i:i + window_size].flatten()

        eeg_std = np.std(eeg_window)
        emg_std = np.std(emg_window)
        if eeg_std < 1e-12 or emg_std < 1e-12:
            eeg_aligned[i:i + window_size] = eeg[i:i + window_size]
            emg_aligned[i:i + window_size] = emg[i:i + window_size]
            sync_scores.append(0.0)
            continue

        eeg_window = (eeg_window - np.mean(eeg_window)) / eeg_std
        emg_window = (emg_window - np.mean(emg_window)) / emg_std

        correlation = correlate(eeg_window, emg_window, mode="full")
        lags = np.arange(-len(eeg_window) + 1, len(emg_window))

        denom = (np.linalg.norm(eeg_window) * np.linalg.norm(emg_window))
        if denom < 1e-12:
            correlation = np.zeros_like(correlation)
        else:
            correlation = correlation / denom

        lag = lags[np.argmax(correlation)]
        max_corr = float(np.max(correlation))
        sync_scores.append(max_corr)

        if lag > 0:
            eeg_aligned[i:i + window_size] = np.roll(eeg[i:i + window_size], lag, axis=0)
            emg_aligned[i:i + window_size] = emg[i:i + window_size]
        else:
            eeg_aligned[i:i + window_size] = eeg[i:i + window_size]
            emg_aligned[i:i + window_size] = np.roll(emg[i:i + window_size], -lag, axis=0)

    avg_sync_score = float(np.mean(sync_scores)) if len(sync_scores) else 0.0
    return eeg_aligned, emg_aligned, avg_sync_score

# -----------------------------
# Dataset
# -----------------------------
class EEGEMGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# -----------------------------
# MODEL (UNCHANGED)
# -----------------------------
class EEGEMGTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, num_classes, dropout_rate=0.5):
        super(EEGEMGTransformer, self).__init__()
        self.align_eeg = nn.Linear(input_dim, input_dim)
        self.align_emg = nn.Linear(input_dim, input_dim)

        self.eeg_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, batch_first=True, dropout=dropout_rate),
            num_layers=2
        )
        self.emg_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, batch_first=True, dropout=dropout_rate),
            num_layers=2
        )

        self.eeg_projector = nn.Linear(input_dim, hidden_dim)
        self.emg_projector = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(p=dropout_rate)

        self.cross_attention_weights = nn.Parameter(torch.tensor([[0.7], [0.3]]), requires_grad=True)
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)

        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, eeg, emg):
        eeg = self.align_eeg(eeg)
        emg = self.align_emg(emg)

        eeg_features = self.eeg_encoder(eeg)
        emg_features = self.emg_encoder(emg)

        eeg_features = self.eeg_projector(eeg_features)
        emg_features = self.emg_projector(emg_features)

        eeg_features = self.dropout(eeg_features)
        emg_features = self.dropout(emg_features)

        combined_features = (
            self.cross_attention_weights[0] * eeg_features + self.cross_attention_weights[1] * emg_features
        )
        combined, _ = self.cross_attention(combined_features, combined_features, combined_features)
        output = self.fc(combined.mean(dim=1))
        return output, self.cross_attention_weights

# -----------------------------
# Online adaptation (same behavior) + GPU
# -----------------------------
def online_adaptation_with_regularization(
    model, optimizer, buffer_X, buffer_y, criterion, val_loader, num_cycles=2, batch_size=8, lr=0.00001
):
    model.train()

    # uses X_train/y_train from outer scope (same as your pattern)
    replay_buffer_X = X_train[:100]
    replay_buffer_y = y_train[:100]
    buffer_X = np.concatenate([buffer_X, replay_buffer_X], axis=0)
    buffer_y = np.concatenate([buffer_y, replay_buffer_y], axis=0)

    buffer_dataset = EEGEMGDataset(buffer_X, buffer_y)
    buffer_loader = DataLoader(buffer_dataset, batch_size=batch_size, shuffle=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CyclicLR(
        optimizer, base_lr=lr/10, max_lr=lr, step_size_up=5, mode="triangular2"
    )

    best_val_loss = float("inf")
    best_model_state = {k: v.detach().clone() for k, v in model.state_dict().items()}

    for _cycle in range(num_cycles):
        for X_batch, y_batch in buffer_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            eeg = X_batch[:, :input_dim].unsqueeze(1)
            emg = X_batch[:, input_dim:].unsqueeze(1)

            optimizer.zero_grad()
            outputs, _ = model(eeg, emg)
            loss = criterion(outputs, y_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)

                eeg = X_batch[:, :input_dim].unsqueeze(1)
                emg = X_batch[:, input_dim:].unsqueeze(1)
                outputs, _ = model(eeg, emg)
                val_loss += criterion(outputs, y_batch).item()

        val_loss /= max(len(val_loader), 1)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
        else:
            break

        model.train()

    model.load_state_dict(best_model_state)

# -----------------------------
# Metrics (NO plots)
# -----------------------------
def calculate_hamming_loss(y_true, y_pred):
    return hamming_loss(y_true, y_pred)

def calculate_top_k_accuracy(y_true, y_scores, k=3):
    k = min(k, y_scores.shape[1])
    return top_k_accuracy_score(y_true, y_scores, k=k)

def evaluate_model(model, val_loader, sync_score_global):
    model.eval()
    y_true, y_pred, y_scores = [], [], []
    attention_weights = []

    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            eeg = X_batch[:, :input_dim].unsqueeze(1)
            emg = X_batch[:, input_dim:].unsqueeze(1)

            outputs, weights = model(eeg, emg)
            probs = torch.softmax(outputs, dim=1)

            _, predicted = torch.max(outputs, 1)
            y_true.extend(y_batch.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
            y_scores.extend(probs.cpu().numpy())
            attention_weights.append(weights.detach().cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_scores = np.array(y_scores)
    attention_weights = np.array(attention_weights)

    precision = precision_score(y_true, y_pred, average="weighted", zero_division=0)
    recall = recall_score(y_true, y_pred, average="weighted", zero_division=0)
    f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    acc = accuracy_score(y_true, y_pred)
    balanced_acc = balanced_accuracy_score(y_true, y_pred)
    kappa = cohen_kappa_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred) if len(np.unique(y_true)) > 1 else 0.0

    try:
        log_loss_value = log_loss(y_true, y_scores, labels=np.arange(y_scores.shape[1]))
    except Exception:
        log_loss_value = None

    hamming_loss_value = calculate_hamming_loss(y_true, y_pred)

    try:
        top_k_accuracy_value = calculate_top_k_accuracy(y_true, y_scores, k=3)
    except Exception:
        top_k_accuracy_value = None

    try:
        auroc = roc_auc_score(y_true, y_scores, multi_class="ovr", average="weighted")
        auprc = average_precision_score(y_true, y_scores, average="weighted")
    except Exception:
        auroc, auprc = None, None

    cm = confusion_matrix(y_true, y_pred)

    with np.errstate(divide="ignore", invalid="ignore"):
        per_class_acc = np.diag(cm) / cm.sum(axis=1)
        per_class_error = 1 - per_class_acc
        per_class_error = per_class_error[np.isfinite(per_class_error)]
        mpce = float(np.mean(per_class_error)) if len(per_class_error) else 0.0

    try:
        avg_attention_weights = np.mean(attention_weights, axis=0)
    except Exception:
        avg_attention_weights = None

    print(f"Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f} | Acc: {acc:.4f}")
    print(f"Balanced Acc: {balanced_acc:.4f} | Kappa: {kappa:.4f} | MCC: {mcc:.4f}")
    if log_loss_value is not None:
        print(f"LogLoss: {log_loss_value:.6f}")
    if auroc is not None:
        print(f"AUROC: {auroc:.4f} | AUPRC: {auprc:.4f}")
    print(f"MPCE: {mpce:.6f} | Hamming: {hamming_loss_value:.6f}")

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "accuracy": acc,
        "balanced_accuracy": balanced_acc,
        "cohen_kappa": kappa,
        "mcc": mcc,
        "log_loss": log_loss_value,
        "auroc": auroc,
        "auprc": auprc,
        "mpce": mpce,
        "hamming_loss": hamming_loss_value,
        "top_k_accuracy": top_k_accuracy_value,
        "sync_score": sync_score_global,
        "attention_weights": avg_attention_weights,
        "confusion_matrix": cm,
    }

# -----------------------------
# Training (KEEP your current settings here)
# -----------------------------
def train_model_with_weight_decay(model, train_loader, val_loader, criterion, epochs=30, patience=5, lr=0.00005):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
    best_loss = float("inf")
    patience_counter = 0

    best_model_state = {k: v.detach().clone() for k, v in model.state_dict().items()}

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for X_batch, y_batch in train_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            eeg = X_batch[:, :input_dim].unsqueeze(1)
            emg = X_batch[:, input_dim:].unsqueeze(1)

            optimizer.zero_grad()
            outputs, _ = model(eeg, emg)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)

                eeg = X_batch[:, :input_dim].unsqueeze(1)
                emg = X_batch[:, input_dim:].unsqueeze(1)
                outputs, _ = model(eeg, emg)
                val_loss += criterion(outputs, y_batch).item()

        val_loss /= max(len(val_loader), 1)
        train_loss = total_loss / max(len(train_loader), 1)

        print(f"Epoch {epoch+1}/{epochs} | train_loss={train_loss:.6f} | val_loss={val_loss:.6f}")

        if val_loss < best_loss:
            best_loss = val_loss
            patience_counter = 0
            best_model_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    model.load_state_dict(best_model_state)

# ============================================================
# LOAD + ALIGN ALL SUBJECTS (SKIP MISSING)
# ============================================================
combined_list = []
labels_list = []
subid_list = []
sync_scores = []

expected_eeg_dim = None
expected_emg_dim = None

loaded_subjects = []
skipped_subjects = []

for sid in range(1, SUBJECT_MAX + 1):
    sub_dir = os.path.join(ROOT, f"final_exports-sub{sid}")
    eeg_path = os.path.join(sub_dir, f"eeg_sub{sid}.csv")
    emg_path = os.path.join(sub_dir, f"emg_sub{sid}.csv")
    lab_path = os.path.join(sub_dir, f"labels_sub{sid}.csv")

    if not (os.path.isdir(sub_dir) and os.path.exists(eeg_path) and os.path.exists(emg_path) and os.path.exists(lab_path)):
        skipped_subjects.append(sid)
        continue

    try:
        eeg = read_numeric_csv(eeg_path)
        emg = read_numeric_csv(emg_path)
        lab_df = read_labels_csv(lab_path)

        # Make lengths match (keep the same behavior idea; if mismatch, resample EEG rows)
        n = min(len(eeg), len(emg), len(lab_df))
        eeg = eeg[:n]
        emg = emg[:n]
        lab_df = lab_df.iloc[:n].reset_index(drop=True)

        # if still mismatch between eeg/emg, resample eeg to match emg (rare)
        if len(eeg) != len(emg):
            eeg = resample(eeg, num=len(emg), axis=0).astype(np.float32, copy=False)
            n2 = min(len(eeg), len(emg), len(lab_df))
            eeg, emg = eeg[:n2], emg[:n2]
            lab_df = lab_df.iloc[:n2].reset_index(drop=True)

        # Check consistent feature dims across subjects (avoid dimension errors)
        if expected_eeg_dim is None:
            expected_eeg_dim = eeg.shape[1]
            expected_emg_dim = emg.shape[1]
        else:
            if eeg.shape[1] != expected_eeg_dim or emg.shape[1] != expected_emg_dim:
                print(f"[SKIP sub{sid}] dim mismatch: eeg {eeg.shape[1]} vs {expected_eeg_dim} OR emg {emg.shape[1]} vs {expected_emg_dim}")
                skipped_subjects.append(sid)
                continue

        # Align per subject (important: do NOT align across subject boundaries)
        eeg_aligned, emg_aligned, ss = sliding_window_cross_correlation(eeg, emg, window_size=100, overlap=50)
        sync_scores.append((ss, len(eeg_aligned)))

        combined = np.concatenate([eeg_aligned, emg_aligned], axis=1)  # (n, eeg+emg)
        y = lab_df["Label"].to_numpy(dtype=int)
        subj_ids = lab_df["subject_id"].to_numpy(dtype=int)

        combined_list.append(combined)
        labels_list.append(y)
        subid_list.append(subj_ids)

        loaded_subjects.append(sid)
        print(f"[LOAD OK] sub{sid}: N={len(combined)} | eeg_dim={eeg.shape[1]} | emg_dim={emg.shape[1]} | sync={ss:.4f}")

    except Exception as e:
        print(f"[SKIP sub{sid}] error: {e}")
        skipped_subjects.append(sid)

if not combined_list:
    raise RuntimeError("No subjects loaded. Check folder names and file paths.")

combined_data = np.concatenate(combined_list, axis=0)
labels_raw = np.concatenate(labels_list, axis=0)
subject_ids_all = np.concatenate(subid_list, axis=0)

# Label encoding (same as before)
label_encoder = LabelEncoder()
labels = label_encoder.fit_transform(labels_raw.ravel())

# Weighted avg sync score across subjects (by rows)
if sync_scores:
    sync_score_global = float(np.sum([s * n for (s, n) in sync_scores]) / np.sum([n for (_, n) in sync_scores]))
else:
    sync_score_global = 0.0

print("\n========== DATA SUMMARY ==========")
print(f"Loaded subjects: {len(loaded_subjects)} -> {loaded_subjects}")
print(f"Skipped subjects: {len(skipped_subjects)} -> {skipped_subjects}")
print(f"Total pooled rows: {combined_data.shape[0]}")
print(f"Total feature dim: {combined_data.shape[1]}")
print(f"Global sync score (weighted): {sync_score_global:.6f}")
print("Label distribution (encoded):", dict(zip(*np.unique(labels, return_counts=True))))

# ============================================================
# K-FOLD TRAINING ON POOLED DATA (NOT LOSO)
# ============================================================
N = combined_data.shape[0]
if N < 2:
    raise RuntimeError(f"Not enough samples to train. Found N={N} pooled rows.")

k = 5
k = min(k, N)
if k < 2:
    raise RuntimeError(f"Not enough samples for KFold. Need at least 2; found N={N}.")

kf = KFold(n_splits=k, shuffle=True, random_state=42)

fold_results = []
before_adaptation_metrics = []
after_adaptation_metrics = []
online_adaptation_percentage = 0.3

for fold, (train_index, val_index) in enumerate(kf.split(combined_data), start=1):
    print(f"\n========== Fold {fold}/{k} ==========")

    X_train = combined_data[train_index]
    X_val = combined_data[val_index]
    y_train = labels[train_index]
    y_val = labels[val_index]

    train_dataset = EEGEMGDataset(X_train, y_train)
    val_dataset = EEGEMGDataset(X_val, y_val)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=(device.type == "cuda"))
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, pin_memory=(device.type == "cuda"))

    input_dim = X_train.shape[1] // 2
    hidden_dim = 256
    num_heads = 4
    num_classes = len(np.unique(labels))

    model = EEGEMGTransformer(
        input_dim=input_dim, hidden_dim=hidden_dim, num_heads=num_heads, num_classes=num_classes
    ).to(device)

    criterion = nn.CrossEntropyLoss()

    # Train
    train_model_with_weight_decay(model, train_loader, val_loader, criterion)

    print("[EVAL] Before Online Adaptation")
    metrics_before = evaluate_model(model, val_loader, sync_score_global)
    before_adaptation_metrics.append(metrics_before)

    # Online adaptation (same idea)
    online_data_size = int(len(X_val) * online_adaptation_percentage)
    online_X = X_val[:online_data_size]
    online_y = y_val[:online_data_size]

    if online_data_size > 0:
        optimizer = torch.optim.Adam(model.parameters(), lr=0.000001, weight_decay=0.01)
        online_adaptation_with_regularization(
            model, optimizer, online_X, online_y, criterion, val_loader,
            num_cycles=5, batch_size=16, lr=0.00001
        )

    print("[EVAL] After Online Adaptation")
    metrics_after = evaluate_model(model, val_loader, sync_score_global)
    after_adaptation_metrics.append(metrics_after)

    fold_results.append({"before_adaptation": metrics_before, "after_adaptation": metrics_after})

# Save best model (keeps your same saving style)
best_model_index = int(np.argmax([r["after_adaptation"]["accuracy"] for r in fold_results]))
torch.save(model.state_dict(), "EEGEMGTransformer_best.pth")
print(f"\n[OK] Best model (by AFTER accuracy): Fold {best_model_index+1} saved as EEGEMGTransformer_best.pth")
