# EEG Rest vs Left — Drive Loader (EDF) · Conv1d + Transformer · PyTorch Lightning
*Generated: 2025-10-22 15:06:44*

Version **fix**: robust scanner (case-insensitive, recursive), corrected imports, absolute Drive path.

In [None]:
#@title 1) Environment & Drive
!pip -q install mne==1.7.1 pytorch-lightning==2.4.0 torchmetrics==1.4.0

import os, sys, math, json, random, glob, time, shutil, itertools
import numpy as np
import mne
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Torch:", torch.__version__, "| PL:", pl.__version__, "| MNE:", mne.__version__)
print("Device:", DEVICE)

In [None]:
#@title 2) Config
BASE_DIR = "/content/drive/MyDrive/Colab-Dataset/eegmidb/organized_data"  #@param {type:"string"}
USE_RIGHT_AS_NEGATIVE = False  #@param {type:"boolean"}
INCLUDE_PATTERNS = ["*.edf", "*.EDF"]  # accepted extensions
LEFT_DIRNAMES = ["left"]   #@param {type:"raw"}
REST_DIRNAMES = ["rest","baseline","idle"]  #@param {type:"raw"}
RIGHT_DIRNAMES = ["right"]  # used only if USE_RIGHT_AS_NEGATIVE
BANDPASS_LOWER = 1.0  #@param {type:"number"}
BANDPASS_UPPER = 40.0  #@param {type:"number"}
APPLY_NOTCH = True     #@param {type:"boolean"}
NOTCH_FREQ = 50.0      #@param {type:"number"}
RESAMPLE_HZ = 128      #@param {type:"number"}
WINDOW_SEC = 2.0       #@param {type:"number"}
WINDOW_OVERLAP = 0.5   #@param {type:"number"}
MAX_FILES_PER_CLASS = 0  # 0 = no limit  #@param {type:"number"}
VAL_SPLIT = 0.15       #@param {type:"number"}
TEST_SPLIT = 0.15      #@param {type:"number"}
RANDOM_SEED = 42       #@param {type:"number"}
BATCH_SIZE = 64        #@param {type:"number"}
MAX_EPOCHS = 20        #@param {type:"number"}
LR = 1e-3              #@param {type:"number"}
DROPOUT = 0.1          #@param {type:"number"}

seed_everything(RANDOM_SEED, workers=True)

CHECKPOINT_DIR = "checkpoints"; os.makedirs(CHECKPOINT_DIR, exist_ok=True)
LOGS_DIR = "logs"; os.makedirs(LOGS_DIR, exist_ok=True)

assert os.path.exists(BASE_DIR), f"Path not found: {BASE_DIR}"

In [None]:
#@title 3) Robust recursive scan (case-insensitive by folder name)
from fnmatch import fnmatch

def scan_edf_paths(base_dir, include_patterns):
    files = []
    for root, _, filenames in os.walk(base_dir):
        for fn in filenames:
            for pat in include_patterns:
                if fnmatch(fn, pat):
                    files.append(os.path.join(root, fn))
                    break
    return sorted(files)

def classify_by_parent_folder(path, left_names, rest_names, right_names):
    parts = [p.casefold() for p in os.path.normpath(path).split(os.sep)]
    # check any segment equals a known label folder
    if any(seg in set(n.casefold() for n in left_names) for seg in parts):
        return "left"
    if any(seg in set(n.casefold() for n in rest_names) for seg in parts):
        return "rest"
    if any(seg in set(n.casefold() for n in right_names) for seg in parts):
        return "right"
    return None

all_edfs = scan_edf_paths(BASE_DIR, INCLUDE_PATTERNS)
left_files, rest_files, right_files = [], [], []

for p in all_edfs:
    cls = classify_by_parent_folder(p, LEFT_DIRNAMES, REST_DIRNAMES, RIGHT_DIRNAMES)
    if cls == "left":
        left_files.append(p)
    elif cls == "rest":
        rest_files.append(p)
    elif cls == "right" and USE_RIGHT_AS_NEGATIVE:
        right_files.append(p)

if MAX_FILES_PER_CLASS > 0:
    left_files = left_files[:MAX_FILES_PER_CLASS]
    rest_files = rest_files[:MAX_FILES_PER_CLASS]
    right_files = right_files[:MAX_FILES_PER_CLASS]

