In [1]:
# boilerplate for quick jupyter experiments
import sys
sys.path.append('/usr0/home/drschwar/src/paradigms')
sys.path.append('/usr0/home/drschwar/src/subplot_artist')
import socket
import os
import string
from functools import partial
from itertools import chain
import mne
import numpy
from scipy.stats import boxcox
import tqdm
%matplotlib inline
from matplotlib import gridspec, pyplot as plt
from paradigms import Loader

current_machine = socket.gethostname()
current_machine = current_machine.lower() if current_machine is not None else ''

if current_machine == 'drschwar-xps13':
    data_root = r'C:\Users\danrs\Documents\BrainGroupData\meg'
    inv_path = os.path.join(
        r'C:\Users\danrs\Documents\BrainGroupData\meg', 
        '{experiment}' 'inv_{subject}_{experiment}_trans-D_nsb-5_cb-0_raw-{structural}-7-0.2-0.8-limitTrue-rankNone-inv.fif')
    struct_dir = r'C:\Users\danrs\Documents\BrainGroupData\meg\structural' 
    session_stimuli_path = os.path.join(
        r'C:\Users\danrs\Documents\BrainGroupData\meg', 
        '{experiment}', '{subject}_sentenceBlock.mat')
    word2vec_path = None
    glove_path = None
else:
    data_root = '/share/volume0/newmeg/'
    inv_path = '/share/volume0/newmeg/{experiment}/data/inv/{subject}/{subject}_{experiment}_trans-D_nsb-5_cb-0_raw-{structural}-7-0.2-0.8-limitTrue-rankNone-inv.fif'
    struct_dir = '/share/volume0/drschwar/structural'
    session_stimuli_path = '/share/volume0/newmeg/{experiment}/meta/{subject}/sentenceBlock.mat'
    word2vec_path = '/share/volume0/language_representation/models/googlenews/GoogleNews-vectors-negative300.bin'
    glove_path = '/share/volume0/drschwar/GloVe/glove.840B.300d.fmt_w2v.bin'
    
recording_tuple_regex = Loader.make_standard_recording_tuple_regex(
    'trans-D_nsb-5_cb-0_empty-4-10-2-2_band-1-150_notch-60-120_beats-head-meas_blinks-head-meas')
loader = Loader(session_stimuli_path, data_root, recording_tuple_regex, inv_path, struct_dir)


In [2]:
def plot_average_data(mne_raw, stimuli, time_before_stimulus=0.2, duration=3.0, picks=None):
    
    stimulus_data = list()
    for index_stimulus in range(len(stimuli)):
        start = stimuli[index_stimulus][Stimulus.time_stamp_attribute_name]
        start_sample = numpy.searchsorted(mne_raw.times, start-time_before_stimulus)    
        end_sample = int(start_sample + (duration + time_before_stimulus) * mne_raw.info['sfreq'])
        data, times = mne_raw[:, start_sample:end_sample]
        if picks is not None:
            data = data[picks]
        stimulus_data.append(data)
        
    stimulus_data = numpy.array(stimulus_data)
    stimulus_data = numpy.mean(stimulus_data, axis=0)
    grid = gridspec.GridSpec(1, 1)
    fig = plt.figure(figsize=(16, 4))
    axes = fig.add_subplot(grid[0, 0])
    axes.matshow(stimulus_data, interpolation=None)
    time = numpy.arange(int(-time_before_stimulus * mne_raw.info['sfreq']), int(duration * mne_raw.info['sfreq'])) / float(mne_raw.info['sfreq'])
    for e in range(0, int(duration * mne_raw.info['sfreq']), int(0.5 * mne_raw.info['sfreq'])):
        ev = e + time_before_stimulus * mne_raw.info['sfreq']
        axes.axvline(ev, color='blue')
        axes.axvline(ev + 50, color='red')
    axes.set_xticklabels([time[int(t)] if 0 <= t < len(time) else t for t in axes.get_xticks()])
    plt.show()

In [3]:
def _stimulus_to_100ms_slices(s):
    return [(
        (s['master_stimulus_index'], idx_time * 0.1), 
        s['time_stamp'] + 0.1 * idx_time) for idx_time in range(5)]

