In [None]:
!pip install torcheeg

In [None]:
!pip install torch-scatter


In [None]:
!pip install --quiet mne scipy                  # CSP helper tools


In [1]:
# Standard library & basic third‑party imports
import os
import argparse
from typing import Optional, Dict, List

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score

import torch
from torch.utils.data import Dataset, DataLoader

# TorchEEG
import numpy as np
from scipy.signal import butter, filtfilt

from torcheeg.models import CSPNet
from tqdm.auto import tqdm  # ← progress bars

In [2]:
class ButterBandpass:
    """Callable transform: band‑pass 8–32 Hz for each (C, T) EEG array.
    Ensures *float32* output so tensors match the model weights.
    """
    def __init__(self, low=8, high=32, fs=250, order=4):
        self.b, self.a = butter(order, [low, high], btype="bandpass", fs=fs)

    def __call__(self, arr):  # arr shape (8, T)
        filtered = filtfilt(self.b, self.a, arr, axis=-1,
                            padlen=3 * max(len(self.a), len(self.b)))
        return filtered.astype(np.float32, copy=False)


class SegReconstruct:
    """Segmentation & Reconstruction (S&R) augmentation from EEG‑Conformer.

    Splits each trial into *n_seg* equal chunks along the time axis, shuffles
    them, then re‑concatenates.  Applied with probability *p*.
    """
    def __init__(self, n_seg: int = 4, p: float = 0.5):
        assert 0 < p <= 1, "p must be in (0,1]"
        assert n_seg > 1, "n_seg must be > 1"
        self.n_seg, self.p = n_seg, p
    def __call__(self, arr):  # arr: (8, T)
        if np.random.rand() > self.p:
            return arr.astype(np.float32, copy=False)
        C, T = arr.shape
        seg_len = T // self.n_seg
        idxs = np.arange(self.n_seg)
        np.random.shuffle(idxs)
        segments = [arr[:, i*seg_len:(i+1)*seg_len] for i in idxs]
        shuffled = np.concatenate(segments, axis=-1)
        # Append remainder if T isn't divisible by n_seg
        if seg_len * self.n_seg < T:
            shuffled = np.concatenate([shuffled, arr[:, seg_len*self.n_seg:]], axis=-1)
        return shuffled.astype(np.float32, copy=False)


In [4]:

# -------------------------------------------------
# 1. Imports
# -------------------------------------------------
import os
from typing import Optional, Dict, List

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score

import torch
from torch.utils.data import Dataset, DataLoader

from torcheeg.models import CSPNet

# %% [markdown]
"""
## 2. Dataset Class
Maps the Kaggle index CSVs to raw `EEGdata.csv` recordings and slices the requested trial.
"""

# %%
class MTCBCIDataset(Dataset):
    """PyTorch Dataset for the MTC‑AIC3 BCI Competition.

    Handles both `S4` and `4` style subject IDs and caches session files to
    minimise disk I/O.
    """

    EEG_CHANNELS: List[str] = [
        "FZ", "C3", "CZ", "C4", "PZ", "PO7", "OZ", "PO8",
    ]

    LABEL_MAP: Dict[str, Dict[str, int]] = {
        "MI": {"Left": 0, "Right": 1},
        "SSVEP": {"Left": 0, "Right": 1, "Forward": 2, "Backward": 3},
    }

    def __init__(
        self,
        csv_path: str,
        root_dir: str,
        cache_eeg: bool = True,
        transform: Optional[callable] = None,
    ) -> None:
        super().__init__()
        self.df = pd.read_csv(csv_path)
        self.root_dir = root_dir.rstrip("/\\")  # ← FIXED back‑slash escape
        self.cache_eeg = cache_eeg
        self.transform = transform
        self._eeg_cache: Dict[str, pd.DataFrame] = {}

    # ------------------------------------------------------------------
    # Dataset API
    # ------------------------------------------------------------------
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        task = row["task"]
        split = self._split_by_id(row["id"])

        # --- subject directory: avoid double "S" ---
        subj_field = str(row["subject_id"])
        subject_dir = subj_field if subj_field.upper().startswith("S") else f"S{subj_field}"

        eeg_csv = os.path.join(
            self.root_dir,
            task,
            split,
            subject_dir,
            str(row["trial_session"]),
            "EEGdata.csv",
        )

        # Load / cache CSV
        if self.cache_eeg and eeg_csv in self._eeg_cache:
            eeg_df = self._eeg_cache[eeg_csv]
        else:
            eeg_df = pd.read_csv(eeg_csv)
            if self.cache_eeg:
                self._eeg_cache[eeg_csv] = eeg_df

        # Slice out the requested trial
        samples_per_trial = 2250 if task == "MI" else 1750
        start = (int(row["trial"]) - 1) * samples_per_trial
        end = start + samples_per_trial

        seg = eeg_df.loc[start:end - 1, self.EEG_CHANNELS].to_numpy(np.float32).T  # (8, T)
        if self.transform:
            seg = self.transform(seg)

        x = torch.from_numpy(seg.copy()).unsqueeze(0)  # ensure positive strides
        
        # Label handling (‑1 for unlabeled test rows)
        if "label" in row and not pd.isna(row["label"]):
            y = torch.tensor(self.LABEL_MAP[task][row["label"]], dtype=torch.long)
        else:
            y = torch.tensor(-1, dtype=torch.long)
        return x, y

    # ------------------------------------------------------------------
    # Helper
    # ------------------------------------------------------------------
    @staticmethod
    def _split_by_id(idx: int) -> str:
        if idx <= 4800:
            return "train"
        elif idx <= 4900:
            return "validation"
        return "test"


