Concatenate a sequence of `BerpDataset`s and their corresponding `NaturalLanguageStimulus`es into a single `BerpDataset`.

Why do this? So that downstream consumers (e.g. the TRF pipeline) see the dataset as one single time series, and don't
unnecessarily chop it up and create many invalid boundaries.

In [1]:
from copy import deepcopy
import itertools
from pathlib import Path
import pickle
import sys
from typing import Optional

import numpy as np
import torch
from torch.nn.functional import pad
from tqdm.auto import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
sys.path.append("../..")
from berp.datasets import BerpDataset, NaturalLanguageStimulus, Vocabulary
from berp.datasets.base import assert_compatible
from berp.datasets.eeg import load_eeg_dataset
from berp.util import sample_to_time

In [4]:
datasets = [
    "../../workflow/heilbron2022/data/dataset/distilgpt2/old-man-and-the-sea/sub17/run1.pkl",
    "../../workflow/heilbron2022/data/dataset/distilgpt2/old-man-and-the-sea/sub17/run2.pkl",
    "../../workflow/heilbron2022/data/dataset/distilgpt2/old-man-and-the-sea/sub17/run3.pkl",
]

stimuli = {
    "old-man-and-the-sea/run1": "../../workflow/heilbron2022/data/stimulus/distilgpt2/old-man-and-the-sea/run1.pkl",
    "old-man-and-the-sea/run2": "../../workflow/heilbron2022/data/stimulus/distilgpt2/old-man-and-the-sea/run2.pkl",
    "old-man-and-the-sea/run3": "../../workflow/heilbron2022/data/stimulus/distilgpt2/old-man-and-the-sea/run3.pkl",
}

target_dataset_name = "old-man-and-the-sea/sub17"
target_stimulus_name = "old-man-and-the-sea"

out_dataset = "merged.old-man-and-the-sea.sub17.pkl"
out_stimulus = "merged.old-man-and-the-sea.pkl"
# If True, pickle stimulus representation
save_stimulus = True
# If True load stimulus at given out path and check equality
check_stimulus = True

In [5]:
if save_stimulus and check_stimulus:
    import warnings
    warnings.warn("Both save_stimulus and check_stimulus are enabled. Are you sure?")



## Load datasets

In [6]:
nested_datasets = load_eeg_dataset(
    datasets,
    # NB don't do any normalization!
    normalize_X_ts=False, normalize_X_variable=False, normalize_Y=False,
    stimulus_paths=stimuli)

In [7]:
assert len(nested_datasets.datasets) == len(datasets)
assert all(ds.global_slice_indices == None for ds in nested_datasets.datasets)

In [8]:
ds1 = nested_datasets.datasets[0]
for ds2 in nested_datasets.datasets[1:]:
    assert_compatible(ds1, ds2)

In [9]:
# Compute onset of each dataset in the eventual merged data.
dataset_durations = [sample_to_time(len(dataset) + 1, dataset.sample_rate)
                     for dataset in nested_datasets.datasets]
dataset_onsets = np.concatenate([[0], np.cumsum(dataset_durations[:-1])])
dataset_onsets

array([  0.      , 170.359375, 334.09375 ])

In [10]:
all_ds = nested_datasets.datasets
ds1 = all_ds[0]

## Load stimulus representations

In [11]:
stims = {}
for name, path in stimuli.items():
    with open(path, "rb") as f:
        stims[name] = pickle.load(f)

In [12]:
stim_order = [ds.stimulus_name for ds in all_ds]
stims_list = [stims[name] for name in stim_order]

## Merge datasets

In [13]:
# Pad phoneme_onsets as necessary
max_num_phonemes = max(ds.max_n_phonemes for ds in all_ds)
assert max_num_phonemes == max(stim.max_n_phonemes for stim in stims_list)

# NB we don't apply dataset-level onset here, since phoneme_onsets are all relative
# to word onset
merged_phoneme_onsets = [
    pad(ds.phoneme_onsets, (0, max_num_phonemes - ds.phoneme_onsets.shape[1]), value=0)
    if ds.phoneme_onsets.shape[1] < max_num_phonemes
    else ds.phoneme_onsets
    for ds in all_ds
]
merged_phoneme_onsets = torch.cat(merged_phoneme_onsets)

