# Training and evaluating decoding models with PyTorch

Acknowledgments: This tutorial draws heavily on the encling tutorial by Samuel A. Nastase.

In [None]:
# only run this cell in colab
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install mne mne_bids himalaya scikit-learn pandas matplotlib nilearn gensim

In [None]:
import mne
import h5py
import torch
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.nn import functional as F
from gensim.models import KeyedVectors
import re

from tqdm import tqdm

from mne_bids import BIDSPath

from himalaya.backend import set_backend
from sklearn.decomposition import PCA

from decoding_utils import run_training_over_lags

In [None]:
if torch.cuda.is_available():
    set_backend("torch_cuda")
    device = torch.device("cuda")
    print("Using cuda!")
else:
    device = torch.device("cpu")
    print("Using CPU")

## Loading features

In [None]:
bids_root = ""  # if using a local dataset, set this variable accordingly

# Download the embedding, if required
embedding_path = f"{bids_root}stimuli/gpt2-xl/features.hdf5"
if not len(bids_root):
    !wget -nc https://s3.amazonaws.com/openneuro.org/ds005574/$embedding_path
    embedding_path = "features.hdf5"
embedding_path = "features.hdf5"

print(f"Using embedding file path: {embedding_path}")

In [None]:
modelname, layer = 'gpt2-xl', 24
with h5py.File(embedding_path, "r") as f:
    contextual_embeddings = f[f"layer-{layer}"][...]
print(f"LLM embedding matrix has shape: {contextual_embeddings.shape}")

In [None]:
# Download the transcript, if required
transcript_path = f"{bids_root}stimuli/gpt2-xl/transcript.tsv"
if not len(bids_root):
    !wget -nc https://s3.amazonaws.com/openneuro.org/ds005574/$transcript_path
    transcript_path = "transcript.tsv"

# Load transcript
df_contextual = pd.read_csv(transcript_path, sep="\t", index_col=0)
if "rank" in df_contextual.columns:
    model_acc = (df_contextual["rank"] == 0).mean()
    print(f"Model accuracy: {model_acc*100:.3f}%")

df_contextual.head()

When we extracted features, some words are split into separate tokens. Since we only have information of start and end for words, we will align the features from tokens to words for decoding models. Here, we simply average the token features across the same word. Now the features should be a numpy array with a shape of (number of words * feature dimensions).

In [None]:
aligned_embeddings = []
for _, group in df_contextual.groupby("word_idx"): # group by word index
    indices = group.index.to_numpy()
    average_emb = contextual_embeddings[indices].mean(0) # average features
    aligned_embeddings.append(average_emb)
aligned_embeddings = np.stack(aligned_embeddings)
pca = PCA(n_components=50, svd_solver='auto')
pca_embeddings = pca.fit_transform(aligned_embeddings.tolist())
print(f"LLM embeddings matrix has shape: {aligned_embeddings.shape}")
print(f"PCA embeddings matrix has shape: {pca_embeddings.shape}")

We will also construct a dataframe containing words with their start and end timestamps.

In [None]:
df_word = df_contextual.groupby("word_idx").agg(dict(word="first", start="first", end="last"))
df_word.head()

## Loading brain data

In [None]:
file_path = BIDSPath(root=f"{bids_root}derivatives/ecogprep",
                    subject="03", task="podcast", datatype="ieeg", description="highgamma",
                    suffix="ieeg", extension=".fif")
print(f"File path within the dataset: {file_path}")

# You only need to run this if using Colab (i.e. if you did not set bids_root to a local directory)
if not len(bids_root):
    !wget -nc https://s3.amazonaws.com/openneuro.org/ds005574/$file_path
    file_path = file_path.basename

raw = mne.io.read_raw_fif(file_path, verbose=False)
picks = mne.pick_channels_regexp(raw.ch_names, "LG[AB]*")
raw = raw.pick(picks)
raw

We'll setup GloVe embeddings as well for comparison to GPT-2.

In [None]:
if not os.path.exists('glove'):
    !mkdir -p glove
    !mkdir -p glove