In [5]:
from torch.utils.data import DataLoader

import random
import torch

def sr_collate(batch, n_seg=4, p=0.5):
    """
    EEG-Conformer Segmentation & Reconstruction (S&R) implementation.

    Parameters
    ----------
    batch : list of (x, y) tuples from the Dataset.
        x shape  = (1, 8, T)
        y scalar = label (torch.long)
    n_seg : int
        Number of equal segments per trial (paper uses 4).
    p : float
        Probability of performing S&R augmentation for *this* batch.

    Returns
    -------
    xs, ys : torch.Tensor
        xs shape = (N (+synthetics), 1, 8, T)
        ys shape = (N (+synthetics),)
    """
    xs, ys = list(zip(*batch))          # tuples -> lists
    xs, ys = list(xs), list(ys)

    if random.random() < p:             # decide whether to augment
        # --- group indices by class ---
        by_class = {}
        for idx, y in enumerate(ys):
            cls = int(y)
            by_class.setdefault(cls, []).append(idx)

        synth_x, synth_y = [], []
        for cls, idxs in by_class.items():
            if len(idxs) < 2:           # need ≥2 source trials
                continue

            # assume all trials same T
            T = xs[idxs[0]].shape[-1]
            seg_len = T // n_seg

            seg_list = []
            for i in range(n_seg):
                src_idx = random.choice(idxs)
                seg = xs[src_idx][..., i*seg_len:(i+1)*seg_len]  # (1,8,seg_len)
                seg_list.append(seg)

            synth_trial = torch.cat(seg_list, dim=-1)            # (1,8,T)
            synth_x.append(synth_trial)
            synth_y.append(torch.tensor(cls, dtype=torch.long))

        if synth_x:                      # append to original batch
            xs.extend(synth_x)
            ys.extend(synth_y)

    return torch.stack(xs, 0), torch.stack(ys, 0)



In [7]:
# %%
root_dir = "/kaggle/input/mtcaic3"  # ⬅️  CHANGE ME

task = "MI"          # "MI"  or  "SSVEP"
epochs = 200
batch_size = 128

# %% [markdown]
"""
## 4. Data Preparation
Creates the training & validation loaders for the chosen task.
"""
# --- Band‑pass transform (8‑32 Hz, 4th‑order Butterworth) ---
bandpass = ButterBandpass(low=8, high=32, fs=250, order=4)
seg_aug  = SegReconstruct(n_seg=4, p=0.5)          # 50 % chance
def combined(arr):
    arr = bandpass(arr)
    # arr = seg_aug(arr)
    return arr

# %%
train_csv = os.path.join(root_dir, "train.csv")
val_csv = os.path.join(root_dir, "validation.csv")

