# **I Know What You Will Do: Forecasting Motor Behaviour from EEG Time Series**

**[Brainhack Rome 2025](https://brainhackrome.github.io/) - Project #3**

Matteo De Matola<sup>1</sup><sup>°</sup>, Anna Notaro<sup>2</sup><sup>°</sup>, Emanuele Di Giorgio<sup>3</sup>, Matteo Mancini<sup>4</sup>

<sup>1</sup> Center for Mind/Brain Sciences, University of Trento

<sup>2</sup> Bocconi University Milan, Bocconi AI & Neuroscience Student Association

<sup>3</sup> LUMSA University Rome

<sup>4</sup> Centro Ricerche Enrico Fermi (CREF) Rome

<sup>°</sup> equal contributions

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/matteo-d-m/brainhack-rome-forecasting/blob/main/eeg-forecasting-notebook.ipynb)

In [None]:
colab = True            # put False if you work locally

import json
import os
import re
import pickle
from pathlib import Path
from itertools import product
from collections import namedtuple

import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader

if colab:
    from google.colab import drive

## Load & Segment Data

The data are stored as `.npy` files, so we read them into `NumPy` arrays. Their are continuous, so we must segment them into windows of interest, which are centered on event `LEDon`. The resulting epochs are further segmented into a _past_ window (the one second before `LEDon`) and a _future_ window (the one second after `LEDon`).

In [None]:
def windows(folder, filename, sampling_rate, past_samples, future_samples, quantize=False):
    """
    Function to extract EEG data windows based on marker events.

    Parameters:
    folder (str): Directory containing EEG files.
    filename (str): JSON file containing marker data.

    Returns:
    all_sequences (list): List of tuples containing past and future EEG data windows.
    """

    eeg_dir = folder
    markers_file = filename

    # Load the marker file and extract columns and data
    with open(markers_file, 'r') as f:
        marker_data = json.load(f)
    columns = marker_data["columns"]
    data_rows = marker_data["data"]

    # Helper function to get the index of a column by name
    def col_idx(col_name):
        return columns.index(col_name)

    all_sequences = []

    # Iterate through EEG files in the specified directory
    for eeg_filename in os.listdir(eeg_dir):
        if eeg_filename.endswith('.npy'):  # Process only .npy files
            eeg_path = os.path.join(eeg_dir, eeg_filename)  # Full path to the EEG file

            # Use regex to extract the run number from the filename
            m = re.search(r'_S(\d+)', eeg_filename)
            if m:
                run = int(m.group(1))  # Extract run number
            else:
                print(f"Could not extract run number from file name {eeg_filename}.")
                continue

            eeg_data = np.load(eeg_path)

            # Filter marker rows corresponding to the current run
            run_rows = [row for row in data_rows if int(row[col_idx("Run")]) == run]

            # Process each marker row for the current run
            for row in run_rows:
                start_time_sec = row[col_idx("StartTime")]  # Get the start time of the event
                if start_time_sec is None:  # Skip rows with no start time
                    continue

                led_on_sec = start_time_sec
                led_on_sample = int(led_on_sec * sampling_rate)

                # Extract past and future windows around the event. Quantize (i.e., compress in [-1,1] if necessary)
                if led_on_sample - past_samples >= 0 and led_on_sample + future_samples <= eeg_data.shape[1]:
                    past_window = eeg_data[:, led_on_sample - past_samples : led_on_sample]
                    future_window = eeg_data[:, led_on_sample : led_on_sample + future_samples]
                    if quantize:
                      mu = 255
                      past_window = np.sign(past_window) * np.log(1 + mu * np.abs(past_window)) / np.log(mu + 1)
                      future_window = np.sign(future_window) * np.log(1 + mu * np.abs(future_window)) / np.log(mu + 1)
                    all_sequences.append((past_window, future_window))
                else:
                    print(f"Skipping trial in run {run}: LEDOn sample {led_on_sample} out of bounds.")

    return all_sequences  # Return the list of (past, future) pairs