if not os.path.exists('glove/glove.6B.zip'):
    # Wikipedia 2014 + Gigaword 5 (6B tokens, 50d, 100d, 200d, & 300d vectors)
    !wget https://nlp.stanford.edu/data/glove.6B.zip -P glove/
    
    # Twitter (27B tokens, 25d, 50d, 100d, & 200d vectors)
    # wget https://nlp.stanford.edu/data/glove.twitter.27B.zip -P glove/
    
    # Extract the downloaded zip file
    !unzip glove/glove.6B.zip -d glove/

glove_file = 'glove/glove.6B.50d.txt'
glove_vectors = KeyedVectors.load_word2vec_format(glove_file, binary=False, no_header=True)

### Decoding Straight from Neural Data


Here we decode straight from the neural data using a convolutional network per [https://www.nature.com/articles/s41593-022-01026-4]. To adapt this to your custom model for decoding you need two primary pieces.

1. Define a decoding model and a constructor function.
2. Define a data preprocessing function for preparing the data for your decoding model.

Below I demonstrate how to do this with the raw neural data. In the Decoding Foundation Model section I demonstrate how to adapt this to a model's output which will be useful if you're trying to decode from something like BrainBert output.

First we need to define a decoding model. This model is adapted from the 2021 paper linked above and is essentially a convolutional model. Below that I define an ensemble model which averages the outputs of several of these models which improves performance and was used in the original paper.

In [None]:
class PitomModel(nn.Module):
    def __init__(
        self,
        input_channels,
        output_dim,
        conv_filters=128,
        reg=0.35,
        reg_head=0,
        dropout=0.2
    ):
        """
        PyTorch implementation of the PITOM decoding model.
        
        Args:
            input_channels: Numbr of electrodes in data (int)
            output_dim: Dimension of output vector (int)
            conv_filters: Number of convolutional filters (default: 128)
            reg: L2 regularization factor for convolutional layers (default: 0.35)
            reg_head: L2 regularization factor for dense head (default: 0)
            dropout: Dropout rate (default: 0.2)
        """
        super(PitomModel, self).__init__()
        
        self.conv_filters = conv_filters
        self.reg = reg
        self.reg_head = reg_head
        self.dropout = dropout
        self.output_dim = output_dim
        
        # Define the CNN architecture
        self.desc = [(conv_filters, 3), ('max', 2), (conv_filters, 2)]
        
        # Build the layers
        self.layers = nn.ModuleList()
        
        for i, (filters, kernel_size) in enumerate(self.desc):
            if filters == 'max':
                self.layers.append(
                    nn.MaxPool1d(kernel_size=kernel_size, stride=kernel_size, padding=kernel_size//2)
                )
            else:
                # Conv block
                conv = nn.Conv1d(
                    in_channels=input_channels if i == 0 else conv_filters,
                    out_channels=filters,
                    kernel_size=kernel_size,
                    stride=1,
                    padding=0,  # 'valid' in Keras
                    bias=False
                )
                
                # Apply weight decay equivalent to L2 regularization
                self.layers.append(conv)
                self.layers.append(nn.ReLU())
                self.layers.append(nn.BatchNorm1d(filters))
                self.layers.append(nn.Dropout(dropout))
                
                input_channels = filters
        
        # Final locally connected layer (using Conv1d with groups as approximation)
        # Note: True locally connected layers aren't standard in PyTorch
        # This is an approximation that would need to be customized further for exact equivalence
        self.final_conv = nn.Conv1d(
            in_channels=conv_filters,
            out_channels=conv_filters,
            kernel_size=2,
            stride=1,
            padding=0,  # 'valid' in Keras
            bias=True
        )
        
        self.final_bn = nn.BatchNorm1d(conv_filters)
        self.final_act = nn.ReLU()
        
        # Output layer
        self.dense = nn.Linear(conv_filters, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)
        self.tanh = nn.Tanh()
    
    def forward(self, x):
        # Apply layers
        for layer in self.layers:
            x = layer(x)
        
        # Apply final conv block
        x = self.final_conv(x)
        x = self.final_bn(x)
        x = self.final_act(x)
        
        # Global max pooling
        x = F.adaptive_max_pool1d(x, 1).squeeze(-1)
        
        # Apply output layer if needed
        x = self.dense(x)
        x = self.layer_norm(x)
        x = self.tanh(x)
            
        return x

class EnsemblePitomModel(nn.Module):
    def __init__(
        self,
        num_models: int,
        input_channels,
        output_dim: int,
        conv_filters=128,
        reg=0.35,
        reg_head=0,
        dropout=0.2
    ):
        """
        PyTorch implementation of the PITOM decoding model.
        
        Args:
            num_models: The number of models to include in the ensemble. The outputs will be averaged at the end.
            input_channels: Numbr of electrodes in data (int)
            output_dim: Dimensionality of output (int)
            conv_filters: Number of convolutional filters (default: 128)
            reg: L2 regularization factor for convolutional layers (default: 0.35)
            reg_head: L2 regularization factor for dense head (default: 0)
            dropout: Dropout rate (default: 0.2)
        """
        super(EnsemblePitomModel, self).__init__()

        self.models = nn.ModuleList()
        for _ in range(num_models):
            self.models.append(PitomModel(
                input_channels,
                output_dim,
                conv_filters=conv_filters,
                reg=reg,
                reg_head=reg_head,
                dropout=dropout
            ))

    def forward(self, x):
        # Run all models and average together all embeddings.
        embeddings = torch.stack([model(x) for model in self.models])
        return embeddings.mean(0)

Our training function needs a constructor function which takes in a dictionary of model_params which you provide when you call the training function and returns a model object. This can be very simple as you see here.

The reason I'm using a constructor here is that it allows us to extend our existing code to any decoding model we desire without having to change our code for every use case. Now the code for a BrainBert decoder can exist separate from the neural decoder and any other model we choose to use, but they all still use the same backbone.

In [None]:
def pitom_model(model_params):
    return PitomModel(
            input_channels=model_params['input_channels'],
            output_dim=model_params['embedding_dim'],
            conv_filters=model_params['conv_filters'],
            reg=model_params['reg'],
            reg_head=model_params['reg_head'],
            dropout=model_params['dropout']
        )

def ensemble_pitom_model(model_params):
    return EnsemblePitomModel(
            num_models=model_params['num_models'],
            input_channels=model_params['input_channels'],
            output_dim=model_params['embedding_dim'],
            conv_filters=model_params['conv_filters'],
            reg=model_params['reg'],
            reg_head=model_params['reg_head'],
            dropout=model_params['dropout']
        )

Next we need a preprocessing function to transform the neural data as we would like for this particular model. Here I want to average over windows of 32 data points to essentially resample the data from 512 hz to 16 hz. 

The preprocessing function will always be passed data in the shape (num_words, num_electrodes, num_timesteps). num_words is defined by the df_word dataframe passed into data_params below, num_electrodes is the number of electrodes in the mne.Raw object passed into data_params, and num_timesteps is defined by the window_width * sampling_frequency of your mne.Raw. window_width is passed into data_params below in seconds. The preprocessing function requires that the returned data has shape [num_words, ...] where '...' is any arbitrary shape you need for your purposes. The only requirement is that data.shape[0] == num_words.

Like the model constructor, we use this preprocessing function as an input to our training code because it doesn't require us to change the backbone code for every use case. Now all of our use cases can define their custom code in a logical place without having to change our backbone code and potentially lead to bloated functions which are prone to bugs.

In [None]:
def preprocess_neural_data(data):
    return data.reshape(data.shape[0], data.shape[1], -1, 32).mean(-1)

With that we can now train a decoding model. Results will be written to your specified output directory. Currently we only write out the weighted_roc's of the model for the specified lags but you can expand on the existing code to write out new metrics as needed.

In [None]:
lags = np.arange(-900, 1000, 100)

weighted_roc_means = run_training_over_lags(lags, ensemble_pitom_model,
        model_params={
            'conv_filters': 128,
            'reg': 0.35,
            'reg_head': 0,
            'dropout': 0.2,
            'num_models': 10,
            'input_channels': len(raw.ch_names),
            'embedding_dim': pca_embeddings.shape[1],
        },
        training_params={
            'batch_size': 32,
            'epochs': 100,
            'learning_rate': 0.001,
            'weight_decay': 0.0001,
            'early_stopping_patience': 10,
            'n_folds': 5
        },
        data_params={
            'raw': raw,
            'df_word': df_word,
            'word_embeddings': pca_embeddings,
            'window_width': 0.625,
            'preprocessing_fn': preprocess_neural_data,
        },
        trial_name='ensemble_model_10')

One other thing you may want to do is decode into different sets of embeddings. Here we create arbitrary embeddings with no semantic information encoded. All we have to do to train over them is pass the embeddings of shape [num_words, embedding_dim] into our data_params. num_words must be the same as the number of words in the passed df_word.

In [None]:
# Generate arbitrary embeddings 
words = df_word.word.tolist()
unique_words = list(set(words))
word_to_idx = {}
for i, word in enumerate(words):
    if word not in word_to_idx:
        word_to_idx[word] = []
    word_to_idx[word].append(i)
    
arbitrary_embeddings_per_word = np.random.uniform(low=-1.0, high=1.0, size=(len(unique_words), 50))
arbitrary_embeddings = np.zeros((len(words), 50))
for i, word in enumerate(unique_words):
    for idx in word_to_idx[word]:
        arbitrary_embeddings[idx] = arbitrary_embeddings_per_word[i]

lags = np.arange(-900, 1000, 100)
weighted_roc_means_arbitrary = run_training_over_lags(lags, ensemble_pitom_model,
        model_params={
            'conv_filters': 128,
            'reg': 0.35,
            'reg_head': 0,
            'dropout': 0.2,
            'num_models': 10,
            'input_channels': len(raw.ch_names),
            'embedding_dim': arbitrary_embeddings.shape[1],
        },
        training_params={
            'batch_size': 32,
            'epochs': 100,
            'learning_rate': 0.001,
            'weight_decay': 0.0001,
            'early_stopping_patience': 10,
            'n_folds': 5
        },
        data_params={
            'raw': raw,
            'df_word': df_word,
            'word_embeddings': arbitrary_embeddings,
            'window_width': 0.625,
            'preprocessing_fn': preprocess_neural_data,
        },
        trial_name='ensemble_model_arbitrary_embeddings')

We can do the same with GloVe embeddings here.

In [None]:
def preprocess_word(word):
    # Convert to lowercase
    word = word.lower()
    # Remove punctuation
    word = re.sub(r'[^\w\s]', '', word)
    return word

words = df_word.word.tolist()
preprocessed_words = [preprocess_word(word) for word in words]
in_glove = []
glove_embeddings = []

for i, word in enumerate(preprocessed_words):
    if word in glove_vectors:
        glove_embeddings.append(glove_vectors[word])
        in_glove.append(True)
    else:
        in_glove.append(False)

df_word['in_glove'] = in_glove

glove_embeddings = np.vstack(glove_embeddings)

lags = np.arange(-900, 1000, 100)
weighted_roc_means_glove = run_training_over_lags(lags, ensemble_pitom_model,
        model_params={
            'conv_filters': 128,
            'reg': 0.35,
            'reg_head': 0,
            'dropout': 0.2,
            'num_models': 10,
            'input_channels': len(raw.ch_names),
            'embedding_dim': glove_embeddings.shape[1],
        },
        training_params={
            'batch_size': 32,
            'epochs': 100,
            'learning_rate': 0.001,
            'weight_decay': 0.0001,
            'early_stopping_patience': 10,
            'n_folds': 5
        },
        data_params={
            'raw': raw,
            'df_word': df_word[df_word['in_glove']],
            'word_embeddings': glove_embeddings,
            'window_width': 0.625,
            'preprocessing_fn': preprocess_neural_data,
        },
        trial_name='ensemble_model_glove')

In [None]:
gpt_data = pd.read_csv("results/ensemble_model_10_roc_means.csv")
arbitrary_data = pd.read_csv("results/ensemble_model_arbitrary_embeddings_roc_means.csv")
glove_data = pd.read_csv("results/ensemble_model_glove_roc_means.csv")
plt.plot(gpt_data.lags, gpt_data.rocs, label='GPT-2')
plt.plot(arbitrary_data.lags, arbitrary_data.rocs, label='Arbitrary Embeddings')
plt.plot(glove_data.lags, glove_data.rocs, label='GloVe')
plt.axvline(0, color='red', alpha=0.5)
plt.xlabel('Lags (ms)')
plt.ylabel('AUC-ROC')
plt.title('AUC-ROC as a function of Lags')
plt.legend()
plt.show()

### Decoding from Foundation Model

Here we can go through the same process but now we want to decode from the outputs of our foundation model. The model I've supplied here is essentially randomly initialized so performance will not be good.

First we'll load data for the subject we pretrain over.

In [None]:
file_path = BIDSPath(root=f"{bids_root}derivatives/ecogprep",
                    subject="09", task="podcast", datatype="ieeg", description="highgamma",
                    suffix="ieeg", extension=".fif")
print(f"File path within the dataset: {file_path}")

# You only need to run this if using Colab (i.e. if you did not set bids_root to a local directory)
if not len(bids_root):
    !wget -nc https://s3.amazonaws.com/openneuro.org/ds005574/$file_path
    file_path = file_path.basename

raw = mne.io.read_raw_fif(file_path, verbose=False)
# picks = mne.pick_channels_regexp(raw.ch_names, "LG[AB]*")
grid_ch_names = []
for i in range(64):
    channel = "G" + str(i + 1)
    if np.isin(channel, raw.info.ch_names):
        grid_ch_names.append(channel)


raw = raw.pick(grid_ch_names,)
raw

In [None]:
import sys

from foundation_model.models_mae import MaskedAutoencoderViT
from foundation_model.config import create_video_mae_experiment_config_from_file
from foundation_model.utils import create_model

Create the model.

In [None]:
# First load the model and set it in eval mode.
model_dir = "foundation_model/models"
ecog_config = create_video_mae_experiment_config_from_file(os.path.join(model_dir, "experiment_config.ini"))

model = create_model(ecog_config)
model.load_state_dict(torch.load(os.path.join(model_dir, "model.pth"), weights_only=True))
model = model.to(device)

model.eval()

Now this is the important part for decoding. As above we have the same steps to accomplish:

1. Define a decoding model and a constructor function.
2. Define a data preprocessing function for preparing the data for your decoding model.

First we'll define a decoding model. Because the model currently only outputs a 16 dimensional embedding of the neural data, we'll instead just use an MLP to go from our neural embedding to the 50 dimensional word embedding.

In [None]:
class MLP(nn.Module):
    def __init__(self, layer_sizes, activation=F.relu, dropout_rate=0.2, use_layer_norm=True):
        """
        Initialize a Multi-Layer Perceptron with configurable architecture and LayerNorm.
        
        Args:
            layer_sizes (list): List of integers specifying the size of each layer.
                               First element is input size, last element is output size.
            activation (function): Activation function to use between layers (default: ReLU).
            dropout_rate (float): Dropout probability for regularization (default: 0.2).
            use_layer_norm (bool): Whether to use LayerNorm after each hidden layer (default: True).
        """
        super(MLP, self).__init__()
        
        if len(layer_sizes) < 2:
            raise ValueError("layer_sizes must contain at least input and output sizes")
        
        self.layers = nn.ModuleList()
        self.layer_norms = nn.ModuleList() if use_layer_norm else None
        self.activation = activation
        self.dropout = nn.Dropout(dropout_rate)
        self.use_layer_norm = use_layer_norm
        
        # Create linear layers and layer norms based on specified sizes
        for i in range(len(layer_sizes) - 1):
            self.layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            
            # Add layer norm for all but the output layer
            if use_layer_norm:
                self.layer_norms.append(nn.LayerNorm(layer_sizes[i+1]))
    
    def forward(self, x):
        """
        Forward pass through the MLP.
        
        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, layer_sizes[0]]
            
        Returns:
            torch.Tensor: Output tensor of shape [batch_size, layer_sizes[-1]]
        """
        for i, layer in enumerate(self.layers):
            x = layer(x)

            # Final layer will be normed.
            if self.use_layer_norm:
                x = self.layer_norms[i](x)
            
            # Apply activation, and dropout to all but the final layer
            if i < len(self.layers) - 1:
                x = self.activation(x)
                x = self.dropout(x)
                
        return x

Now we can define our very simple constructor function.

In [None]:
def mlp_fn(model_params):
    return MLP(model_params['layer_sizes'])

Here we can define our preprocessing function. As above it takes in data of shape (num_words, num_electrodes, num_timesteps). First I reshape the data into the shape expected by our foundation model (num_words, num_frequency_bands, resampled_time_steps, 8, 8) for our grid data. I then pass the data through the foundation model in batches to get data of shape (num_words, embedding_dim). This is now ready to be sent to my decoding model so I return that data.

If you're decoding from something like BrainBert you'd likely want your preprocessing function to look something like this where you

1. Get the data into the expected shape and sampling rate.
2. Feed the data through the model in batches to reduce memory overhead.
3. Combined embeddings in any way you see fit and then return the result.

In [None]:
def foundation_model_preprocessing_fn(data):
    data_config = ecog_config.ecog_data_config
    data = data.reshape(data.shape[0], data.shape[1], -1, data_config.original_fs // data_config.new_fs)
    data = data.mean(-1)
    
    for i in range(64):
        channel = "G" + str(i + 1)
        if not np.isin(channel, raw.info.ch_names):
            data = np.insert(data, i, np.zeros_like(data[:, i, :]), axis=1)

    # Reshape to [num_examples, frequency bands (currrently 1), time, num_electrodes]
    data = np.einsum('bet->bte', data).reshape(data.shape[0], data.shape[2], 8, 8)
    data = np.expand_dims(data, axis=1)

    # Construct input dataset
    batch_size = 16
    foundation_embeddings = []
        
    with torch.no_grad():
        for i in tqdm(range(0, len(data), batch_size)):
            batch = torch.tensor(data[i:i+batch_size], dtype=torch.float32).to(device)
            batch_embeddings = model(batch, forward_features=True)  # Shape: [batch_size, 16]
            foundation_embeddings.append(batch_embeddings.cpu().numpy())
    
    foundation_embeddings = np.vstack(foundation_embeddings)

    return foundation_embeddings

We can decode with essentially the same setup as before with minimal changes. I just have to change a few parameters to fit the expectations of my foundation model.

In [None]:
lags = np.arange(-900, 1000, 100)

weighted_roc_means = run_training_over_lags(lags, mlp_fn,
        model_params={
            'layer_sizes': [16, 32, 50],
        },
        training_params={
            'batch_size': 32,
            'epochs': 100,
            'learning_rate': 0.001,
            'weight_decay': 0.0001,
            'early_stopping_patience': 10,
            'n_folds': 5
        },
        data_params={
            'raw': raw,
            'df_word': df_word,
            'word_embeddings': pca_embeddings,
            'window_width': ecog_config.ecog_data_config.sample_length,
            'preprocessing_fn': foundation_model_preprocessing_fn,
        },
        trial_name='foundation_model_trial')

Here we can see that it performs not well as expected with a random model.

In [None]:
data = pd.read_csv("foundation_model_trial.csv")
lags = data.lags
weighted_roc_means = data.rocs
plt.plot(lags, weighted_roc_means, label='GPT-2')
# plt.plot(lags, weighted_roc_means_arbitrary, label='Arbitrary Embeddings')
plt.axvline(0, color='red', alpha=0.5)
plt.xlabel('Lags (ms)')
plt.ylabel('AUC-ROC')
plt.title('AUC-ROC as a function of Lags')
plt.legend()
plt.show()