train_full = MTCBCIDataset(train_csv, root_dir, transform=combined)
val_full = MTCBCIDataset(val_csv, root_dir, cache_eeg=False)

# train_full = MTCBCIDataset(train_csv, root_dir)
# val_full = MTCBCIDataset(val_csv, root_dir, cache_eeg=False)

# List of duplicate indices you provided
duplicates_indices = [
    1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509,
    1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659,
    2290, 2291, 2292, 2293, 2294, 2295, 2296, 2297, 2298, 2299,
    2300, 2301, 2302, 2303, 2304, 2305, 2306, 2307, 2308, 2309
]


train_indices = [i for i, r in enumerate(train_full.df.itertuples()) if r.task == task]
clean_train_indices = [idx for idx in train_indices if idx not in duplicates_indices]

val_indices = [i for i, r in enumerate(val_full.df.itertuples()) if r.task == task]

# train_ds = torch.utils.data.Subset(train_full, train_indices)
clean_train_ds = torch.utils.data.Subset(train_full, clean_train_indices)

val_ds = torch.utils.data.Subset(val_full, val_indices)

loader_train = DataLoader(clean_train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, collate_fn=lambda b: sr_collate(b, n_seg=5, p=0.5))
loader_val = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

# loader_train = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
# loader_val = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

print(f"Loaded {len(clean_train_ds)} training and {len(val_ds)} validation trials for task {task}.")


Loaded 2360 training and 50 validation trials for task MI.


In [8]:
# ==============================================================
# ✦  CELL  A  ✦  -- compute analytic CSP filters for the MI task
# ==============================================================

import numpy as np, torch, scipy.linalg as la
from tqdm.auto import tqdm

NUM_CSP  = 8                   # f in the paper (use an even number)
N_CH     = 8                   # electrodes: FZ,C3,CZ,C4,PZ,PO7,OZ,PO8

def compute_csp_filters(dataloader, n_filters=NUM_CSP):
    """
    Returns an (n_filters, N_CH) numpy array.
    Uses the one‑vs‑one CSP solution for 2‑class MI.
    """
    cov_sum = {0: np.zeros((N_CH, N_CH)),
               1: np.zeros((N_CH, N_CH))}
    n_trial = {0: 0, 1: 0}

    for x, y in tqdm(dataloader, desc="CSP|cov", leave=False):
        # x : (B, 1, 8, T)
        x = x.squeeze(1).numpy()                       # (B,8,T)
        y = y.numpy()
        for trial, label in zip(x, y):
            trial = trial - trial.mean(axis=1, keepdims=True)
            cov   = trial @ trial.T
            cov   = cov / np.trace(cov)                # normalise
            cov_sum[label] += cov
            n_trial[label] += 1

    C0 = cov_sum[0] / n_trial[0]
    C1 = cov_sum[1] / n_trial[1]
    Cc = C0 + C1                                       # composite

    # Generalised eigen‑decomposition  C0 w = λ Cc w
    eigvals, eigvecs = la.eigh(C0, Cc)
    order = np.argsort(eigvals)[::-1]                  # descending
    eigvecs = eigvecs[:, order]

    # First k/2 & last k/2 vectors maximise variance for class 0 and 1
    k = n_filters // 2
    filters = np.concatenate([eigvecs[:, :k],
                              eigvecs[:, -k:]], axis=1).T
    return filters.astype(np.float32)                  # (k*2, 8)

# --------------------------------------------------------------
# # ⚠  Use *exactly* the MI rows of the *training* split
# # --------------------------------------------------------------
# train_mi_loader = DataLoader(train_ds, batch_size=64,
#                              shuffle=False, num_workers=0)

csp_filters = compute_csp_filters(loader_train, NUM_CSP)
print("✓ CSP filters shape:", csp_filters.shape)


CSP|cov:   0%|          | 0/19 [00:00<?, ?it/s]

✓ CSP filters shape: (8, 8)


In [9]:
# ==============================================================
# ✦  CELL  B  ✦  -- model definition (faithful CSP‑Net‑1)
# ==============================================================

import torch.nn as nn
from torcheeg.models import EEGNet
from huggingface_hub import hf_hub_download

