In [44]:
from pathlib import Path
import pickle

import pandas as pd
import mne
import mne_bids
import torch
from tqdm.auto import tqdm

In [45]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append(str(Path(".").resolve().parent.parent))

from berp.datasets import BerpDataset, NestedBerpDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [46]:
story_name = "lw1"
subject = "01"
session = 0

presentation_sounds = "lw1.01.0.sound.csv"
presentation_words = "lw1.01.0.word.csv"
presentation_phonemes = "lw1.01.0.phoneme.csv"

aligned_words = "word.csv"
aligned_phonemes = "phoneme.csv"

global_session_alignment = "session_alignment.csv"

stimulus = "lw1.pkl"

bids = "../../workflow/meg-masc/raw-data/sub-01/ses-0/meg/sub-01_ses-0_task-0_meg.con"

target_sample_rate = 128

## Load and process natural language stimulus and time series features

In [47]:
with Path(stimulus).open("rb") as f:
    story_stim = pickle.load(f)
    
assert story_stim.name == story_name

In [48]:
# Variable onset features are simply a variable onset intercept,
# word features and word surprisals.
X_variable = torch.concat(
    [torch.ones_like(story_stim.word_surprisals).unsqueeze(1),
     story_stim.word_features,
     story_stim.word_surprisals.unsqueeze(1)],
    dim=1)
variable_feature_names = ["recognition_onset", "word_frequency", "word_surprisal"]

assert X_variable.shape[1] == len(variable_feature_names)

In [49]:
# Load other stimulus time-series features.
# TODO

## Load aligned word/phoneme presentation data

In [50]:
word_aligned_df = pd.read_csv(aligned_words, index_col=0)
phoneme_aligned_df = pd.read_csv(aligned_phonemes, index_col=0)
word_aligned_df

Unnamed: 0,word_idx,story,story_uid,sound_id,kind,meg_file,start,sound,word,sequence_id,condition,word_index,speech_rate,voice,pronounced,onset,duration,value,sample,token_idx
0,0,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.000000,stimuli/audio/lw1_0.wav,Tara,0.0,sentence,0.0,205.0,Allison,1.0,23.506,0.30,697,23506,0
1,0,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.000000,stimuli/audio/lw1_0.wav,Tara,0.0,sentence,0.0,205.0,Allison,1.0,23.506,0.30,697,23506,1
2,1,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.310000,stimuli/audio/lw1_0.wav,stood,0.0,sentence,1.0,205.0,Allison,1.0,23.816,0.24,698,23816,2
3,2,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.550000,stimuli/audio/lw1_0.wav,stock,0.0,sentence,2.0,205.0,Allison,1.0,24.056,0.37,699,24056,3
4,3,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,1.080000,stimuli/audio/lw1_0.wav,still,0.0,sentence,3.0,205.0,Allison,1.0,24.586,0.40,700,24586,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
634,663,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.070000,stimuli/audio/lw1_3.wav,end,52.0,sentence,15.0,205.0,Allison,1.0,361.097,0.17,3119,361097,918
635,664,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.250000,stimuli/audio/lw1_3.wav,for,52.0,sentence,16.0,205.0,Allison,1.0,361.277,0.14,3120,361277,919
636,665,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.460000,stimuli/audio/lw1_3.wav,project,52.0,sentence,18.0,205.0,Allison,1.0,361.487,0.58,3121,361487,921
637,666,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,51.179999,stimuli/audio/lw1_3.wav,and,52.0,sentence,19.0,205.0,Allison,1.0,362.207,0.15,3122,362207,923


In [51]:
# Not all words in the presentation will be retained in the stimulus: words which
# ended up at the start of an input time series to the model were dropped, because
# the model didn't have those values as targets.
#
# By the above logic, the missing words will likely be distributed roughly evenly
# throughout the stimulus.
set(word_aligned_df.word_idx) - set(story_stim.word_ids.numpy())

{25,
 54,
 146,
 183,
 208,
 231,
 257,
 285,
 307,
 317,
 344,
 390,
 427,
 477,
 530,
 593,
 617}