In [None]:
if colab:
    drive.mount('/content/drive')
    data_dir = Path("/content/drive/MyDrive/Brainhack/Dataset")
else:
    data_dir = Path("insert your local directory")


""" We use 8 patients for training, two for validation and two for testing """
training_patients = [1,2,3,4,5,6,7,8]
validation_patients = [9,10]
test_patients = [11,12]

datasets = {"training": None,
            "validation": None,
            "test": None}

for dataset_type, patients_list in zip(datasets.keys(), [training_patients, validation_patients, test_patients]):
  all_data = []
  for patient in patients_list:
    print(f"PATIENT P{patient}")
    patient_dir = data_dir / f"P{patient}"
    all_patient_data = windows(folder=patient_dir,
                               filename=patient_dir / f"P{patient}_AllLifts.json",
                               sampling_rate=500,
                               past_samples=1000,
                               future_samples=1000,
                               quantize=True)
    all_data += all_patient_data
  datasets[dataset_type] = all_data

## Check that all data samples have the expected shape

In [None]:
NUMBER_OF_CHANNELS = 14
NUMBER_OF_TIMEPOINTS = 1000

expected_shape = (NUMBER_OF_CHANNELS, NUMBER_OF_TIMEPOINTS)

for dataset in datasets.values():
  shape_mismatches = 0
  for sample_number, sample in enumerate(dataset):
    past, future = sample
    if past.shape != expected_shape and future.shape != expected_shape:
        print(f"Sample {sample_number} has shape {past.shape} in the past and {future.shape} in the future, instead of {expected_shape} in both")
  if shape_mismatches == 0:
    print("Everything OK")

## Create PyTorch datasets and check that everything went well

In [None]:
class DatasetFromList(Dataset):
    """Class to create a PyTorch Dataset from a list of data arrays"""


    def __init__(self, data, normalise=False):
        """Class constructor (i.e., it actually creates the Dataset object)

        Parameters:
        data -- a list of data arrays (type: list[np.array])
        normalise -- whether to normalise the samples (type: bool) (default: False)
        """

        self.data = data
        self.normalise = normalise

    def __len__(self):
        """Returns the length of the created dataset.

        Usage:
        len(dataset_name)
        """

        return len(self.data)

    def __getitem__(self, index):
        """Returns the data sample located at a given index

        Parameters:
        index -- the index of interest (type: int)

        Returns:
        sample -- the data sample located at 'index' (type: torch.tensor)

        Usage:
        dataset_name[index]
        """

        sample = torch.as_tensor(self.data[index][0], dtype=torch.float32)
        label = torch.as_tensor(self.data[index][1], dtype=torch.float32)
        if self.normalise:
            sample = torch.nn.functional.normalize(sample)
        return sample, label

In [None]:
training_dataset = DatasetFromList(data=datasets["training"],
                                   normalise=False)
validation_dataset = DatasetFromList(data=datasets["validation"],
                                     normalise=False)
test_dataset = DatasetFromList(data=datasets["test"],
                               normalise=False)

for dataset_type, dataset in zip(datasets.keys(), [training_dataset, validation_dataset, test_dataset]):
  print(f"The PyTorch dataset contains {len(dataset)} samples, {'as expected' if len(dataset) == len(datasets[dataset_type]) else 'unexpectedly'}")

## Create DataLoaders

In [None]:
BATCH_SIZE = 64
training_loader = DataLoader(dataset=training_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True,
                             num_workers=2,
                             pin_memory=True)
validation_loader = DataLoader(dataset=validation_dataset,
                               batch_size=BATCH_SIZE,
                               shuffle=True,
                               num_workers=2,
                               pin_memory=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True,
                         num_workers=2,
                         pin_memory=True)

## Enter: WaveNet