class CSPNet1_EEGNet(nn.Module):
    """
    CSP‑Net‑1 with EEGNet‑v4 backbone.
    Args
    ----
    csp_w      : ndarray (f, 8)   CSP filters you just computed.
    freeze_csp : bool            CSP‑Net‑1‑fix (True) or ‑upd (False).
    """
    def __init__(self, csp_w, freeze_csp=True,
                 hf_variant="EEGNetv4_BNCI2014004"):
        super().__init__()
        f, C = csp_w.shape                      # (8, 8)

        # 1. CSP projection layer  (B,1,C,T) ➜ (B,f,1,T)
        self.csp = nn.Conv2d(1, f,
                             kernel_size=(C, 1),
                             bias=False)
        with torch.no_grad():
            self.csp.weight.copy_(
                torch.tensor(csp_w)             # (f,C)
                      .unsqueeze(1)             # (f,1,C)
                      .unsqueeze(-1)            # (f,1,C,1)
            )
        if freeze_csp:
            for p in self.csp.parameters():
                p.requires_grad = False

        # 2. Complete EEGNet‑v4 backbone
        self.eegnet = EEGNet(chunk_size=chunk_size,
                             num_electrodes=f,
                             num_classes=num_classes)

        # ---- optional HF backbone weights (spatial dims unchanged) ----
        load_hf_eegnet_weights(self.eegnet, hf_variant)


    # ------------------------------------------------------------------
    def forward(self, x):                         # x: (B,1,8,T)
        x = self.csp(x)                           # (B,f,1,T)
        x = x.permute(0, 2, 1, 3).contiguous()    # (B,1,f,T)
        return self.eegnet(x)


# ------------------------------------------------------------------
# Helper: load compatible HF weights into a torcheeg‑EEGNet model
# ------------------------------------------------------------------
def load_hf_eegnet_weights(model: torch.nn.Module,
                           hf_variant: str = "EEGNetv4_BNCI2014004") -> None:
    """
    Copy as many parameters as possible from PierreGtch/EEGNetv4
    into a torcheeg.models.EEGNet instance whose electrode
    dimension may differ (8 vs 22).
    """
    try:
        ckpt_path = hf_hub_download(
            repo_id="PierreGtch/EEGNetv4",
            filename=f"{hf_variant}/model-params.pkl",
            repo_type="model")
        raw = torch.load(ckpt_path, map_location="cpu")
    except Exception as e:
        print("⚠ HF download failed — using random init.\n", e)
        return

    # --- map HF names ➜ torcheeg names ---------------------------------
    rename = {
        "conv_temporal.weight"          : "block1.0.weight",
        "bnorm_temporal.weight"         : "block1.1.weight",
        "bnorm_temporal.bias"           : "block1.1.bias",
        "bnorm_temporal.running_mean"   : "block1.1.running_mean",
        "bnorm_temporal.running_var"    : "block1.1.running_var",
        "conv_spatial.weight"           : "block1.2.weight",
        "bnorm_1.weight"                : "block1.3.weight",
        "bnorm_1.bias"                  : "block1.3.bias",
        "bnorm_1.running_mean"          : "block1.3.running_mean",
        "bnorm_1.running_var"           : "block1.3.running_var",
        "conv_separable_depth.weight"   : "block2.0.weight",
        "conv_separable_point.weight"   : "block2.1.weight",
        "bnorm_2.weight"                : "block2.2.weight",
        "bnorm_2.bias"                  : "block2.2.bias",
        "bnorm_2.running_mean"          : "block2.2.running_mean",
        "bnorm_2.running_var"           : "block2.2.running_var",
        # classifier layers (“conv_classifier.*”, “lin.weight”) are skipped
    }

    tgt = model.state_dict()
    copied = {}
    for src_name, dst_name in rename.items():
        if src_name not in raw or dst_name not in tgt:
            continue
        if raw[src_name].shape != tgt[dst_name].shape:
            # conv_spatial.weight is (16,1,22,1) in HF, but (16,1,8,1) here
            continue
        copied[dst_name] = raw[src_name]

    tgt.update(copied)
    model.load_state_dict(tgt, strict=False)
    print(f"✓ Copied {len(copied)}/{len(rename)} compatible tensors "
          f"from HF checkpoint")