In [52]:
assert 0 == len(set(story_stim.word_ids.numpy()) - set(word_aligned_df.word_idx)), \
    "Stim words are present which are missing from the aligned data!"

## Load sound/word/phoneme presentation data

In [53]:
sound_presentation_df = pd.read_csv(presentation_sounds, index_col=0)
word_presentation_df = pd.read_csv(presentation_words, index_col=0)
phoneme_presentation_df = pd.read_csv(presentation_phonemes, index_col=0)

In [54]:
assert set(sound_presentation_df.story) == set(word_presentation_df.story)
assert set(sound_presentation_df.story) == set(phoneme_presentation_df.story)
assert set(sound_presentation_df.story) == {story_name}

In [55]:
word_presentation_df

Unnamed: 0,story,story_uid,sound_id,kind,meg_file,start,sound,word,sequence_id,condition,word_index,speech_rate,voice,pronounced,onset,duration,value,sample
0,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.000000,stimuli/audio/lw1_0.wav,Tara,0.0,sentence,0.0,205.0,Allison,1.0,23.506,0.30,697,23506
1,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.310000,stimuli/audio/lw1_0.wav,stood,0.0,sentence,1.0,205.0,Allison,1.0,23.816,0.24,698,23816
2,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.550000,stimuli/audio/lw1_0.wav,stock,0.0,sentence,2.0,205.0,Allison,1.0,24.056,0.37,699,24056
3,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,1.080000,stimuli/audio/lw1_0.wav,still,0.0,sentence,3.0,205.0,Allison,1.0,24.586,0.40,700,24586
4,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,1.630000,stimuli/audio/lw1_0.wav,waiting,0.0,sentence,4.0,205.0,Allison,1.0,25.136,0.41,701,25136
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
663,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.070000,stimuli/audio/lw1_3.wav,end,52.0,sentence,15.0,205.0,Allison,1.0,361.097,0.17,3119,361097
664,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.250000,stimuli/audio/lw1_3.wav,for,52.0,sentence,16.0,205.0,Allison,1.0,361.277,0.14,3120,361277
665,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.460000,stimuli/audio/lw1_3.wav,project,52.0,sentence,18.0,205.0,Allison,1.0,361.487,0.58,3121,361487
666,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,51.179999,stimuli/audio/lw1_3.wav,and,52.0,sentence,19.0,205.0,Allison,1.0,362.207,0.15,3122,362207


In [56]:
phoneme_presentation_df

Unnamed: 0,story,story_uid,sound_id,kind,meg_file,start,sound,phoneme,sequence_id,condition,word_index,speech_rate,voice,pronounced,onset,duration,value,sample
0,lw1,0.0,0.0,phoneme,A0167_MASC_1_16Mar17_01.con,0.00,stimuli/audio/lw1_0.wav,t_B,0.0,sentence,0.0,205.0,Allison,1.0,23.506,0.08,5,23506
1,lw1,0.0,0.0,phoneme,A0167_MASC_1_16Mar17_01.con,0.08,stimuli/audio/lw1_0.wav,eh_I,0.0,sentence,0.0,205.0,Allison,1.0,23.586,0.09,6,23586
2,lw1,0.0,0.0,phoneme,A0167_MASC_1_16Mar17_01.con,0.17,stimuli/audio/lw1_0.wav,r_I,0.0,sentence,0.0,205.0,Allison,1.0,23.676,0.07,7,23676
3,lw1,0.0,0.0,phoneme,A0167_MASC_1_16Mar17_01.con,0.24,stimuli/audio/lw1_0.wav,ah_E,0.0,sentence,0.0,205.0,Allison,1.0,23.746,0.06,8,23746
4,lw1,0.0,0.0,phoneme,A0167_MASC_1_16Mar17_01.con,0.31,stimuli/audio/lw1_0.wav,s_B,0.0,sentence,1.0,205.0,Allison,1.0,23.816,0.06,9,23816
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2457,lw1,0.0,3.0,phoneme,A0167_MASC_1_16Mar17_01.con,51.85,stimuli/audio/lw1_3.wav,p_I,52.0,sentence,22.0,205.0,Allison,1.0,362.877,0.09,2962,362877
2458,lw1,0.0,3.0,phoneme,A0167_MASC_1_16Mar17_01.con,51.94,stimuli/audio/lw1_3.wav,iy_I,52.0,sentence,22.0,205.0,Allison,1.0,362.967,0.09,2963,362967
2459,lw1,0.0,3.0,phoneme,A0167_MASC_1_16Mar17_01.con,52.03,stimuli/audio/lw1_3.wav,sh_I,52.0,sentence,22.0,205.0,Allison,1.0,363.057,0.08,2964,363057
2460,lw1,0.0,3.0,phoneme,A0167_MASC_1_16Mar17_01.con,52.11,stimuli/audio/lw1_3.wav,iy_I,52.0,sentence,22.0,205.0,Allison,1.0,363.137,0.01,2965,363137


