# **I Know What You Will Do: Forecasting Motor Behaviour from EEG Time Series**

**[Brainhack Rome 2025](https://brainhackrome.github.io/) - Project #3**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/matteo-d-m/brainhack-rome-forecasting/blob/main/eeg-forecasting-notebook.ipynb)

In [None]:
colab = True            # put False if you work locally

import json
import os
import re
import pickle
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

if colab:
    !pip install mne
import mne

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

if colab:
    from google.colab import drive

## Load & Segment Data

The data are stored as `.npy` files, so we read them into `NumPy` arrays. Their are continuous, so we must segment them into windows of interest. To begin, we choose the two seconds before event `LEDon` as window of interest.

In [None]:
def windows(folder, filename, sampling_rate, past_samples, future_samples):
    """
    Function to extract EEG data windows based on marker events.

    Parameters:
    folder (str): Directory containing EEG files.
    filename (str): JSON file containing marker data.

    Returns:
    all_sequences (list): List of tuples containing past and future EEG data windows.
    """

    eeg_dir = folder
    markers_file = filename

    # Load the marker file and extract columns and data
    with open(markers_file, 'r') as f:
        marker_data = json.load(f)
    columns = marker_data["columns"]
    data_rows = marker_data["data"]

    # Helper function to get the index of a column by name
    def col_idx(col_name):
        return columns.index(col_name)

    all_sequences = []

    # Iterate through EEG files in the specified directory
    for eeg_filename in os.listdir(eeg_dir):
        if eeg_filename.endswith('.npy'):  # Process only .npy files
            eeg_path = os.path.join(eeg_dir, eeg_filename)  # Full path to the EEG file

            # Use regex to extract the run number from the filename
            m = re.search(r'_S(\d+)', eeg_filename)
            if m:
                run = int(m.group(1))  # Extract run number
            else:
                print(f"Could not extract run number from file name {eeg_filename}.")
                continue

            eeg_data = np.load(eeg_path)

            # Filter marker rows corresponding to the current run
            run_rows = [row for row in data_rows if int(row[col_idx("Run")]) == run]

            # Process each marker row for the current run
            for row in run_rows:
                start_time_sec = row[col_idx("StartTime")]  # Get the start time of the event
                if start_time_sec is None:  # Skip rows with no start time
                    continue

                led_on_sec = start_time_sec
                led_on_sample = int(led_on_sec * sampling_rate)

                # Extract past and future windows around the event
                if led_on_sample - past_samples >= 0 and led_on_sample + future_samples <= eeg_data.shape[1]:
                    past_window = eeg_data[:, led_on_sample - past_samples : led_on_sample]
                    future_window = eeg_data[:, led_on_sample : led_on_sample + future_samples]

                    all_sequences.append((past_window, future_window))
                else:
                    print(f"Skipping trial in run {run}: LEDOn sample {led_on_sample} out of bounds.")

    return all_sequences  # Return the list of (past, future) pairs

In [None]:
if colab:
    drive.mount('/content/drive')
    data_dir = Path("/content/drive/MyDrive/Brainhack/Dataset")
else:
    data_dir = Path("insert your local directory")


patients_of_interest = [1,2,3,4,5,6]
all_sequences = []
for patient in patients_of_interest:
  print(f"PATIENT P{patient}")
  patient_dir = data_dir / f"P{patient}"
  all_patient_sequences = windows(folder=patient_dir,
                                  filename=patient_dir / f"P{patient}_AllLifts.json",
                                  sampling_rate=500,
                                  past_samples=1000,
                                  future_samples=1000)
  all_sequences += all_patient_sequences

## Check that all data samples have the expected shape

In [None]:
NUMBER_OF_CHANNELS = 14
NUMBER_OF_TIMEPOINTS = 1000

shape_mismatches = 0
for sample_number, sample in enumerate(all_sequences):
  past, future = sample
  if past.shape != future.shape:
      print(f"Sample {sample_number} has shape {past.shape} instead of {future.shape}")
      shape_mismatches += 1
if shape_mismatches == 0:
      print("Everything OK")

## Create PyTorch dataset and check that everything went well

In [None]:
class DatasetFromList(Dataset):
    """Class to create a PyTorch Dataset from a list of data arrays"""


    def __init__(self, data, normalise=False):
        """Class constructor (i.e., it actually creates the Dataset object)

        Parameters:
        data -- a list of data arrays (type: list[np.array])
        normalise -- whether to normalise the samples (type: bool) (default: False)
        """

        self.data = data
        self.normalise = normalise

    def __len__(self):
        """Returns the length of the created dataset.

        Usage:
        len(dataset_name)
        """

        return len(self.data)

    def __getitem__(self, index):
        """Returns the data sample located at a given index

        Parameters:
        index -- the index of interest (type: int)

        Returns:
        sample -- the data sample located at 'index' (type: torch.tensor)

        Usage:
        dataset_name[index]
        """

        sample = torch.as_tensor(self.data[index][0], dtype=torch.float32)
        label = torch.as_tensor(self.data[index][1], dtype=torch.float32)
        # sample = sample.unsqueeze(0)
        if self.normalise:
            sample = torch.nn.functional.normalize(sample)
        return sample, label

In [None]:
dataset = DatasetFromList(data=all_sequences,
                          normalise=False)

print(f"The PyTorch dataset contains {len(dataset)} samples, {'as expected' if len(dataset) == len(all_sequences) else 'unexpectedly'}")

In [None]:
expected_shape = (NUMBER_OF_CHANNELS, NUMBER_OF_TIMEPOINTS)

shape_mismatches = 0
for sample_number, sample in enumerate(dataset):
  past, future = sample
  if past.shape != expected_shape and future.shape != expected_shape:
      print(f"Sample {sample_number} has shape {past.shape} in the past and {future.shape} in the future, instead of {expected_shape} in both")
if shape_mismatches == 0:
  print("Everything OK")

In [None]:
BATCH_SIZE = int(len(dataset) / 4)
training_loader = DataLoader(dataset=dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True,
                             num_workers=2,
                             pin_memory=True)
validation_loader = DataLoader(dataset=dataset,
                               batch_size=BATCH_SIZE,
                               shuffle=True,
                               num_workers=2,
                               pin_memory=True)

## Enter: WaveNet

In [None]:
class CausalConv1d(nn.Conv1d):
    """1D Causal convolution layer that pads inputs to avoid using future data."""
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        padding = (kernel_size - 1) * dilation
        super().__init__(in_channels, out_channels, kernel_size,
                         padding=padding, dilation=dilation, **kwargs)

    def forward(self, x):
        out = super().forward(x)
        if self.padding[0] > 0:
            out = out[:, :, :-self.padding[0]]
        return out

class WaveNetForecaster(nn.Module):
    def __init__(self, in_channels=14, residual_channels=32, skip_channels=64,
                 kernel_size=2, num_layers=8):
        """
        WaveNet-based forecaster model.
        Args:
            in_channels: Number of input channels.
            residual_channels: Number of channels in the residual layers.
            skip_channels: Number of channels in the skip connections.
            kernel_size: Size of the convolutional kernel.
            num_layers: Number of dilated causal convolution layers.
        """
        super(WaveNetForecaster, self).__init__()
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels

        self.input_conv = nn.Conv1d(in_channels, residual_channels, kernel_size=1)

        # lists to hold the layers for each dilated block
        self.filter_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()

        # exponentially increasing dilation rates
        for i in range(num_layers):
            dilation = 2 ** i
            self.filter_convs.append(CausalConv1d(residual_channels, residual_channels,
                                                  kernel_size, dilation=dilation))
            self.gate_convs.append(CausalConv1d(residual_channels, residual_channels,
                                                kernel_size, dilation=dilation))

            self.residual_convs.append(nn.Conv1d(residual_channels, residual_channels, kernel_size=1))
            self.skip_convs.append(nn.Conv1d(residual_channels, skip_channels, kernel_size=1))

        self.output_conv1 = nn.Conv1d(skip_channels, skip_channels, kernel_size=1)
        self.output_conv2 = nn.Conv1d(skip_channels, in_channels, kernel_size=1)

    def forward(self, x):
        """
        Forward pass of the WaveNet model.
        Args:
            x: Input tensor of shape [batch, in_channels, input_length].
        Returns:
            Output tensor of shape [batch, in_channels, input_length].
        """
        x = self.input_conv(x)
        skip_sum = None

        for filter_conv, gate_conv, res_conv, skip_conv in zip(
                self.filter_convs, self.gate_convs, self.residual_convs, self.skip_convs):

            filt = torch.tanh(filter_conv(x))
            gate = torch.sigmoid(gate_conv(x))
            out = filt * gate

            skip_out = skip_conv(out)
            skip_sum = skip_out if skip_sum is None else (skip_sum + skip_out)

            x = res_conv(out) + x

        out = torch.relu(skip_sum)
        out = torch.relu(self.output_conv1(out))
        out = self.output_conv2(out)

        return out


In [None]:
from torch import optim

def train_and_validate(model, device, combination, epochs, dataloaders):
  """Performs model training and validation.

  Parameters:
  model -- a PyTorch model instance
  device -- where to run computations (torch device object)
  combination -- a combination of hyperparameter values (namedtuple)
  epochs -- number of model runs (int)
  dataloaders -- PyTorch dataloader instances (tuple)
  """

  some_loss = nn.MSELoss()
  optimizer = optim.Adam(model.parameters(),
                         lr=1e-2,                   # optimize
                         weight_decay=1e-5)         # optimize
  training_loss_log = []
  validation_loss_log = []

  for epoch in range(epochs):
    training_loss = []
    model.train()

    for batch in dataloaders[0]:
      past = batch[0].to(device)
      true_future = batch[1].to(device)
      generated_future = model(past)
      loss = some_loss(generated_future, true_future)
      model.zero_grad()
      loss.backward()
      optimizer.step()
      loss = loss.detach().cpu().numpy()
      training_loss.append(loss)
    validation_loss = []
    model.eval()
    with torch.no_grad():
      for batch in dataloaders[1]:
        past = batch[0].to(device)
        true_future = batch[1].to(device)
        generated_future = model(past)
        loss = some_loss(generated_future, true_future)
        loss = loss.detach().cpu().numpy()
        validation_loss.append(loss)
    training_loss = np.mean(training_loss)
    training_loss_log.append(training_loss)
    validation_loss = np.mean(validation_loss)
    validation_loss_log.append(validation_loss)
    print(f"EPOCH {epoch+1} - TRAINING LOSS: {training_loss: .2f} - VALIDATION LOSS: {validation_loss: .2f}")
    if epoch == epochs-1:
      print("Finished")
  torch.save(model.state_dict(), 'model_parameters.torch')
  return training_loss_log, validation_loss_log

In [None]:
device = torch.device("cuda") if torch.cuda.is_available else torch.device("CPU")
print(f"Device is: {device}")

In [None]:
torch.manual_seed(0)

model = WaveNetForecaster()
model.to(device)

losses = train_and_validate(model=model,
                            device=device,
                            combination=None,
                            epochs=500,
                            dataloaders=(training_loader, validation_loader))