Visualize learned TRF coefficients.

In [None]:
from argparse import ArgumentParser, Namespace
from collections import defaultdict
from copy import deepcopy
import io
from itertools import product
from pathlib import Path
import pickle
import sys

from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import mne
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.base import clone
import torch
from tqdm.auto import tqdm, trange

In [None]:
import logging
L = logging.getLogger(__name__)

In [None]:
from IPython.display import HTML

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
sys.path.append(str(Path(".").resolve().parent.parent))
from berp.models import load_model
import berp.models.reindexing_regression as rr
from berp.viz.trf import trf_to_dataframe, plot_trf_coefficients
from berp.viz.trf_em import pipeline_to_dataframe

In [None]:
workflow = "heilbron2022"
model = "EleutherAI/gpt-neo-2.7B/n10000"

paradigm = "_wide"
trf_run_name = "trf-berp-fixed-t075"

plot_kwargs = dict(errorbar="se")

In [None]:
model_dir = f"workflow/{workflow}/results{paradigm}/{model}/{trf_run_name}"

## Load results

In [None]:
pipe = load_model(model_dir, device="cpu")

## Preprocessing

In [None]:
coef_df = pipeline_to_dataframe(pipe)

In [None]:
coef_df

In [None]:
coef_df.predictor_name.unique()

### Recognition-locked responses

In [None]:
sns.set("talk")

In [None]:
def plot_variable(data, **kwargs):
    plot_trf_coefficients(data, predictor_match_patterns=["var_"], **plot_kwargs)

g = sns.FacetGrid(data=coef_df, col="sensor_name", col_wrap=2, height=7,
                  sharex=False, sharey=False)
g.map_dataframe(plot_variable)
g.add_legend()

#### Surprisal modulation by subject

In [None]:
g = sns.FacetGrid(data=coef_df[coef_df.predictor_name == "var_word_surprisal"],
                  col="sensor_name", col_wrap=2, height=7,
                  sharex=False, sharey=False)
g.map_dataframe(sns.lineplot, "epoch_time", "coef", "subject", **plot_kwargs)
g.add_legend()

#### Word onset

In [None]:
sns.set("talk")
plt.figure(figsize=(10, 7))
plot_trf_coefficients(coef_df, predictor_match_patterns=["word_onset"], **plot_kwargs)
plt.title("Word onset-locked responses")
None

In [None]:
sns.set("talk")
plt.figure(figsize=(10, 7))
plot_trf_coefficients(coef_df, predictor_match_patterns=["phoneme_onset", "all_phons_surprisals"], **plot_kwargs)
plt.title("Phoneme onset-locked responses")
None