Estimate a generalized additive regression model (GAM) to explain N400 amplitude on the given datasets.

In [1]:
import io
from itertools import product
import logging
L = logging.getLogger(__name__)
from pathlib import Path
import pickle
import yaml

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats as st
import torch
from tqdm.auto import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append("..")
from berp.datasets import BerpDataset, NestedBerpDataset, get_metadata
from berp.datasets.eeg import load_eeg_dataset
from berp.viz import make_word_onset_epochs

In [4]:
workflow = "heilbron2022"
lm = "EleutherAI/gpt-neo-2.7B/n10000"

subjects = list(range(1, 20))
runs = list(range(1, 20))
stories = ["old-man-and-the-sea"]

# target_sensors = ["B19", "B20", "B22"]
target_sensors = ['B22', 'D19', 'A19', 'C22', 'C10', 'C32', 'A7', 'B4']

normalize_X_ts = False
normalize_X_variable = False
normalize_Y = False

epoch_tmin = -12/128
epoch_tmax = 96/128

baseline_tmin = -12/128
baseline_tmax = 0.

temporal_rois = [("baseline", baseline_tmin, baseline_tmax),
                 ("n400", 0.3, 0.5)]

out_path = "n400_gam_data.csv"

# # Only run quantitative tests on items we have a priori interest in. We want to avoid
# # fishing for p-values here, and also degrading any positive results due to multiple
# # comparisons.
# # Semantics: (feature, troi, sensor) where we expect amplitude/latency-to-peak differences
# a_priori_interest = [
#     ("var_word_surprisal", "n125", "D19"),
#     ("var_word_surprisal", "n125", "C10"),
#     ("var_word_surprisal", "n400", "A19"),
#     ("var_word_surprisal", "early", "C10"),
# ]

In [5]:
dataset_dir = f"../workflow/{workflow}/data/dataset/{lm}"
stimulus_dir = f"../workflow/{workflow}/data/stimulus/{lm}"

## Prepare dataset

In [6]:
# Load datasets.
ds = {}
stimulus_paths = {f"{story}/run{run}": Path(stimulus_dir) / f"{story}/run{run}.pkl"
                  for story in stories for run in runs}
combs = list(product(stories, subjects, runs))
ds_paths = []
for story, subject, run in tqdm(combs):
    ds_path = Path(dataset_dir) / story / f"sub{subject}" / f"run{run}.pkl"
    if not ds_path.exists():
        print(f"Could not find dataset {ds_path}")
        continue
    ds_paths.append(ds_path)
    
nested_ds = load_eeg_dataset(
    ds_paths,
    subset_sensors=target_sensors,
    stimulus_paths=stimulus_paths,
    normalize_X_ts=normalize_X_ts, 
    normalize_X_variable=normalize_X_variable, 
    normalize_Y=normalize_Y)

assert len(combs) == len(nested_ds.datasets)
ds = dict(zip(combs, nested_ds.datasets))

  0%|          | 0/361 [00:00<?, ?it/s]

## Epoch

In [7]:
df = pd.concat([make_word_onset_epochs(
        ds_i,
        tmin=epoch_tmin, tmax=epoch_tmax,
        baseline=False,
    ) for ds_i in tqdm(ds.values())],
    names=["story", "subject", "run"],
    keys=ds.keys())
df.index = df.index.rename("word_idx", level="epoch")
df = df.drop(columns=["word_idx"])
df["sensor_name"] = df.index.get_level_values("sensor_idx").map(dict(enumerate(target_sensors)))
df

  0%|          | 0/361 [00:00<?, ?it/s]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,value,epoch_time,sensor_name
story,subject,run,word_idx,sample,sensor_idx,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
old-man-and-the-sea,1,1,0,0,0,0.031092,-0.093750,B22
old-man-and-the-sea,1,1,0,0,1,0.056925,-0.093750,D19
old-man-and-the-sea,1,1,0,0,2,0.044012,-0.093750,A19
old-man-and-the-sea,1,1,0,0,3,0.032322,-0.093750,C22
old-man-and-the-sea,1,1,0,0,4,-0.000803,-0.093750,C10
old-man-and-the-sea,...,...,...,...,...,...,...,...
old-man-and-the-sea,19,19,588,107,3,0.026362,0.742188,C22
old-man-and-the-sea,19,19,588,107,4,0.039911,0.742188,C10
old-man-and-the-sea,19,19,588,107,5,0.054915,0.742188,C32
old-man-and-the-sea,19,19,588,107,6,-0.127584,0.742188,A7


### Prepare design matrix

In [7]:
metadata_df = pd.concat([get_metadata(ds_i) for ds_i in ds.values()],
                        names=["story", "subject", "run"],
                        keys=ds.keys())