print(f"Found EDFs total: {len(all_edfs)}")
print(f"→ left={len(left_files)} | rest={len(rest_files)} | right(neg)={len(right_files)}")
print("Sample left:", left_files[:3])
print("Sample rest:", rest_files[:3])

assert len(left_files) > 0 and len(rest_files) > 0, "Need at least one EDF in both left and rest."

In [None]:
#@title 4) EDF loader & preprocessing (MNE)
def load_edf_preprocess(path, band=(BANDPASS_LOWER, BANDPASS_UPPER), notch=NOTCH_FREQ if APPLY_NOTCH else None, target_hz=RESAMPLE_HZ):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    eeg_picks = mne.pick_types(raw.info, meg=False, eeg=True, eog=False, ecg=False, stim=False, misc=False)
    if len(eeg_picks) == 0:
        eeg_picks = list(range(len(raw.ch_names)))
    raw.pick(eeg_picks)
    raw.filter(band[0], band[1], fir_design='firwin', verbose=False)
    if notch is not None and notch > 0:
        raw.notch_filter(freqs=[notch], verbose=False)
    if target_hz is not None:
        raw.resample(target_hz, npad="auto", verbose=False)
    data = raw.get_data().astype(np.float32)  # (C, T)
    return data, target_hz, raw.ch_names

# quick check
d, fs, ch = load_edf_preprocess(left_files[0])
print("Loaded example:", os.path.basename(left_files[0]), "| shape:", d.shape, "| fs:", fs, "| EEG ch:", len(ch))

In [None]:
#@title 5) Sliding window epoching
def make_windows(data, fs, window_sec=2.0, overlap=0.5):
    C, T = data.shape
    W = int(window_sec * fs)
    step = max(1, int(W * (1 - overlap)))
    out = []
    for start in range(0, max(1, T - W + 1), step):
        out.append(data[:, start:start+W])
    return np.stack(out).astype(np.float32) if out else np.zeros((0, C, W), dtype=np.float32)

print("Windows shape example:", make_windows(d, fs, WINDOW_SEC, WINDOW_OVERLAP).shape)

In [None]:
#@title 6) Dataset
class EEGDriveDataset(Dataset):
    def __init__(self, left_paths, rest_paths, right_paths=None):
        self.items = []
        for p in left_paths:
            x, fs, _ = load_edf_preprocess(p)
            for win in make_windows(x, fs, WINDOW_SEC, WINDOW_OVERLAP):
                self.items.append((win, 1))
        for p in rest_paths:
            x, fs, _ = load_edf_preprocess(p)
            for win in make_windows(x, fs, WINDOW_SEC, WINDOW_OVERLAP):
                self.items.append((win, 0))
        if right_paths:
            for p in right_paths:
                x, fs, _ = load_edf_preprocess(p)
                for win in make_windows(x, fs, WINDOW_SEC, WINDOW_OVERLAP):
                    self.items.append((win, 0))
        random.shuffle(self.items)
        if not self.items:
            raise RuntimeError("Empty dataset.")
        self.C, self.W = self.items[0][0].shape
        print(f"Dataset: {len(self.items)} windows | C={self.C} | W={self.W}")

    def __len__(self): return len(self.items)
    def __getitem__(self, idx):
        x, y = self.items[idx]
        return torch.from_numpy(x), torch.tensor([y], dtype=torch.float32)

full_ds = EEGDriveDataset(left_files, rest_files, right_files if USE_RIGHT_AS_NEGATIVE else None)
n_total = len(full_ds)
n_test = int(TEST_SPLIT * n_total)
n_val = int(VAL_SPLIT * (n_total - n_test))
n_train = n_total - n_val - n_test
from torch.utils.data import random_split
train_ds, val_ds, test_ds = random_split(full_ds, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(RANDOM_SEED))
print(f"Splits -> train:{len(train_ds)}  val:{len(val_ds)}  test:{len(test_ds)}")

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, pin_memory=True)

EEG_CHANNELS = full_ds.C
WINDOW_SAMPLES = full_ds.W

In [None]:
#@title 7) Model
import math
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dim_feedforward, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, dim_feedforward),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, embed_dim),
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = dropout
    def forward(self, x):
        y, _ = self.attn(x, x, x)
        y = F.dropout(y, p=self.dropout, training=self.training)
        x = self.norm1(x + y)
        y = self.mlp(x)
        y = F.dropout(y, p=self.dropout, training=self.training)
        x = self.norm2(x + y)
        return x