def extract_stimuli_from_raw_time_courses(mne_raw, time_courses, stimuli, stimulus_to_name_time_pairs, num_samples):
    names = list()
    result = list()
    for item in chain.from_iterable(map(stimulus_to_name_time_pairs, stimuli)):
        if len(item) != 2:
            raise ValueError('Expected stimulus_to_name_time_pairs to return a list of pairs for each '
                             'stimulus. Are you returning just a single pair? Got: {}'.format(item))
        name, time = item
        sample_index = numpy.searchsorted(mne_raw.times, time, side='left')
        result.append(numpy.expand_dims(time_courses[:, sample_index:(sample_index + num_samples)], 0))
    return names, numpy.concatenate(result)   

# picks = mne.pick_types(mne_raw.info, meg=True)
# mne_raw.load_data()
# plot_average_data(mne_raw, stimuli, time_before_stimulus=0.2, picks=picks)

subject_labels = dict()
subject_rois = dict()

with mne.utils.use_log_level(False):

    for subject in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']:

        print('processing {}'.format(subject))

        structurals = {
            'A': 'struct4',
            'B': 'struct5',
            'C': 'struct6',
            'D': 'krns5D',
            'E': 'struct2',
            'F': 'krns5A',
            'G': 'struct1',
            'H': 'struct3',
            'I': 'krns5C'
        }

        inv, labels = loader.load_structural('harryPotter', subject, structurals[subject])
        
        roi_means = list()
        for block in ['1', '2', '3', '4']:
            print('loading block {}'.format(block))
            mne_raw, stimuli, _ = loader.load_block('harryPotter', subject, block)
            # According to the mne people, SNR of 1 (i.e. setting lambda2=1) is what you want for single trial data
            # https://github.com/mne-tools/mne-python/issues/4131
            print('computing sources')
            source_estimate = mne.minimum_norm.apply_inverse_raw(mne_raw, inv, lambda2=1.0)
            # time_course is (labels, time)
            print('extracting time course')
            time_course = source_estimate.extract_label_time_course(labels, inv['src'], mode='pca_flip', verbose=False)
            source_estimate = None  # release memory
            # time_course is now (num_keys, labels, 100)
            print('getting 100ms slices')
            keys, time_course = extract_stimuli_from_raw_time_courses(
                mne_raw, time_course, stimuli, _stimulus_to_100ms_slices, 100)
            # take the mean of the 100ms -> (num_keys, labels)
            time_course = numpy.mean(time_course, axis=2)
            roi_means.append(time_course)
            

#         epochs, keys = loader.load_epochs(
#             'harryPotter', subject, ['1', '2', '3', '4'], _stimulus_to_100ms_slices, tmin=0, tmax=0.099, add_eeg_ref=False)

#         # According to the mne people, SNR of 1 (i.e. setting lambda2=1) is what you want for single trial data
#         # https://github.com/mne-tools/mne-python/issues/4131
#         roi_means = list()
#         for source_estimate in tqdm.tqdm(mne.minimum_norm.apply_inverse_epochs(
#                 epochs, inv, lambda2=1.0, return_generator=True), 
#                 leave=False, total=len(keys), miniters=0):
#             current_means = numpy.mean(
#                 source_estimate.extract_label_time_course(labels, inv['src'], mode='mean', verbose=False), axis=1)
#             roi_means.append(numpy.expand_dims(current_means, 0))

        subject_rois[subject] = numpy.concatenate(roi_means, axis=0)
        subject_labels[subject] = [label.name for label in labels]

processing A
loading block 1
computing sources
extracting time course
getting 100ms slices
loading block 2
computing sources
extracting time course
getting 100ms slices
loading block 3
computing sources
extracting time course
getting 100ms slices
loading block 4
computing sources
extracting time course
getting 100ms slices
processing B
loading block 1
computing sources
extracting time course
getting 100ms slices
loading block 2
computing sources
extracting time course
getting 100ms slices
loading block 3
computing sources
extracting time course
getting 100ms slices
loading block 4
computing sources
extracting time course
getting 100ms slices
processing C
loading block 1
computing sources
extracting time course
getting 100ms slices
loading block 2
computing sources
extracting time course
getting 100ms slices
loading block 3
computing sources
extracting time course
getting 100ms slices
loading block 4
computing sources
extracting time course
getting 100ms slices
processing D
loading bloc

In [4]:
first_subject_labels = None
for subject in subject_labels:
    if first_subject_labels is None:
        first_subject_labels = subject_labels[subject]
    else:
        assert(numpy.array_equal(first_subject_labels, subject_labels[subject]))

In [4]:
len(subject_rois)

0

In [5]:
subjects = sorted(subject_rois)
rois = [subject_rois[subject] for subject in subjects]