In [14]:
merged_dataset = BerpDataset(
    name=target_dataset_name, stimulus_name=target_stimulus_name,
    sample_rate=ds1.sample_rate,
    
    # Shift word onset and offset information based on dataset onset
    word_onsets=torch.cat([ds.word_onsets + dataset_onset for ds, dataset_onset in zip(all_ds, dataset_onsets)]),
    word_offsets=torch.cat([ds.word_offsets + dataset_onset for ds, dataset_onset in zip(all_ds, dataset_onsets)]),
    # NB we don't apply dataset-level onset here, since phoneme_onsets are all relative
    # to word onset
    phoneme_onsets=merged_phoneme_onsets,
    
    X_ts=torch.cat([ds.X_ts for ds in all_ds]),
    X_variable=torch.cat([ds.X_variable for ds in all_ds]),
    Y=torch.cat([ds.Y for ds in all_ds]),
    
    sensor_names=ds1.sensor_names,
    phonemes=ds1.phonemes,
    
    ts_feature_names=ds1.ts_feature_names,
    variable_feature_names=ds1.variable_feature_names)

## Merge stimuli

In [15]:
stim1 = stims_list[0]
for stim2 in stims_list[1:]:
    assert stim1.phonemes == stim2.phonemes
    assert stim1.pad_phoneme_id == stim2.pad_phoneme_id
    assert stim1.word_features.shape[1] == stim2.word_features.shape[1]
    # candidate size should match
    assert stim1.p_candidates.shape[1] == stim2.p_candidates.shape[1]
    assert stim1.candidate_ids.shape[1] == stim2.candidate_ids.shape[1]
    
    assert stim1.word_feature_names == stim2.word_feature_names
    assert stim1.phoneme_feature_names == stim2.phoneme_feature_names

In [16]:
word_id_offsets = np.concatenate([[0], np.cumsum([stim.word_ids.max().item() + 1 for stim in stims_list])[:-1]])
merged_word_ids = torch.cat([stim.word_ids + word_id_offsets[i] for i, stim in enumerate(stims_list)])

# Merged word IDs should be unique.
assert (torch.bincount(merged_word_ids) <= 1).all()

In [17]:
# Merge vocabularies.
merged_vocabulary = deepcopy(stim1.candidate_vocabulary)
print(f"Size of stim1 vocabulary: {len(merged_vocabulary)}")

for stim2 in tqdm(stims_list[1:]):
    for vocab2_idx, vocab2_tok in enumerate(stim2.candidate_vocabulary.idx2tok):
        vocab1_idx = merged_vocabulary.add(vocab2_tok)
        stim2.candidate_ids[stim2.candidate_ids == vocab2_idx] = vocab1_idx
        
print(f"Size of merged vocabulary: {len(merged_vocabulary)}")

Size of stim1 vocabulary: 17099


  0%|          | 0/2 [00:00<?, ?it/s]

Size of merged vocabulary: 19520


In [18]:
merged_stimulus = NaturalLanguageStimulus(
    name=target_stimulus_name,
    phonemes=stim1.phonemes,
    pad_phoneme_id=stim1.pad_phoneme_id,
    
    word_ids=merged_word_ids,
    word_lengths=torch.cat([stim.word_lengths for stim in stims_list]),
    word_features=torch.cat([stim.word_features for stim in stims_list]),
    word_feature_names=stim1.word_feature_names,
    
    phoneme_features=list(itertools.chain.from_iterable(stim.phoneme_features for stim in stims_list)),
    phoneme_feature_names=stim1.phoneme_feature_names,
    
    p_candidates=torch.cat([stim.p_candidates for stim in stims_list]),
    candidate_ids=torch.cat([stim.candidate_ids for stim in stims_list]),
    candidate_vocabulary=merged_vocabulary)

## Save dataset

In [19]:
with open(out_dataset, "wb") as f:
    pickle.dump(merged_dataset, f)

## Check stimulus equality

In [20]:
if check_stimulus:
    with open(out_stimulus, "rb") as f:
        ref_stimulus = pickle.load(f)
    assert ref_stimulus == merged_stimulus

AssertionError: 

## Save stimulus, possibly overwrite

In [None]:
if save_stimulus:
    with open(out_stimulus, "wb") as f:
        pickle.dump(merged_stimulus, f)