class EEGClassificationModel(nn.Module):
    def __init__(self, eeg_channels, dropout=0.1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(eeg_channels, eeg_channels, kernel_size=11, stride=1, padding=5, bias=False),
            nn.BatchNorm1d(eeg_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Conv1d(eeg_channels, eeg_channels * 2, kernel_size=11, stride=1, padding=5, bias=False),
            nn.BatchNorm1d(eeg_channels * 2),
            nn.ReLU(inplace=True),
        )
        embed_dim = eeg_channels * 2
        self.posenc = PositionalEncoding(embed_dim, dropout=dropout)
        self.tr1 = TransformerBlock(embed_dim, num_heads=4, dim_feedforward=max(16, eeg_channels // 2), dropout=dropout)
        self.tr2 = TransformerBlock(embed_dim, num_heads=4, dim_feedforward=max(16, eeg_channels // 2), dropout=dropout)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, max(16, eeg_channels // 2)),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(max(16, eeg_channels // 2), 1),
        )
    def forward(self, x):
        x = self.conv(x)
        x = x.permute(0, 2, 1)
        x = self.posenc(x)
        x = self.tr1(x); x = self.tr2(x)
        x = x.mean(dim=1)
        x = self.mlp(x)
        return x

In [None]:
#@title 8) Lightning wrapper
class LitEEG(pl.LightningModule):
    def __init__(self, eeg_channels, lr=1e-3, dropout=0.1):
        super().__init__()
        self.save_hyperparameters()
        self.model = EEGClassificationModel(eeg_channels=eeg_channels, dropout=dropout)
        self.lr = lr
    def forward(self, x): return self.model(x)
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-4)
        sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[int(MAX_EPOCHS*0.5), int(MAX_EPOCHS*0.75)], gamma=0.1)
        return {"optimizer": opt, "lr_scheduler": sch}
    def _common(self, batch, stage):
        x, y = batch
        logits = self(x).squeeze(1)
        loss = F.binary_cross_entropy_with_logits(logits, y.squeeze(1))
        preds = (torch.sigmoid(logits) > 0.5).int()
        acc = (preds == y.int().squeeze(1)).float().mean()
        self.log(f"{stage}_loss", loss, prog_bar=True, on_epoch=True)
        self.log(f"{stage}_acc", acc, prog_bar=True, on_epoch=True)
        return loss
    def training_step(self, b, i): return self._common(b, "train")
    def validation_step(self, b, i): return self._common(b, "val")
    def test_step(self, b, i): return self._common(b, "test")

In [None]:
#@title 9) Train
lit_model = LitEEG(eeg_channels=EEG_CHANNELS, lr=LR, dropout=DROPOUT)
logger_tb = TensorBoardLogger(save_dir=LOGS_DIR, name="tb")
logger_csv = CSVLogger(save_dir=LOGS_DIR, name="csv")
ckpt = ModelCheckpoint(dirpath=CHECKPOINT_DIR, monitor="val_acc", mode="max", save_top_k=1, filename="best")
lr_mon = LearningRateMonitor(logging_interval='epoch')
es = EarlyStopping(monitor="val_acc", mode="max", patience=5, min_delta=0.0)

trainer = pl.Trainer(
    accelerator="gpu" if DEVICE=="cuda" else "cpu",
    max_epochs=MAX_EPOCHS,
    logger=[logger_tb, logger_csv],
    callbacks=[ckpt, lr_mon, es],
    deterministic=True,
    log_every_n_steps=10,
)

trainer.fit(lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
print("Best checkpoint:", ckpt.best_model_path)

In [None]:
#@title 10) Test
best = LitEEG.load_from_checkpoint(ckpt.best_model_path, eeg_channels=EEG_CHANNELS, lr=LR, dropout=DROPOUT) if ckpt.best_model_path else lit_model
res = pl.Trainer(accelerator="gpu" if DEVICE=="cuda" else "cpu", logger=False).test(best, dataloaders=test_loader)
print("Test metrics:", res)

In [None]:
#@title 11) Export
EXPORT_DIR = "exports"; os.makedirs(EXPORT_DIR, exist_ok=True)
export_path = os.path.join(EXPORT_DIR, "eeg_rest_vs_left.pt")
scripted = torch.jit.script(best.model.cpu())
scripted.save(export_path)
print("Saved:", export_path)