In [None]:
# -*- coding: utf-8 -*-
"""SE_transformer_best.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1o8XLVIFzTrY-_yrcbAWGiVU3Vzm9eTu-
"""

# ================================================================
# Early-MI classification with SE-enhanced Transformer (mix T+E, 8-fold CV)
# ================================================================

import os, glob
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
from scipy.io import loadmat


# --------------------------------------------------
# 0) Reproducibility
# --------------------------------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# --------------------------------------------------
# 1) Mount Drive  (comment out if running locally)
# --------------------------------------------------
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# --------------------------------------------------
# 2) Dataset  (same code, but we’ll instantiate twice: full vs early)
# --------------------------------------------------
data_dir   = '/content/drive/My Drive/conferencedata'
subject_id = 2
tasks      = ['left','right']           # only these two MI classes

class EEGGraphDataset(Dataset):
    """
    Each trial → a sequence of connectivity-graph vectors (shape: T_i × 231), plus an int label.
    If test_early=True, only keep segments whose segment_end_ms ≤ 800 ms.
    """
    def __init__(self, data_dir, subject_id, test_early=False, early_ms=800):
        self.trials, self.labels = [], []
        task_to_lab  = {t: i for i, t in enumerate(tasks)}
        subj_tag     = f"A{subject_id:02d}"
        iu           = np.triu_indices(22, k=1)

        for sess in ['T','E']:  # files from both training (“T”) and evaluation (“E”) sessions
            pattern = Path(data_dir) / f"connectivity_graphs_{subj_tag}{sess}_*.mat"
            for fpath in sorted(glob.glob(str(pattern))):
                task = Path(fpath).stem.split('_')[-1].lower()
                if task not in task_to_lab:
                    continue
                lab = task_to_lab[task]

                mat = loadmat(fpath)
                segs_mat       = mat['connectivity_graphs']
                segment_end_ms = mat['segment_end_ms'].squeeze()
                trials_idx     = mat.get('trial_indices', None)

                # convert to list of 2D arrays, one per segment
                if segs_mat.dtype == np.object_:
                    segs = [np.asarray(g) for g in segs_mat.squeeze()]
                else:
                    segs = [segs_mat[:, :, i] for i in range(segs_mat.shape[2])]

                def add(seq):
                    stacked = np.stack([s[iu].astype(np.float32) for s in seq])
                    self.trials.append(stacked)
                    self.labels.append(lab)

                if trials_idx is not None:
                    trials_idx = trials_idx.squeeze().astype(int)
                    for t in np.unique(trials_idx):
                        mask    = (trials_idx == t)
                        segs_t  = [segs[i] for i in np.where(mask)[0]]
                        ends_t  = segment_end_ms[mask]
                        if test_early:
                            segs_t = [s for s, e in zip(segs_t, ends_t) if e <= 800]
                        if segs_t:
                            add(segs_t)
                else:
                    segs_f = [s for s, e in zip(segs, segment_end_ms)
                              if (e <= 800 if test_early else True)]
                    if segs_f:
                        add(segs_f)

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        return self.trials[idx], self.labels[idx]

# --------------------------------------------------
# 3) Pad-collate  (unchanged)
# --------------------------------------------------
def pad_collate(batch):
    seqs = [torch.tensor(x, dtype=torch.float32) for x, _ in batch]
    labs = torch.tensor([y for _, y in batch], dtype=torch.long)
    lengths = [s.shape[0] for s in seqs]
    max_len = max(lengths)

    padded = torch.stack([
        torch.cat([s, torch.zeros(max_len - s.shape[0], s.shape[1])], dim=0)
        for s in seqs
    ])
    pad_mask = torch.stack([
        torch.tensor([False] * l + [True] * (max_len - l)) for l in lengths
    ])
    return padded, labs, pad_mask