metadata_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,word_length,word_duration,word_surprisal,word_frequency,word_prior_entropy,word
story,subject,run,word_idx,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
old-man-and-the-sea,1,1,0,3,0.156250,-2.381337,7.429644,5.486012,wʌz
old-man-and-the-sea,1,1,1,2,0.078125,-3.896170,9.030593,5.764207,ʌn
old-man-and-the-sea,1,1,2,3,0.289062,-3.420624,10.644726,6.296887,oʊld
old-man-and-the-sea,1,1,3,3,0.382812,-0.872183,9.044897,4.189036,mæn
old-man-and-the-sea,1,1,4,2,0.210938,-1.986570,8.776630,3.307850,hu
old-man-and-the-sea,...,...,...,...,...,...,...,...,...
old-man-and-the-sea,19,19,584,4,0.250000,-4.542065,12.977191,4.395442,lʊkt
old-man-and-the-sea,19,19,585,2,0.156250,-1.088804,8.243340,2.250363,æt
old-man-and-the-sea,19,19,586,2,0.109375,-0.945295,5.689068,1.269512,ɪt
old-man-and-the-sea,19,19,587,2,0.117188,-2.847154,6.640238,2.992393,ɪn


## Aggregate

In [9]:
# Identifies a single epoch
id_key = ["story", "subject", "run", "word_idx", "sensor_name"]

In [10]:
roi_df = {}

for label, tstart, tend in temporal_rois:
    sub_comb_df = df[df.epoch_time.between(tstart, tend)]
    roi_df[label] = sub_comb_df.groupby(id_key).value.mean()
    
roi_df = pd.concat(roi_df, names=["toi"])
roi_df

toi       story                subject  run  word_idx  sensor_name
baseline  old-man-and-the-sea  1        1    0         A19            0.041212
                                                       A7             0.107339
                                                       B22            0.026291
                                                       B4             0.105514
                                                       C10            0.010029
                                                                        ...   
n400      old-man-and-the-sea  19       19   588       B4            -0.179150
                                                       C10           -0.026348
                                                       C22           -0.246394
                                                       C32           -0.053439
                                                       D19           -0.223328
Name: value, Length: 3267392, dtype: float64

In [11]:
agg_df = roi_df.unstack("toi")
agg_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,toi,baseline,n400
story,subject,run,word_idx,sensor_name,Unnamed: 5_level_1,Unnamed: 6_level_1
old-man-and-the-sea,1,1,0,A19,0.041212,-0.002146
old-man-and-the-sea,1,1,0,A7,0.107339,0.094173
old-man-and-the-sea,1,1,0,B22,0.026291,0.006058
old-man-and-the-sea,1,1,0,B4,0.105514,0.067242
old-man-and-the-sea,1,1,0,C10,0.010029,0.034790
old-man-and-the-sea,...,...,...,...,...,...
old-man-and-the-sea,19,19,588,B4,0.202290,-0.179150
old-man-and-the-sea,19,19,588,C10,0.026011,-0.026348
old-man-and-the-sea,19,19,588,C22,0.012474,-0.246394
old-man-and-the-sea,19,19,588,C32,0.021017,-0.053439


In [12]:
full_df = pd.merge(agg_df, metadata_df, left_index=True, right_index=True)
full_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,baseline,n400,word_length,word_duration,word_surprisal,word_frequency,word_prior_entropy,word
story,subject,run,word_idx,sensor_name,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
old-man-and-the-sea,1,1,0,A19,0.041212,-0.002146,3,0.156250,-2.381337,7.429644,5.486012,wʌz
old-man-and-the-sea,1,1,0,A7,0.107339,0.094173,3,0.156250,-2.381337,7.429644,5.486012,wʌz
old-man-and-the-sea,1,1,0,B22,0.026291,0.006058,3,0.156250,-2.381337,7.429644,5.486012,wʌz
old-man-and-the-sea,1,1,0,B4,0.105514,0.067242,3,0.156250,-2.381337,7.429644,5.486012,wʌz
old-man-and-the-sea,1,1,0,C10,0.010029,0.034790,3,0.156250,-2.381337,7.429644,5.486012,wʌz
old-man-and-the-sea,...,...,...,...,...,...,...,...,...,...,...,...
old-man-and-the-sea,19,19,588,B4,0.202290,-0.179150,7,0.632812,-3.234899,18.427759,3.286217,dɪsgʌst
old-man-and-the-sea,19,19,588,C10,0.026011,-0.026348,7,0.632812,-3.234899,18.427759,3.286217,dɪsgʌst
old-man-and-the-sea,19,19,588,C22,0.012474,-0.246394,7,0.632812,-3.234899,18.427759,3.286217,dɪsgʌst
old-man-and-the-sea,19,19,588,C32,0.021017,-0.053439,7,0.632812,-3.234899,18.427759,3.286217,dɪsgʌst


In [13]:
# DEV for now export to do it in R vanilla :/
full_df.to_csv(out_path)