In [None]:
class CausalConv1d(nn.Conv1d):
    """1D Causal convolution layer that pads inputs to avoid using future data."""
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        padding = (kernel_size - 1) * dilation
        super().__init__(in_channels, out_channels, kernel_size,
                         padding=padding, dilation=dilation, **kwargs)

    def forward(self, x):
        out = super().forward(x)
        if self.padding[0] > 0:
            out = out[:, :, :-self.padding[0]]
        return out

class WaveNetForecaster(nn.Module):
    def __init__(self, 
                 channels_to_forecast=14,
                 residual_channels=32,
                 skip_channels=64,
                 kernel_size=2,
                 num_layers=8):
        """ WaveNet-based forecaster model.
        Args:
            in_channels: Number of input channels.
            residual_channels: Number of channels in the residual layers.
            skip_channels: Number of channels in the skip connections.
            kernel_size: Size of the convolutional kernel.
            num_layers: Number of dilated causal convolution layers.
        """
        super(WaveNetForecaster, self).__init__()
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels

        self.input_conv = nn.Conv1d(in_channels=channels_to_forecast,
                                    out_channels=residual_channels,
                                    kernel_size=1)

        # lists to hold the layers for each dilated block
        self.filter_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()

        for layer_number in range(num_layers):
            """Exponentially increasing dilation rate"""
            dilation = 2 ** layer_number
            self.filter_convs.append(CausalConv1d(in_channels=residual_channels,
                                                  out_channels=residual_channels,
                                                  kernel_size=kernel_size,
                                                  dilation=dilation))
            self.gate_convs.append(CausalConv1d(in_channels=residual_channels,
                                                out_channels=residual_channels,
                                                kernel_size=kernel_size,
                                                dilation=dilation))

            self.residual_convs.append(nn.Conv1d(in_channels=residual_channels,
                                                 out_channels=residual_channels,
                                                 kernel_size=1))
            self.skip_convs.append(nn.Conv1d(in_channels=residual_channels,
                                             out_channels=skip_channels,
                                             kernel_size=1))

        self.output_conv1 = nn.Conv1d(in_channels=skip_channels,
                                      out_channels=skip_channels,
                                      kernel_size=1)
        self.output_conv2 = nn.Conv1d(in_channels=skip_channels,
                                      out_channels=channels_to_forecast,
                                      kernel_size=1)

    def forward(self, x):
        """
        Forward pass of the WaveNet model.
        Args:
            x: Input tensor of shape [batch, in_channels, input_length].
        Returns:
            Output tensor of shape [batch, in_channels, input_length].
        """
        x = self.input_conv(x)
        skip_sum = None

        for filter_conv, gate_conv, res_conv, skip_conv in zip(self.filter_convs, self.gate_convs, self.residual_convs, self.skip_convs):

            filt = torch.tanh(filter_conv(x))
            gate = torch.sigmoid(gate_conv(x))
            out = filt * gate

            skip_out = skip_conv(out)
            skip_sum = skip_out if skip_sum is None else (skip_sum + skip_out)

            x = res_conv(out) + x

        out = torch.relu(skip_sum)
        out = torch.relu(self.output_conv1(out))
        out = self.output_conv2(out)

        return out

### Define functions for training, validation and hyperparameter tuning

In [None]:
def combine(hyperparameters):
  """Constructs combinations of hyperparameter values.

  Parameters:
  hyperparameters -- map between hyperparameter names and candidate values (dict[list])

  Returns:
  candidates -- combinations of hyperparameters values (list[namedtuple])
  """

  candidate = namedtuple("combination", hyperparameters.keys())
  candidates = []
  for combination in product(*hyperparameters.values()):
    candidates.append(candidate(*combination))
  return candidates

