<a href="https://colab.research.google.com/github/nhnain/eegchallenge/blob/main/Challenge1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#1. Preprocessing

In [1]:
!pip install braindecode
!pip install eegdash

Collecting braindecode
  Downloading braindecode-1.2.0-py3-none-any.whl.metadata (7.1 kB)
Collecting mne>=1.10.0 (from braindecode)
  Downloading mne-1.10.1-py3-none-any.whl.metadata (20 kB)
Collecting mne_bids>=0.16 (from braindecode)
  Downloading mne_bids-0.17.0-py3-none-any.whl.metadata (7.3 kB)
Collecting skorch>=1.2.0 (from braindecode)
  Downloading skorch-1.2.0-py3-none-any.whl.metadata (11 kB)
Collecting torchinfo~=1.8 (from braindecode)
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting wfdb (from braindecode)
  Downloading wfdb-4.3.0-py3-none-any.whl.metadata (3.8 kB)
Collecting linear_attention_transformer (from braindecode)
  Downloading linear_attention_transformer-0.19.1-py3-none-any.whl.metadata (787 bytes)
Collecting docstring_inheritance (from braindecode)
  Downloading docstring_inheritance-2.2.2-py3-none-any.whl.metadata (11 kB)
Collecting axial-positional-embedding (from linear_attention_transformer->braindecode)
  Downloading axial_position

In [9]:
from pathlib import Path
import math
import os
import random
from joblib import Parallel, delayed

import torch
from torch.utils.data import DataLoader
from torch import optim
from torch.nn.functional import l1_loss
from braindecode.preprocessing import preprocess, Preprocessor, create_fixed_length_windows
from braindecode.datasets.base import EEGWindowsDataset, BaseConcatDataset, BaseDataset
from braindecode.models import EEGNeX
from eegdash import EEGChallengeDataset
from typing import Optional
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm
from eegdash.dataset import EEGChallengeDataset
from eegdash.hbn.windows import annotate_trials_with_target, add_aux_anchors, add_extras_columns, keep_only_recordings_with

In [3]:
# Identify whether a CUDA-enabled GPU is available
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    msg ='CUDA-enabled GPU found. Training should be faster.'
else:
    msg = (
        "No GPU found. Training will be carried out on CPU, which might be "
        "slower.\n\nIf running on Google Colab, you can request a GPU runtime by"
        " clicking\n`Runtime/Change runtime type` in the top bar menu, then "
        "selecting \'T4 GPU\'\nunder \'Hardware accelerator\'."
    )
print(msg)

CUDA-enabled GPU found. Training should be faster.


In [5]:
release_list = ['R1', 'R2', 'R3', 'R4', 'R5', 'R6',
                'R7', 'R8', 'R9', 'R10', 'R11']
train_set = []
valid_set = []
test_set = []

DATA_DIR = Path('data')
DATA_DIR.mkdir(parents=True, exist_ok=True)


In [10]:
for release in tqdm(release_list):
  dataset_ccd = EEGChallengeDataset(
      task = 'contrastChangeDetection',
      release = release,
      cache_dir = DATA_DIR,
      mini = True
  )
  raws = Parallel(n_jobs=-1)(
      delayed(lambda d: d.raw)(d) for d in dataset_ccd.datasets
  )

  EPOCH_LEN_S = 2.0
  SFREQ = 100

  transformation_offline = [
      Preprocessor(
          annotate_trials_with_target,
          target_field = 'rt_from_stimulus', epoch_length =  EPOCH_LEN_S,
          require_stimulus = True,
          require_response = True,
          apply_on_array = True,
      ),
      Preprocessor(add_aux_anchors, apply_on_array=False),
  ]
  preprocess (dataset_ccd, transformation_offline, n_jobs=1)

  ANCHOR = 'stimulus_anchor'
  SHIFT_AFTER_STIM = 0.5
  WINDOW_LEN = 2.0

  dataset = keep_only_recordings_with(ANCHOR, dataset_ccd)

  single_windows = create_windows_from_events(
      dataset,
      mapping={ANCHOR: 0},
      trial_start_offset_samples = int(SHIFT_AFTER_STIM * SFREQ),
      trial_stop_offset_samples = INT((SHIFT_AFTER_STIM + WINDOW_LEN) * SFREQ),
      window_size_samples = int(EPOCH_LEN_S * SFREQ),
      window_stride_samples = SFREQ,
      preload = True,
  )

  single_windows = add_extras_columns(
      single_windows,
      dataset,
      desc = ANCHOR, keys = ("target", "rt_from_stimulus", "rt_from_trialstart",
            "stimulus_onset", "response_onset", "correct", "response_type"))

  meta_information = single_windows.get_metadata()

  subjects = meta_information['subject'].unique()
  sub_rm = ["NDARWV769JM7", "NDARME789TD2", "NDARUA442ZVF", "NDARJP304NK1",
            "NDARTY128YLU", "NDARDW550GU6", "NDARLD243KRE", "NDARUJ292JXV", "NDARBA381JGH"]
  subjects = [s for s in subjects if s not in sub_rm]

  subject_split = single_windows.split('subject')

  if release != 'R5':
    for s in subject_split:
      train_set.append(subject_split[s])
  else:
    for s in subject_split:
      valid_set.append(subject_split[s])


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Reading 0 ... 23499  =      0.000 ...   234.990 secs...


  0%|          | 0/11 [11:19<?, ?it/s]