In [6]:
rois = [numpy.reshape(r, (r.shape[0] // 5, 5, -1)) for r in rois]

In [7]:
for r in rois:
    print(r.shape)

(5174, 5, 68)
(5176, 5, 68)
(5176, 5, 68)
(5176, 5, 68)
(5176, 5, 68)
(5176, 5, 68)
(5176, 5, 68)
(5176, 5, 68)
(5176, 5, 68)


In [9]:
import itertools
all_stimuli = list()
for s in subjects:
    stimuli = list()
    block_id = list()
    for block in ['1', '2', '3', '4']:
        _, block_stimuli, _ = loader.load_block('harryPotter', s, block)
        block_id.extend([block] * len(block_stimuli))
        stimuli.extend([b.text for b in block_stimuli])
    all_stimuli.append((stimuli, block_id))

In [27]:
# correct missing '+' stimuli for 'A'
corrected = rois[0]
for idx in range(max([len(s[0]) for s in all_stimuli])):
    if all_stimuli[1][0][idx] == '+':
        assert(s[0][idx] == '+' for s in all_stimuli[2:])
        if all_stimuli[0][0][idx] != '+':
            print('inserting at {}'.format(idx))
            corrected_stimuli = all_stimuli[0][0][0:idx] + ['+'] + all_stimuli[0][0][idx:]
            corrected_block_id = all_stimuli[0][1][0:idx] + [all_stimuli[0][1][idx - 1]] + all_stimuli[0][1][idx:]
            all_stimuli[0] = corrected_stimuli, corrected_block_id
            corrected = numpy.concatenate(
                [corrected[0:idx], numpy.full((1,) + corrected.shape[1:], numpy.nan), corrected[idx:]])
    assert(all([s[0][idx] == all_stimuli[0][0][idx] for s in all_stimuli[1:]]))
    

inserting at 1302
inserting at 2653


In [29]:
for idx_subject in range(len(all_stimuli)):
    all_stimuli[idx_subject] = all_stimuli[idx_subject][0], [int(b) for b in all_stimuli[idx_subject][1]]

In [30]:
rois[0] = corrected

In [32]:
print(all_stimuli[0][1])

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [33]:
rois = numpy.concatenate([numpy.expand_dims(r, 0) for r in rois], axis=0)

In [35]:
print(rois.shape)

(9, 5176, 5, 68)


In [36]:
rois = numpy.transpose(rois, axes=(0, 1, 3, 2))

In [37]:
rois.shape

(9, 5176, 68, 5)

In [40]:
numpy.savez(
    '/usr0/home/drschwar/src/ulmfit_data/harry_potter_pca.npz',
    stimuli=numpy.array(all_stimuli[0][0]),
    blocks=numpy.array(all_stimuli[0][1], dtype=numpy.int64),
    rois=numpy.array(first_subject_labels),
    subjects=numpy.array(subjects),
    data=rois)

In [2]:
mne_raw, stimuli, _ = loader.load_block('harryPotter', 'A', '1')

In [41]:
data = numpy.load('/usr0/home/drschwar/src/ulmfit_data/harry_potter_pca.npz')

In [51]:
data['data']

array([[[[2.50855019, 4.62644884, 3.11857033, 2.66599053, 2.0550925 ],
         [1.88327035, 2.2761118 , 1.71844959, 1.53066243, 1.27884711],
         [1.87729142, 1.39619447, 2.14151152, 1.57469817, 1.92370568],
         ...,
         [2.65242209, 1.73760765, 3.42613867, 2.45927217, 3.62275474],
         [2.74493767, 4.02696994, 3.49896798, 2.93642855, 2.74324888],
         [1.91200418, 2.49459163, 1.7264115 , 2.4108514 , 1.57914565]],

        [[1.391914  , 2.18727557, 3.25229419, 2.61531819, 2.63281145],
         [1.68505411, 1.445136  , 1.87278115, 1.15040324, 1.3474002 ],
         [1.78733592, 1.41656226, 1.4776016 , 1.66545329, 1.49950228],
         ...,
         [3.10360542, 1.60999585, 3.72334601, 3.31970575, 2.83481224],
         [2.31270869, 2.71260624, 3.22678335, 3.01674638, 2.3585981 ],
         [1.77200711, 1.78594323, 2.13122096, 2.20889252, 1.76494783]],

        [[2.09914668, 2.38242514, 3.56007668, 1.976317  , 3.66097077],
         [1.25402798, 1.30874499, 1.7835789 ,