In [10]:
# 4. Train / eval helpers -------------------------------------------------
def train_one_epoch(model, loader, epoch: int):
    model.train()
    total, processed = 0.0, 0
    bar = tqdm(loader, desc=f"Epoch {epoch} [train]", leave=False)
    for x, y in bar:
        x, y = x.to(device), y.to(device)

        logits = model(x)
        loss = criterion(logits, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total += loss.item() * x.size(0)
        processed += x.size(0)
        bar.set_postfix(loss=f"{total/processed:.4f}")

    return total / processed


@torch.no_grad()
def evaluate(model, loader, epoch: int):
    model.eval()
    preds, labels = [], []
    bar = tqdm(loader, desc=f"Epoch {epoch} [val]", leave=False)
    for x, y in bar:
        logits = model(x.to(device))
        preds.extend(logits.argmax(1).cpu().tolist())
        labels.extend(y.tolist())

    return (
        accuracy_score(labels, preds),
        balanced_accuracy_score(labels, preds),
        f1_score(labels, preds, average="weighted"),
    )

In [11]:
# ==============================================================
# ✦  CELL  C  ✦  -- set up training just like before
# ==============================================================

FREEZE_CSP   = False        # CSP‑Net‑1‑fix; set False for ‑upd variant
LEARNING_RATE = 0.001
EPOCHS        = 200
SAVE_PATH     = "cspnet1_eegnet_best.pth"
chunk_size  = 2250 if task == "MI" else 1750
num_classes = 2    if task == "MI" else 4
device = "cuda"
model = CSPNet1_EEGNet(csp_filters, freeze_csp=FREEZE_CSP).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                    model.parameters()),
                             lr=LEARNING_RATE, weight_decay=5e-4)

print(model)   # sanity‑check: ~0.63 M parameters if CSP frozen

best_bal = 0.0
for epoch in range(1, EPOCHS + 1):
    train_loss = train_one_epoch(model, loader_train, epoch)
    acc, bal, f1 = evaluate(model, loader_val, epoch)
    print(f"Ep {epoch:03d} | loss {train_loss:.4f} | acc {acc:.3f} | bal {bal:.3f}")
    if bal > best_bal:
        best_bal = bal
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"  🚀 saved new best (bal={best_bal:.3f})")


✓ Copied 15/16 compatible tensors from HF checkpoint
CSPNet1_EEGNet(
  (csp): Conv2d(1, 8, kernel_size=(8, 1), stride=(1, 1), bias=False)
  (eegnet): EEGNet(
    (block1): Sequential(
      (0): Conv2d(1, 8, kernel_size=(1, 64), stride=(1, 1), padding=(0, 32), bias=False)
      (1): BatchNorm2d(8, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Conv2dWithConstraint(8, 16, kernel_size=(8, 1), stride=(1, 1), groups=8, bias=False)
      (3): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (4): ELU(alpha=1.0)
      (5): AvgPool2d(kernel_size=(1, 4), stride=4, padding=0)
      (6): Dropout(p=0.25, inplace=False)
    )
    (block2): Sequential(
      (0): Conv2d(16, 16, kernel_size=(1, 16), stride=(1, 1), padding=(0, 8), groups=16, bias=False)
      (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (3): ELU(alpha

Epoch 1 [train]:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 1 [val]:   0%|          | 0/1 [00:00<?, ?it/s]

Ep 001 | loss 0.6999 | acc 0.560 | bal 0.500
  🚀 saved new best (bal=0.500)


Epoch 2 [train]:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 2 [val]:   0%|          | 0/1 [00:00<?, ?it/s]

Ep 002 | loss 0.6917 | acc 0.560 | bal 0.500


Epoch 3 [train]:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 3 [val]:   0%|          | 0/1 [00:00<?, ?it/s]

Ep 003 | loss 0.6875 | acc 0.480 | bal 0.492


Epoch 4 [train]:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 4 [val]:   0%|          | 0/1 [00:00<?, ?it/s]

Ep 004 | loss 0.6864 | acc 0.520 | bal 0.484


Epoch 5 [train]:   0%|          | 0/19 [00:00<?, ?it/s]

KeyboardInterrupt: 