AttributeError: 'numpy.ndarray' object has no attribute 'filenames'

In [None]:
train_set = BaseConcatDataset(train_set)
valid_set = BaseConcatDataset(valid_set)

print(f'Train set: {len(train_set)}')
print(f'Valid set: {len(valid_set)}')

#Model training

In [None]:
lr = 1e-3
weight_decay = 1e-5
n-epochs = 100
early_stopping_patience = 50

In [None]:
def train_one_epoch(
    dataloader: DataLoader,
    model: Module,
    loss_fn,
    optimizer,
    scheduler: Optional[LRScheduler],
    epoch: int,
    device,
    print_batch_stats: bool = True,):
  model.train()

  total_loss = 0
  sum_sq_err = 0
  n_samples = 0

  progress_bar = tqdm(
      enumerate(dataloader),
      total=len(dataloader),
      disable=not print_batch_stats)
  for batch_idx, batch in progress_bar:
    X, y = batch[0], batch[1]
    X, y = X.to(device).float(), y.to(device).float()

    optimizer.zero_grad(set_to_none=True)
    preds = model(X)
    loss = loss_fn(preds, y)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

    preds_flat = preds.detach().view(-1)
    y_flat = y.detach().view(-1)
    sum_sq_err += torch.sum((preds_flat - y_flat)**2).item()
    n_samples += y_flat.numel()

    if print_batch_stats:
      running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
      progress_bar.set_description(
          f"Epoch {epoch}, Batch {batch_idx + 1}/{len(dataloader)},"
          f"Loss: {loss.itemt():.6f}, RMSE: {running_rmse:.6f}"
      )
  if scheduler is not None:
    scheduler.step()
  avg_loss = total_loss/len(dataloader)
  rmse = (sum_sq_err/max(n_samples, 1))** 0.5
  return avg_loss, rmse

In [None]:
with torch.inference_mode():
  def valid_model(
      dataloader: DataLoader,
      model: Module,
      loss_fn,
      device,
      print_batch_stats: bool = True):

    model.eval()

    total_loss = 0
    sum_sq_err = 0
    n_batches = len(dataloader)
    n_samples = 0

    iterator = tqdm(
        enumerate(dataloader),
        total=len(dataloader),
        disable=not print_batch_stats)

    for batch_idx, batch in iterator:
      X, y = batch[0], batch[1]
      X, y = X.to(device).float(), y.to(device).float()

      preds = model(X)
      batch_loss = loss_fn(preds, y).item()
      total_loss += batch_loss

      preds_flat = preds.detach().view(-1)
      y_flat = y.detach().view(-1)
      sum_sq_err += torch.sum((preds_flat - y_flat)**2).item()
      n_samples += y_flat.numel()

      if print_batch_stats:
        running_rmse = (sume_sq_err/max(n_samples,1)) ** 0.5
        iterator.set_description(
            f"Val Batch {batch_idx + 1}/{len(dataloader)},"
            f"Loss: {batch_loss:.6f}, RMSE: {running_rmse:.6f}"
        )
    print(f"Val RMSE: {rmse:.6f}, Val Loss: {avg_loss:.6f}\n")
    return avg_loss, rmse

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = n_epochs -1)
loss_fn = torch.nn.MSELoss()

patience = 5
min_delta = 1e-4
best_rmse = float('inf')
epochs_no_improve = 0
best_state, best_epoch = None, None

for epoch in range(1, n_epochs+1):
  print(f'Epoch {epoch}/{n_epochs}: ', end='')

  train_loss, train_rmse = train_one_epoch(
      train_loader, model, loss_fn, optimizer, scheduler, epoch, device
  )
  valid_loss, valid_rmse = valid_model(
      valid_loader, model, loss_fn, device
  )
  print(
      f'Train RMSE: {train_rmse:.6f}, '
      f'Average Train Loss: {train_loss:.6f}, '
      f'Val RMSE: {valid_rmse:.6f}, '
      f'Average Valid Loss: {valid_loss:.6f}'
  )
if best_state is not None:
  model.load_state_dict(best_state)

In [None]:
torch.save(model.state_dict(), 'model.pth')