In [None]:
def train_and_validate(model, device, combination, epochs, dataloaders):
  """Performs model training and validation.

  Parameters:
  model -- a PyTorch model instance
  device -- where to run computations (torch device object)
  combination -- a combination of hyperparameter values (namedtuple)
  epochs -- number of model runs (int)
  dataloaders -- PyTorch dataloader instances (tuple)
  """

  mse_loss = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(),
                         lr=combination.learning_rate,
                         weight_decay=combination.weight_decay)
  training_loss_log = []
  validation_loss_log = []

  for epoch in range(epochs):
    training_loss = []
    model.train()

    for batch in dataloaders[0]:
      past = batch[0].to(device)
      true_future = batch[1].to(device)
      softmaxed_true_future = nn.functional.softmax(true_future,
                                                    dim=1,
                                                    dtype=torch.float32)
      generated_future = model(past)
      loss = mse_loss(generated_future, softmaxed_true_future)
      model.zero_grad()
      loss.backward()
      optimizer.step()
      loss = loss.detach().cpu().numpy()
      training_loss.append(loss)
    validation_loss = []
    model.eval()
    with torch.no_grad():
      for batch in dataloaders[1]:
        past = batch[0].to(device)
        true_future = batch[1].to(device)
        softmaxed_true_future = nn.functional.softmax(true_future,
                                                      dim=1,
                                                      dtype=torch.float32)
        generated_future = model(past)
        loss = mse_loss(generated_future, softmaxed_true_future)
        loss = loss.detach().cpu().numpy()
        validation_loss.append(loss)
    training_loss = np.mean(training_loss)
    training_loss_log.append(training_loss)
    validation_loss = np.mean(validation_loss)
    validation_loss_log.append(validation_loss)
    print(f"EPOCH {epoch+1} - TRAINING LOSS: {training_loss: .2f} - VALIDATION LOSS: {validation_loss: .2f}")
    if epoch == epochs-1:
      print("Finished")
  torch.save(model.state_dict(), 'model_parameters.torch')
  return training_loss_log, validation_loss_log

In [None]:
def hyperparameter_tuning(combinations, device, dataloaders):
  """Chooses the best combination of hyperparameters.

  Parameters:
  combinations -- hyperparameter combinations to evaluate (namedtuple)
  device -- where to run computations (torch device object)
  dataloaders -- PyTorch dataloader instances (tuple)
  """

  scores = []
  for combination in combinations:
    model=WaveNetForecaster()
    model.to(device)
    print(f"Combination {combinations.index(combination)+1} of {len(combinations)}")
    score = train_and_validate(model=model,
                               device=device,
                               combination=combination,
                               epochs=20,
                               dataloaders=dataloaders)
    scores.append(score)
  print("Model selection finished!")
  training_scores = []
  validation_scores = []
  for score in scores:
    training, validation = score
    training_scores.append(training)
    validation_scores.append(validation)
  least_validation_score = min(validation_scores)
  idx = validation_scores.index(least_validation_score)
  winner = combinations[idx]
  return winner

## Perform hyperparameter tuning

In [None]:
device = torch.device("cuda") if torch.cuda.is_available else torch.device("CPU")
print(f"Device is: {device}")

In [None]:
device = torch.device("cuda") if torch.cuda.is_available else torch.device("CPU")
print(f"Device is: {device}")
print(f" ")
hyperparameter_values_to_try = dict(learning_rate=[1e-5,1e-4,1e-3,1e-2],
                                    weight_decay=[1e-5,1e-4,1e-3])

hyperparameter_combinations = combine(hyperparameter_values_to_try)

print(f"We will try the following {len(hyperparameter_combinations)} hyperparameter combinations:")
print(" ")
for index, combination in enumerate(hyperparameter_combinations):
  print(f"{index+1}:", combination)

optimal_hyperparameters = hyperparameter_tuning(model=WaveNetForecaster(),
                                                combinations=hyperparameter_combinations,
                                                device=device,
                                                dataloaders=(training_loader, training_loader))

## Train (?) using the best hyperparameters set  

In [None]:
torch.manual_seed(0)

model = WaveNetForecaster()
model.to(device)

outputs = train_and_validate(model=model,
                             device=device,
                             combination=optimal_hyperparameters,
                             epochs=100,
                             dataloaders=(training_loader, validation_loader))