In [57]:
# Except for non-sentence conditions, the presentation should match exactly
# what we have in alignment data. (Specifically the word and phoneme indices should
# match exactly.)
assert set(word_presentation_df[word_presentation_df.condition == "sentence"].index) == set(word_aligned_df.word_idx)
assert set(phoneme_presentation_df[phoneme_presentation_df.condition == "sentence"].index) == set(phoneme_aligned_df.index)

## Load session alignment data

In [None]:
session_alignment_df = pd.read_csv(global_session_alignment, index_col=0,
                                   dtype={"subject": str})
session_alignment_df = session_alignment_df.set_index(["story_name", "subject", "session"]) \
    .loc[story_name].loc[subject].loc[session]

In [None]:
assert len(session_alignment_df) == len(sound_presentation_df)
assert set(session_alignment_df.sound_id) == set(sound_presentation_df.sound_id)
assert session_alignment_df.onset.min() <= word_presentation_df.onset.min()
assert session_alignment_df.onset.min() <= phoneme_presentation_df.onset.min()
assert session_alignment_df.onset.max() <= word_presentation_df.onset.max()
assert session_alignment_df.onset.max() <= phoneme_presentation_df.onset.max()

In [None]:
session_alignment_df

## Load signal data

In [None]:
bids_path = mne_bids.get_bids_path_from_fname(bids)
raw = mne_bids.read_raw_bids(bids_path)

In [None]:
# Preprocessing to match https://github.com/kingjr/meg-masc/blob/main/check_decoding.py
raw = raw.pick_types(meg=True, misc=False, eeg=False, eog=False, ecg=False)
raw = raw.load_data().filter(0.5, 30.0, n_jobs=1)
raw = raw.resample(target_sample_rate)
raw

In [None]:
# Check compatibility with presentation data and alignment data.
min_time, max_time = raw.times.min(), raw.times.max()

assert min_time <= sound_presentation_df.onset.min()
assert min_time <= word_presentation_df.onset.min()
assert min_time <= phoneme_presentation_df.onset.min()

assert max_time >= sound_presentation_df.onset.max()
assert max_time >= word_presentation_df.onset.max()
assert max_time >= phoneme_presentation_df.onset.max()

assert min_time <= session_alignment_df.onset.min()
assert max_time >= session_alignment_df.onset.max()

In [None]:
# Use signal time series to calculate duration of sound segments (including final segment)
session_alignment_df["offset"] = session_alignment_df.onset.shift(-1, fill_value=raw.times.max())
session_alignment_df["duration"] = session_alignment_df.offset - session_alignment_df.onset
session_alignment_df

## Chop time series so that sounds come predictably

In [None]:
# NB will insert a BAD/EDGE boundary at each concatenation point
concat_raw = mne.concatenate_raws([
    raw.copy().crop(tmin=row.onset, tmax=row.offset, include_tmax=False)
    for _, row in tqdm(session_alignment_df.iterrows(), total=len(session_alignment_df))
])

In [None]:
concat_raw.annotations

In [None]:
BerpDataset(
    name=f"{story_name}/{subject}_{session}",
    stimulus_name=story_stim.name,
    sample_rate=int(concat_raw.info["sfreq"]),
    
    phonemes=story_stim.phonemes,
    
    word_onsets=