In [11]:
from pathlib import Path
import pickle

import pandas as pd
import mne
import mne_bids
import torch
from tqdm.auto import tqdm

In [82]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append(str(Path(".").resolve().parent.parent))

from berp.datasets import BerpDataset, NestedBerpDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [48]:
story_name = "lw1"
subject = "01"
session = 0

presentation_sounds = "lw1.01.0.sound.csv"
presentation_words = "lw1.01.0.word.csv"
presentation_phonemes = "lw1.01.0.phoneme.csv"

aligned_words = "word.csv"
aligned_phonemes = "phoneme.csv"

global_session_alignment = "session_alignment.csv"

stimulus = "lw1.pkl"

bids = "../../workflow/meg-masc/raw-data/sub-01/ses-0/meg/sub-01_ses-0_task-0_meg.con"

target_sample_rate = 128

## Load and process natural language stimulus and time series features

In [33]:
with Path(stimulus).open("rb") as f:
    story_stim = pickle.load(f)
    
assert story_stim.name == story_name

In [19]:
# Variable onset features are simply a variable onset intercept,
# word features and word surprisals.
X_variable = torch.concat(
    [torch.ones_like(story_stim.word_surprisals).unsqueeze(1),
     story_stim.word_features,
     story_stim.word_surprisals.unsqueeze(1)],
    dim=1)
variable_feature_names = ["recognition_onset", "word_frequency", "word_surprisal"]

assert X_variable.shape[1] == len(variable_feature_names)

In [None]:
# Load other stimulus time-series features.
# TODO

## Load aligned word/phoneme presentation data

In [93]:
word_aligned_df = pd.read_csv(aligned_words, index_col=0)
phoneme_aligned_df = pd.read_csv(aligned_phonemes, index_col=0)
word_aligned_df

Unnamed: 0,word_idx,story,story_uid,sound_id,kind,meg_file,start,sound,word,sequence_id,condition,word_index,speech_rate,voice,pronounced,onset,duration,value,sample,token_idx
0,0,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.000000,stimuli/audio/lw1_0.wav,Tara,0.0,sentence,0.0,205.0,Allison,1.0,23.506,0.30,697,23506,0
1,0,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.000000,stimuli/audio/lw1_0.wav,Tara,0.0,sentence,0.0,205.0,Allison,1.0,23.506,0.30,697,23506,1
2,1,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.310000,stimuli/audio/lw1_0.wav,stood,0.0,sentence,1.0,205.0,Allison,1.0,23.816,0.24,698,23816,2
3,2,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.550000,stimuli/audio/lw1_0.wav,stock,0.0,sentence,2.0,205.0,Allison,1.0,24.056,0.37,699,24056,3
4,3,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,1.080000,stimuli/audio/lw1_0.wav,still,0.0,sentence,3.0,205.0,Allison,1.0,24.586,0.40,700,24586,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
634,663,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.070000,stimuli/audio/lw1_3.wav,end,52.0,sentence,15.0,205.0,Allison,1.0,361.097,0.17,3119,361097,918
635,664,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.250000,stimuli/audio/lw1_3.wav,for,52.0,sentence,16.0,205.0,Allison,1.0,361.277,0.14,3120,361277,919
636,665,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.460000,stimuli/audio/lw1_3.wav,project,52.0,sentence,18.0,205.0,Allison,1.0,361.487,0.58,3121,361487,921
637,666,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,51.179999,stimuli/audio/lw1_3.wav,and,52.0,sentence,19.0,205.0,Allison,1.0,362.207,0.15,3122,362207,923


In [96]:
assert len(set(word_aligned_df.word_idx)) == len(story_stim.word_lengths)

# TODO why the mismatch?

AssertionError: 

## Load sound/word/phoneme presentation data

In [21]:
sound_presentation_df = pd.read_csv(presentation_sounds, index_col=0)
word_presentation_df = pd.read_csv(presentation_words, index_col=0)
phoneme_presentation_df = pd.read_csv(presentation_phonemes, index_col=0)

In [34]:
assert set(sound_presentation_df.story) == set(word_presentation_df.story)
assert set(sound_presentation_df.story) == set(phoneme_presentation_df.story)
assert set(sound_presentation_df.story) == {story_name}

In [90]:
word_presentation_df

Unnamed: 0,story,story_uid,sound_id,kind,meg_file,start,sound,word,sequence_id,condition,word_index,speech_rate,voice,pronounced,onset,duration,value,sample
0,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.000000,stimuli/audio/lw1_0.wav,Tara,0.0,sentence,0.0,205.0,Allison,1.0,23.506,0.30,697,23506
1,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.310000,stimuli/audio/lw1_0.wav,stood,0.0,sentence,1.0,205.0,Allison,1.0,23.816,0.24,698,23816
2,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,0.550000,stimuli/audio/lw1_0.wav,stock,0.0,sentence,2.0,205.0,Allison,1.0,24.056,0.37,699,24056
3,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,1.080000,stimuli/audio/lw1_0.wav,still,0.0,sentence,3.0,205.0,Allison,1.0,24.586,0.40,700,24586
4,lw1,0.0,0.0,word,A0167_MASC_1_16Mar17_01.con,1.630000,stimuli/audio/lw1_0.wav,waiting,0.0,sentence,4.0,205.0,Allison,1.0,25.136,0.41,701,25136
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
663,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.070000,stimuli/audio/lw1_3.wav,end,52.0,sentence,15.0,205.0,Allison,1.0,361.097,0.17,3119,361097
664,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.250000,stimuli/audio/lw1_3.wav,for,52.0,sentence,16.0,205.0,Allison,1.0,361.277,0.14,3120,361277
665,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,50.460000,stimuli/audio/lw1_3.wav,project,52.0,sentence,18.0,205.0,Allison,1.0,361.487,0.58,3121,361487
666,lw1,0.0,3.0,word,A0167_MASC_1_16Mar17_01.con,51.179999,stimuli/audio/lw1_3.wav,and,52.0,sentence,19.0,205.0,Allison,1.0,362.207,0.15,3122,362207