# --------------------------------------------------
# 4) Squeeze-and-Excitation block
# --------------------------------------------------
class SEBlock(nn.Module):
    """
    Treat d_model as 'channels'.  Input x shape: (B, T, d_model)
    """
    def __init__(self, d_model, r=4):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_model // r, bias=False)
        self.fc2 = nn.Linear(d_model // r, d_model, bias=False)

    def forward(self, x):
        # x: (B, T, d_model)
        s = x.mean(dim=1)                         # (B, d_model)
        z = F.relu(self.fc1(s))                   # (B, d_model//r)
        z = torch.sigmoid(self.fc2(z)).unsqueeze(1)  # (B, 1, d_model)
        return x * z                              # broadcast → (B, T, d_model)

# --------------------------------------------------
# 5) Encoder layer with SE after FFN
# --------------------------------------------------
class SEEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_ff, dropout):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            d_model, nhead, dropout=dropout, batch_first=True
        )
        # feedforward network
        self.ff1   = nn.Linear(d_model, dim_ff)
        self.ff2   = nn.Linear(dim_ff, d_model)
        self.se    = SEBlock(d_model)      # new SE block
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop  = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask=None):
        attn_out, _ = self.self_attn(
            x, x, x, key_padding_mask=key_padding_mask, need_weights=False
        )
        x = self.norm1(x + self.drop(attn_out))

        y = self.ff2(self.drop(F.gelu(self.ff1(x))))
        y = self.se(y)                              # apply SE
        x = self.norm2(x + self.drop(y))
        return x

# --------------------------------------------------
# 6) Transformer classifier
# --------------------------------------------------
class EEGTransformerClassifier(nn.Module):
    def __init__(
        self,
        feature_dim=231,
        d_model=16,
        nhead=2,
        num_layers=2,
        num_classes=2,
        dim_ff=256,
        dropout=0.1,
        max_seq_len=100
    ):
        super().__init__()
        self.proj = nn.Linear(feature_dim, d_model)
        self.cls  = nn.Parameter(torch.zeros(1, 1, d_model))

        # fixed sinusoidal positional embeddings
        pos = torch.arange(max_seq_len + 1).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * -(np.log(1e4) / d_model))
        pe  = torch.zeros(max_seq_len + 1, d_model)
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe)

        self.layers = nn.ModuleList([
            SEEncoderLayer(d_model, nhead, dim_ff, dropout)
            for _ in range(num_layers)
        ])
        self.drop = nn.Dropout(dropout)
        self.fc   = nn.Linear(d_model, num_classes)

    def forward(self, x, pad_mask=None):
        # x: (B, T, feature_dim)
        B, T, _ = x.shape
        x = self.proj(x)                           # → (B, T, d_model)
        cls = self.cls.expand(B, -1, -1)           # → (B, 1, d_model)
        x = torch.cat([cls, x], dim=1) + self.pe[:T + 1]
        if pad_mask is not None:
            pad_mask = torch.cat([pad_mask.new_zeros((B, 1)), pad_mask], dim=1)
        for lyr in self.layers:
            x = lyr(x, key_padding_mask=pad_mask)
        # classify based on the [CLS] token output
        return self.fc(self.drop(x[:, 0, :]))

# --------------------------------------------------
# 7) Load “full” vs “early” datasets (mix T + E now)
# --------------------------------------------------
full_ds  = EEGGraphDataset(data_dir, subject_id, test_early=False)  # all segments
early_ds = EEGGraphDataset(data_dir, subject_id, test_early=True)   # only ≤800 ms

# Sanity check: both lists have same trial count & same labels in same order
assert len(full_ds) == len(early_ds)
assert all(a == b for a, b in zip(full_ds.labels, early_ds.labels))

all_X_full  = full_ds.trials   # list of numpy arrays, each full-trial segments
all_y_full  = full_ds.labels
all_X_early = early_ds.trials  # same-length list, but each truncated to ≤800 ms
all_y_early = early_ds.labels

# --------------------------------------------------
# 8) 8-fold cross-validation over all trials (mixing T+E)
# --------------------------------------------------
kf         = KFold(n_splits=8, shuffle=True, random_state=42)
device     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bs, ne, lr = 8, 35, 3e-3
patience   = 50
fold_accs  = []

