# Encoding Model Tutorial

This tutorial introduces a typical encoding framework for mapping different features onto human brain activity during natural language comprehension. From previous notebooks, two types of features are obtained based on the stimulus transcript: syntactic features from spacy.io ([Honnibal et al., 2020](https://github.com/explosion/spaCy)) and contextual word embeddings from GPT-2 ([Radford et al., 2019](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)). Encoding models ([Naselaris et al., 2011](https://www.sciencedirect.com/science/article/pii/S1053811910010657?via%3Dihub)) map those features onto brain activity. They are estimated using ridge regression implemented in the [Himalaya](https://gallantlab.org/himalaya/index.html) package ([Dupré La Tour et al., 2022](https://doi.org/10.1016/j.neuroimage.2022.119728)).

Acknowledgments: This tutorial draws heavily on the [encling tutorial](https://github.com/snastase/encling-tutorial/blob/main/encling_tutorial.ipynb) by Samuel A. Nastase.

______________

First, we'll import some general-purpose Python packages.

In [None]:
import mne
import h5py
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from nilearn.plotting import plot_markers
from mne_bids import BIDSPath

from himalaya.backend import set_backend, get_backend
from himalaya.ridge import RidgeCV
from himalaya.scoring import correlation_score

from sklearn.model_selection import KFold
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

We will set the [Himalaya backend](https://gallantlab.org/himalaya/_generated/himalaya.backend.set_backend.html#himalaya.backend.set_backend) to `torch_cuda` so we can utilize gpu to train our encoding models.

In [None]:
if torch.cuda.is_available():
    set_backend("torch_cuda")

## Loading features

We will now load two different features. The first is syntactic features constructed using spacy.io ([Honnibal et al., 2020](https://github.com/explosion/spaCy)). The second contains contextual word embeddings generated from GPT-2 ([Radford et al., 2019](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)). The loaded features should be a numpy array with a shape of (number of tokens * feature dimensions). Note that the numbers of tokens are different for the two features because of different tokenization schemas.

In [None]:
embedding_path = "/scratch/gpfs/kw1166/247/monkey-data/stimuli/%s/states.hdf5"

# Syntactic features
with h5py.File(embedding_path % "syntactic", "r") as f:
    syntactic_features = f["vectors"][...]
print(syntactic_features.shape)

# Contextual word embeddings
modelname, layer = 'gpt2', 6

with h5py.File(embedding_path % modelname, "r") as f:
    contextual_embeddings = f[f"layer-{layer}"][...]
print(contextual_embeddings.shape)

We will also load the stimuli transcripts associated with these features. Both transcripts should contain information about the word, token, start (onset), and end (offset). The contextual word embedding transcript should also include other prediction information extracted from GPT-2, like rank, probability, and entropy. For instance, we can calculate how accurate the model is in predicting the next token in the transcript based on the `rank` column, which are integers that represents the rank of the actual token in all the possible tokens of GPT-2.

Note: Check that the transcript contains the same number of tokens as the features we loaded before.

In [None]:
transcript_path = "/scratch/gpfs/kw1166/247/monkey-data/stimuli/%s/transcript.tsv"

# Syntactic features transcript
df_syntactic = pd.read_csv(transcript_path % "syntactic", sep='\t', index_col=0)
print(df_syntactic.shape)
print(df_syntactic.head())

# Contextual word embeddings transcript
df_contextual = pd.read_csv(transcript_path % modelname, sep="\t", index_col=0)

if "rank" in df_contextual.columns:
    model_acc = (df_contextual["rank"] == 0).mean()
    print(f"Model accuracy: {model_acc*100:.3f}%")

print(df_contextual.shape)
print(df_contextual.head())

When we extracted features, some words are split into separate tokens. Since we only have information of start and end for words, we will align the features from tokens to words for encoding models. Here, we simply average the token features across the same word. Now the features should be a numpy array with a shape of (number of words * feature dimensions).

In [None]:
# Syntactic features
aligned_syntactic = []
for _, group in df_syntactic.groupby("word_idx"): # group by word index
    indices = group.index.to_numpy()
    average_syntactic = syntactic_features[indices].mean(0) # average features
    aligned_syntactic.append(average_syntactic)
aligned_syntactic = np.stack(aligned_syntactic)
print(aligned_syntactic.shape)

# Contextual word embeddings
aligned_embeddings = []
for _, group in df_contextual.groupby("word_idx"): # group by word index
    indices = group.index.to_numpy()
    average_emb = contextual_embeddings[indices].mean(0) # average features
    aligned_embeddings.append(average_emb)
aligned_embeddings = np.stack(aligned_embeddings)
print(aligned_embeddings.shape)

We will also construct a dataframe containing words with their start and end information.

In [None]:
df_word = df_contextual.groupby("word_idx").agg(dict(word="first", start="first", end="last"))
df_word.head()

Since we need word start information for encoding, we will filter out words where the `start` column information is missing.

In [None]:
good_mask = df_word['start'].notna().to_numpy()
print(sum(good_mask))

aligned_syntactic = aligned_syntactic[good_mask]
aligned_embeddings = aligned_embeddings[good_mask]
print(aligned_syntactic.shape)
print(aligned_embeddings.shape)

## Loading brain data

Next, we will load the ECoG data using MNE. Here, we will demonstrate loading data from our third subject.

In [None]:
edf_path = BIDSPath(
    root="/scratch/gpfs/zzada/ecog-narratives/monkey/derivatives/ecogprep",
    subject="03",
    datatype="ieeg",
    description="highgamma",
    extension=".fif",
)
edf_path = edf_path.match()[0]

raw = mne.io.read_raw_fif(edf_path)
raw

We will map the start information (in seconds) of each word in the dataframe onto the brain signal data by multiplying by the sampling rate. Here the first column of `events` mark the start of each word on the brain signal data.

In [None]:
events = np.zeros((len(df_word), 3), dtype=int)
events[:, 0] = (df_word.start * raw.info['sfreq']).astype(int)
events.shape

Then we'll take advantage of MNE's tools for creating epochs around stimulus events, which here are the starts (onsets) of each word, to visualize brain signal that respond to word onsets. Here, we take a fixed-width window ranging from -2 seconds to +2 seconds relative to word onset. Since the sampling rate is 512 Hz (512 samples per second), we have 2049 lags total. The ECoG data is a numpy array with the shape of (number of words * number of ECoG electrodes * number of lags)

In [None]:
epochs = mne.Epochs(
    raw,
    events,
    tmin=-2.0,
    tmax=2.0,
    baseline=None,
    proj=False,
    event_id=None,
    preload=True,
    event_repeated="merge",
)
print(f"ECoG data matrix shape: {epochs._data.shape}")

Next, we'll downsample the temporal resolution to 32 Hz, which reduces the number of lags to 32 * 4 = 128.

In [None]:
epochs = epochs.resample(sfreq=32, npad='auto', method='fft', window='hamming')
# epochs = epochs.resample(sfreq=32, npad='auto', method='polyphase')
epochs._data.shape

We can average each electrode's signal across all words, yielding event-related potentials (ERPs) for word start (onsets). Here, we visualize how one particular electrode, `LGA10` from subject 3, respond to word onsets (Time = 0s).

In [None]:
evoked = epochs.average()
evoked = evoked.pick("LGA10")
evoked.plot()
plt.show()

## Setting up feature and brain data

Now we have both the features and the ECoG data ready. We plan to fit encoding models at each electrode and for each lag, so we'll reshape our target matrix `Y` to horizontally stack both electrodes and lags along the secondg dimension.

In [None]:
epochs_data = epochs.get_data(copy=True) # Get ECoG data
epochs_data = epochs_data.reshape(len(epochs), -1) # Reshape ECoG data
print(f"ECoG data matrix shape: {epochs_data.shape}")

We will also align our features with the ECoG data.

In [None]:
selected_df = df_word.iloc[epochs.selection]
averaged_syntactic = aligned_syntactic[epochs.selection]
averaged_embeddings = aligned_embeddings[epochs.selection]
print(averaged_syntactic.shape)
print(averaged_embeddings.shape)

We will change the float precision to float32 for all data to take advantage of the GPU memory and computational speed.

In [None]:
X1 = averaged_syntactic
X2 = averaged_embeddings
Y = epochs_data

if "torch" in get_backend().__name__:
    X1 = X1.astype(np.float32)
    X2 = X2.astype(np.float32)
    Y = Y.astype(np.float32)

X1.shape, X2.shape, Y.shape

## Building encoding models

Now, we will use ridge regression to estimate the encoding model. We create a model pipeline uisng `sklearn`, which includes a [`StandardScaler`](https://scikit-learn.org/dev/modules/generated/sklearn.preprocessing.StandardScaler.html) that standardizes features (X), and a [`RidgeCV model`](https://gallantlab.org/himalaya/_generated/himalaya.ridge.RidgeCV.html#himalaya.ridge.RidgeCV), which performs ridge regression with cross-validation over our specificed alpha values.

In [None]:
alphas = np.logspace(1, 10, 10) # specify alpha values
inner_cv = KFold(n_splits=5, shuffle=False) # inner 5-fold cross-validation setup
model = make_pipeline(
    StandardScaler(), RidgeCV(alphas, fit_intercept=True, cv=inner_cv) # pipeline
)
model

## Training encoding models

While `RidgeCV` contains an inner cross-validation setup to find the best alpha, we will also set up an outer cross-validation loop to evaluate our encoding model. Here, we will use k = 2, meaning we will train on half of the data and evaluate on the other half. Within each fold, we will split the train and test dataset. Then we will standardize `Y` the same way we standardize `X` in the pipeline. We will then fit our model on the training dataset and use it to predict for the testing dataset. For evaluation, we will calculate correlation scores between `Y_preds`, the ECoG signal predicted by our model, and `Y_test`, the actual ECoG signal. The encoding model is trained and evaluated for each electrode and each lag.

Note: this chunk of code takes a while to run...

In [None]:
epochs_shape = epochs._data.shape[1:] # number of electrodes * number of lags

def train_encoding(X, Y):

    corrs = [] # empty array to store correlation results
    kfold = KFold(2, shuffle=False) # outer 2-fold cross-validation setup
    for train_index, test_index in kfold.split(X): # loop through folds

        # Split train and test datasets
        X1_train, X1_test = X[train_index], X[test_index]
        Y_train, Y_test = Y[train_index], Y[test_index]

        # Standardize Y
        scaler = StandardScaler()
        Y_train = scaler.fit_transform(Y_train)
        Y_test = scaler.transform(Y_test)

        model.fit(X1_train, Y_train) # Fit pipeline with transforms and ridge estimator
        Y_preds = model.predict(X1_test) # Use trained model to predict on test set
        corr = correlation_score(Y_test, Y_preds).reshape(epochs_shape) # Compute correlation score

        if "torch" in get_backend().__name__: # if using gpu, transform tensor back to numpy
            corr = corr.numpy(force=True)

        corrs.append(corr) # append fold correlation results to final results
    return np.stack(corrs)

# set_backend("torch") # resort to torch or numpy if cuda out of memory
corrs_syntactic = train_encoding(X1, Y)
corrs_embedding = train_encoding(X2, Y)


## Plotting encoding lag results

Now we can plot the correlations for one electrode on all lags. Again, we use `LGA10` as an example. We will plot the correlation for the syntactic features as blue and the correlation for the contextual word embeddings as red.

In [None]:
lags = np.arange(-2 * 512, 2 * 512, 16) / 512 # specify the lags

electrode = "LGA10"
electrode_id = raw.info["ch_names"].index(electrode) # get electrode index

plt.axvline(0, c="k", alpha=0.3, ls=":")
plt.axhline(0, c="k", alpha=0.3, ls=":")
plt.plot(lags, corrs_syntactic.mean(0)[electrode_id], c="b") # blue line
plt.plot(lags, corrs_embedding.mean(0)[electrode_id], c="r") # red line
plt.xlabel("Lag (s)")
plt.ylabel("Encoding performance (r)")
plt.show()

## Plotting encoding brainplots

We can also plot the max correlations for all electrodes on the brain. We will first get the coordinates of all electrodes from the raw ECoG data, resulting in `coords` with the shape of (number of electrodes * 3), where each electrode has a 3-dimensional MNI coordinate relative to the brain surface.

In [None]:
ch2loc = {ch['ch_name']: ch['loc'][:3] for ch in raw.info['chs']}
coords = np.vstack([ch2loc[ch] for ch in raw.info['ch_names']])
coords *= 1000  # nilearn likes to plot in meters, not mm
coords.shape

We will take the max correlation for each electrode. Then, we will use the nilearn [`plot_markers`](https://nilearn.github.io/dev/modules/generated/nilearn.plotting.plot_markers.html#nilearn.plotting.plot_markers) function to plot electrodes on top of the brain glass schematics. To compare the encoding performance of our syntactic features and contextual word embeddings, we use the same colorbar scale for both brainplots.

In [None]:
scores_syntactic = corrs_syntactic.mean(0).max(-1)
scores_embedding = corrs_embedding.mean(0).max(-1)
print(scores_syntactic.shape, scores_embedding.shape)
vmax = np.quantile(np.concatenate((scores_syntactic, scores_embedding)), .99)
print(f"Colorbar max correlation: {vmax}")

fig, axes = plt.subplots(2, 1, dpi=300, figsize=(8, 6))
order = scores_syntactic.argsort()
plot_markers(scores_syntactic[order], coords[order],
             node_size=15, display_mode='lyr',
             node_vmin=0, node_vmax=vmax,
             figure=fig, axes=axes[0], alpha=0.8,
             node_cmap='magma_r', colorbar=True)
plot_markers(scores_embedding[order], coords[order],
             node_size=15, display_mode='lyr',
             node_vmin=0, node_vmax=vmax,
             figure=fig, axes=axes[1], alpha=0.8,
             node_cmap='magma_r', colorbar=True)
fig.show()