In [91]:
len(story_stim.word_lengths)

596

In [88]:
assert len(word_presentation_df) == len(story_stim.word_lengths)

AssertionError: 

## Load session alignment data

In [76]:
session_alignment_df = pd.read_csv(global_session_alignment, index_col=0,
                                   dtype={"subject": str})
session_alignment_df = session_alignment_df.set_index(["story_name", "subject", "session"]) \
    .loc[story_name].loc[subject].loc[session]

In [77]:
assert len(session_alignment_df) == len(sound_presentation_df)
assert set(session_alignment_df.sound_id) == set(sound_presentation_df.sound_id)
assert session_alignment_df.onset.min() <= word_presentation_df.onset.min()
assert session_alignment_df.onset.min() <= phoneme_presentation_df.onset.min()
assert session_alignment_df.onset.max() <= word_presentation_df.onset.max()
assert session_alignment_df.onset.max() <= phoneme_presentation_df.onset.max()

In [78]:
session_alignment_df

Unnamed: 0_level_0,sound_id,onset
session,Unnamed: 1_level_1,Unnamed: 2_level_1
0,0.0,23.506
0,1.0,127.185
0,2.0,210.048
0,3.0,311.027


## Load signal data

In [44]:
bids_path = mne_bids.get_bids_path_from_fname(bids)
raw = mne_bids.read_raw_bids(bids_path)

Extracting SQD Parameters from ../../workflow/meg-masc/raw-data/sub-01/ses-0/meg/sub-01_ses-0_task-0_meg.con...
Creating Raw.info structure...
Setting channel info structure...
Creating Info structure...
Ready.
Reading events from ../../workflow/meg-masc/raw-data/sub-01/ses-0/meg/sub-01_ses-0_task-0_events.tsv.
Reading channel info from ../../workflow/meg-masc/raw-data/sub-01/ses-0/meg/sub-01_ses-0_task-0_channels.tsv.
The stimulus channel "STI 014" is present in the raw data, but not included in channels.tsv. Removing the channel.


In [49]:
# Preprocessing to match https://github.com/kingjr/meg-masc/blob/main/check_decoding.py
raw = raw.pick_types(meg=True, misc=False, eeg=False, eog=False, ecg=False)
raw = raw.load_data().filter(0.5, 30.0, n_jobs=1)
raw = raw.resample(target_sample_rate)
raw

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 6601 samples (6.601 sec)



0,1
Measurement date,"January 01, 2000 00:00:00 GMT"
Experimenter,mne_anonymize
Digitized points,0 points
Good channels,208 Magnetometers
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,128.00 Hz
Highpass,0.50 Hz
Lowpass,30.00 Hz


In [62]:
# Check compatibility with presentation data and alignment data.
min_time, max_time = raw.times.min(), raw.times.max()

assert min_time <= sound_presentation_df.onset.min()
assert min_time <= word_presentation_df.onset.min()
assert min_time <= phoneme_presentation_df.onset.min()

assert max_time >= sound_presentation_df.onset.max()
assert max_time >= word_presentation_df.onset.max()
assert max_time >= phoneme_presentation_df.onset.max()

assert min_time <= session_alignment_df.onset.min()
assert max_time >= session_alignment_df.onset.max()

In [80]:
# Use signal time series to calculate duration of sound segments (including final segment)
session_alignment_df["offset"] = session_alignment_df.onset.shift(-1, fill_value=raw.times.max())
session_alignment_df["duration"] = session_alignment_df.offset - session_alignment_df.onset
session_alignment_df

Unnamed: 0_level_0,sound_id,onset,offset,duration
session,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.0,23.506,127.185,103.679
0,1.0,127.185,210.048,82.863
0,2.0,210.048,311.027,100.979
0,3.0,311.027,395.992188,84.965188


## Chop time series so that sounds come predictably

In [87]:
# NB will insert a BAD/EDGE boundary at each concatenation point
concat_raw = mne.concatenate_raws([
    raw.copy().crop(tmin=row.onset, tmax=row.offset, include_tmax=False)
    for _, row in tqdm(session_alignment_df.iterrows(), total=len(session_alignment_df))
])

  0%|          | 0/4 [00:00<?, ?it/s]

In [85]:
concat_raw.annotations

<Annotations | 3139 segments: BAD boundary (3), EDGE boundary (3), ...>

In [None]:
BerpDataset(
    name=f"{story_name}/{subject}_{session}",
    stimulus_name=story_stim.name,
    sample_rate=int(concat_raw.info["sfreq"]),
    
    phonemes=story_stim.phonemes,
    
    word_onsets=