for fold, (train_idx, val_idx) in enumerate(kf.split(all_X_full), 1):
    print(f"\n=== Fold {fold}/8 ===")

    # --- Build training split: full segments of those trials ---
    X_train = [all_X_full[i] for i in train_idx]
    y_train = [all_y_full[i] for i in train_idx]

    # --- Build validation split: early (≤800 ms) segments of held-out trials ---
    X_val = [all_X_early[i] for i in val_idx]
    y_val = [all_y_early[i] for i in val_idx]

    # --- Compute normalization stats on training data only ---
    all_train_segments = np.concatenate(X_train, axis=0)  # shape = (sum of T_i, 231)
    μ = all_train_segments.mean(axis=0, keepdims=True)
    σ = all_train_segments.std(axis=0, keepdims=True) + 1e-6
    normalize = lambda seqs: [(s - μ) / σ for s in seqs]

    X_train_norm = normalize(X_train)
    X_val_norm   = normalize(X_val)

    # --- Wrap in DataLoaders ---
    class TrialDataset(Dataset):
        def __init__(self, X, y):
            self.X = X
            self.y = y
        def __len__(self):
            return len(self.X)
        def __getitem__(self, i):
            return self.X[i], self.y[i]

    train_loader = DataLoader(
        TrialDataset(X_train_norm, y_train),
        batch_size=bs,
        shuffle=True,
        collate_fn=pad_collate
    )
    val_loader = DataLoader(
        TrialDataset(X_val_norm, y_val),
        batch_size=bs,
        shuffle=False,
        collate_fn=pad_collate
    )

    # --- Instantiate model with appropriate max sequence length ---
    max_T = max(
        max(tr.shape[0] for tr in X_train_norm),
        max(tr.shape[0] for tr in X_val_norm)
    )
    model = EEGTransformerClassifier(
        feature_dim=231,
        d_model=16,
        nhead=2,
        num_layers=2,
        num_classes=len(tasks),
        dim_ff=256,
        dropout=0.1,
        max_seq_len=max_T
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = 0.0
    wait = patience
    ckpt_path = f"fold{fold}.pt"

    # --- Training loop ---
    for ep in range(1, ne + 1):
        model.train()
        total_loss, total_correct, total_samples = 0.0, 0, 0

        for Xb, yb, mb in train_loader:
            Xb = Xb.to(device)
            yb = yb.to(device)
            mb = mb.to(device)

            optimizer.zero_grad()
            outputs = model(Xb, mb)
            loss = criterion(outputs, yb)
            loss.backward()
            optimizer.step()

            preds = outputs.argmax(dim=1)
            total_correct += (preds == yb).sum().item()
            total_samples += yb.size(0)
            total_loss += loss.item() * yb.size(0)

        train_loss = total_loss / total_samples
        train_acc  = 100.0 * total_correct / total_samples

        # --- Validation on "early" segments ---
        model.eval()
        correct, count = 0, 0
        with torch.no_grad():
            for Xb, yb, mb in val_loader:
                Xb = Xb.to(device)
                yb = yb.to(device)
                mb = mb.to(device)
                preds = model(Xb, mb).argmax(dim=1)
                correct += (preds == yb).sum().item()
                count += yb.size(0)
        val_acc = 100.0 * correct / count

        # --- Early stopping logic ---
        if val_acc > best_val_acc:
            best_val_acc, wait = val_acc, patience
            torch.save(model.state_dict(), ckpt_path)
        else:
            wait -= 1
            if wait == 0:
                print(f"⟹ Early stop @ epoch {ep}")
                break

        # --- Print stats at ep 1 and every 5 epochs ---
        if ep == 1 or ep % 5 == 0:
            print(f"Ep{ep:02d}: tr-loss {train_loss:.4f} | "
                  f"tr-acc {train_acc:.2f}% | early-val-acc {val_acc:.2f}%")

    # --- Final evaluation for this fold ---
    model.load_state_dict(torch.load(ckpt_path))
    model.eval()
    correct, count = 0, 0
    with torch.no_grad():
        for Xb, yb, mb in val_loader:
            Xb = Xb.to(device)
            yb = yb.to(device)
            mb = mb.to(device)
            preds = model(Xb, mb).argmax(dim=1)
            correct += (preds == yb).sum().item()
            count += yb.size(0)
    fold_acc = 100.0 * correct / count
    fold_accs.append(fold_acc)
    print(f"♦ Fold {fold} final early-800ms acc: {fold_acc:.2f}%")

# --------------------------------------------------
# 9) Summary of CV results
# --------------------------------------------------
print(f"\nMean ± SD early-800ms accuracy over 8 folds: "
      f"{np.mean(fold_accs):.2f}% ± {np.std(fold_accs):.2f}%")

import numpy as np
from sklearn.model_selection import KFold

def dataset_stats(ds, name="dataset"):
    seg_counts = [trial.shape[0] for trial in ds.trials]
    n_trials = len(seg_counts)
    n_segments = int(np.sum(seg_counts))

    print(f"\n{name}")
    print("-" * len(name))
    print(f"Trials               : {n_trials}")
    print(f"Segments             : {n_segments}")
    print(f"[CLS] tokens (1/trial): {n_trials}")
    print(f"Total tokens         : {n_segments + n_trials}")
    print(f"Segments/trial       : "
          f"mean {np.mean(seg_counts):.1f}, "
          f"min {np.min(seg_counts)}, "
          f"max {np.max(seg_counts)}")

    return n_trials, n_segments

# --- Use main code's variables: data_dir, subject_id, tasks ---
# Instantiate datasets
full_ds = EEGGraphDataset(data_dir, subject_id, test_early=False)
early_ds_800 = EEGGraphDataset(data_dir, subject_id, test_early=True, early_ms=800)

# --- Overall stats ---
train_trials, train_segs = dataset_stats(full_ds, "TRAIN  (all segments)")
test_800_trials, test_800_segs = dataset_stats(early_ds_800, "TEST   (segments ending ≤800 ms)")

# --- CV split stats ---
cv_train_trials, cv_train_segs, cv_test_trials, cv_test_segs_800 = [], [], [], []
for tr_idx, val_idx in KFold(n_splits=10, shuffle=True, random_state=42).split(full_ds.trials):
    # Training
    train_trials = [full_ds.trials[i] for i in tr_idx]
    train_segs = int(np.sum([t.shape[0] for t in train_trials]))
    cv_train_trials.append(len(tr_idx))
    cv_train_segs.append(train_segs)
    # Testing (≤800 ms)
    test_trials_800 = [early_ds_800.trials[i] for i in val_idx]
    test_segs_800 = int(np.sum([t.shape[0] for t in test_trials_800]))
    cv_test_trials.append(len(val_idx))
    cv_test_segs_800.append(test_segs_800)

# --- Print CV folds ---
print("\n8-fold split sizes (train / valid per fold, ≤800 ms)")
for i, (tr_idx, val_idx) in enumerate(KFold(n_splits=8, shuffle=True, random_state=42).split(full_ds.trials), 1):
    train_segs = int(np.sum([full_ds.trials[i].shape[0] for i in tr_idx]))
    test_segs_800 = int(np.sum([early_ds_800.trials[i].shape[0] for i in val_idx]))
    print(f"Fold {i}: {len(tr_idx):4d} trials ({train_segs:5d} segs) / "
          f"{len(val_idx):3d} trials ({test_segs_800:4d} segs ≤800 ms)")

# --- Print CV means ---
print("\nCV Mean Statistics:")
print(f"CV Training Trials      : {np.mean(cv_train_trials):.0f}")
print(f"CV Training Segments    : {np.mean(cv_train_segs):.0f}")
print(f"CV Testing Trials       : {np.mean(cv_test_trials):.0f}")
print(f"CV Testing Segments (≤800 ms) : {np.mean(cv_test_segs_800):.0f}")

Mounted at /content/drive


  trials_idx = trials_idx.squeeze().astype(int)



=== Fold 1/8 ===
Ep01: tr-loss 0.7293 | tr-acc 50.00% | early-val-acc 47.22%
Ep05: tr-loss 0.6827 | tr-acc 56.75% | early-val-acc 52.78%
Ep10: tr-loss 0.6961 | tr-acc 48.41% | early-val-acc 55.56%
Ep15: tr-loss 0.6974 | tr-acc 50.00% | early-val-acc 69.44%
Ep20: tr-loss 0.6697 | tr-acc 59.92% | early-val-acc 58.33%
Ep25: tr-loss 0.6759 | tr-acc 60.71% | early-val-acc 61.11%
Ep30: tr-loss 0.6144 | tr-acc 68.25% | early-val-acc 50.00%
Ep35: tr-loss 0.6041 | tr-acc 68.25% | early-val-acc 58.33%
♦ Fold 1 final early-800ms acc: 69.44%

=== Fold 2/8 ===
Ep01: tr-loss 0.7396 | tr-acc 48.02% | early-val-acc 38.89%
Ep05: tr-loss 0.6927 | tr-acc 51.59% | early-val-acc 38.89%
Ep10: tr-loss 0.6954 | tr-acc 50.00% | early-val-acc 47.22%
Ep15: tr-loss 0.5177 | tr-acc 76.19% | early-val-acc 38.89%
Ep20: tr-loss 0.5410 | tr-acc 75.00% | early-val-acc 47.22%
Ep25: tr-loss 0.2994 | tr-acc 88.89% | early-val-acc 58.33%
Ep30: tr-loss 0.2908 | tr-acc 87.30% | early-val-acc 52.78%
Ep35: tr-loss 0.3892 | tr