In [1]:
import mne

import ssvepy

from autoreject import Ransac

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import os
import glob
import re
import collections
from datetime import date

from tqdm import tqdm_notebook as tqdm
from ipywidgets import interact

mne.utils.set_log_level('WARNING')



In [2]:
if os.path.isdir('/Users/jan/Documents/eeg-data/cancan-saturation/'):
    datafolder = '/Users/jan/Documents/eeg-data/cancan-saturation/'  # on laptop
elif os.path.isdir('/data/group/FANS/cancan/eeg/'):
    datafolder = '/data/group/FANS/cancan/eeg/'  # On NaN

files = glob.glob(datafolder + '*saturation*.vhdr')

In [73]:
ids = []
visits = []

for idx, file in enumerate(files):
    m = re.search('(\d+)[_-]*([ABCabc])', file)
    if m:
        ids.append(m.group(1)[-3:].zfill(3))
        visits.append(m.group(2).lower())
    else:
        print(m.group(2))


In [74]:
raws = [mne.io.read_raw_brainvision(file, event_id={'DCC': 199, 'actiCAP Data On': 200},
                                    montage=mne.channels.read_montage('standard_1020'))
        for file in files]

for idx, _ in enumerate(raws):
    raws[idx].info['subject_info'] = ids[idx] + visits[idx]

In [76]:
# Identify Epochs.
# this discards all files that hold no
# events (trigger cable not connected -.-)

epochs = [[mne.Epochs(raw,
                      mne.find_events(raw),
                      event_id=event,
                      tmin=0, tmax=10,
                      picks=mne.pick_types(raw.info, eeg=True))
           for event in [16, 32, 64, 100]]
          for raw in tqdm(raws, desc='Subject')
          if np.in1d([16, 32, 64, 100], mne.find_events(raw)).all()]




In [77]:
for subject in tqdm(epochs):
    for epoch in subject:
        epoch.load_data()
        epoch.resample(256)





In [78]:
%%capture
# Clean the data using autoreject's ransac
cleaners = [Ransac(verbose=False) for raw in raws]

cleanepochs = [[cleaner.fit_transform(epoch) for epoch in epochlist]
               for epochlist, cleaner in zip(tqdm(epochs), cleaners)]

In [79]:
%%capture
ssveps = [[ssvepy.Ssvep(epoch, 5.0, fmin=2, fmax=30) for epoch in epochlist]
           for epochlist in tqdm(cleanepochs)]

In [85]:
occipital_indices = [raws[0].ch_names.index(ch)
                     for ch in ['Oz', 'O1', 'O2', 'POz']]

dataarrays = collections.OrderedDict()

# Aggregate the amp from the occ. electrode with max SNR
dataarrays['maxamp_occipital'] = np.zeros((len(ssveps), len(ssveps[0])))
for subject, ssveplist in enumerate(ssveps):
    try:
        maxelec = np.nanargmax(np.stack(
                [ssvep.stimulation.snr[:, occipital_indices].mean(axis=0) for ssvep in ssveplist], axis=-1
            ).mean(axis=-1))
        maxelec = occipital_indices[maxelec]
    except ValueError:
        maxelec = None
    for condition, ssvep in enumerate(ssveplist):
        dataarrays['maxamp_occipital'][subject, condition] = ssvep.stimulation.power[:, maxelec].mean()

# Aggregate the amp from the electrode with overall max SNR
dataarrays['maxamp_all'] = np.zeros((len(ssveps), len(ssveps[0])))
for subject, ssveplist in enumerate(ssveps):
    maxelec = np.nanargmax(np.stack(
            [ssvep.stimulation.snr.mean(axis=0) for ssvep in ssveplist], axis=-1
        ).mean(axis=-1))
    for condition, ssvep in enumerate(ssveplist):
        dataarrays['maxamp_all'][subject, condition] = ssvep.stimulation.power[:, maxelec].mean()

# Get the average SNR at the occipital electrodes
dataarrays['avsnr_occipital'] = np.zeros((len(ssveps), len(ssveps[0])))
for subject, ssveplist in enumerate(ssveps):
    for condition, ssvep in enumerate(ssveplist):
        dataarrays['avsnr_occipital'][subject, condition] = np.nanmean(ssvep.stimulation.snr[:, occipital_indices])




## Save data to CSV file

In [86]:
# Construct the columns for pandas

datadict = collections.OrderedDict()  # nice to have order
datadict['id'] = [ssvep[0].info['subject_info'][0:-1] for ssvep in ssveps]
datadict['visit'] = [ssvep[0].info['subject_info'][-1] for ssvep in ssveps]
for label, data in dataarrays.items():
    for idx, condition in enumerate([16, 32, 64, 100]):
        datadict[label + '_' + str(condition)] = data[:, idx]

df = pd.DataFrame(datadict)

df.to_csv(date.today().strftime('%Y-%m-%d') + '_alldata.csv')

df

Unnamed: 0,id,visit,maxamp_occipital_16,maxamp_occipital_32,maxamp_occipital_64,maxamp_occipital_100,maxamp_all_16,maxamp_all_32,maxamp_all_64,maxamp_all_100,avsnr_occipital_16,avsnr_occipital_32,avsnr_occipital_64,avsnr_occipital_100
0,017,c,1.886948e-09,1.311694e-09,2.639809e-09,1.642056e-09,1.886948e-09,1.311694e-09,2.639809e-09,1.642056e-09,1.615369,1.496006,1.805192,1.359867
1,110,a,1.947913e-09,5.757273e-09,1.169274e-09,1.120803e-09,1.858458e-10,6.043818e-10,1.271451e-10,2.028098e-10,0.984805,0.748223,0.935111,1.220182
2,109,c,1.182881e-09,1.258869e-09,1.163593e-09,9.259463e-10,7.125155e-10,8.453512e-10,8.079321e-10,7.120175e-10,1.215010,1.225470,1.267641,0.868515
3,117,b,4.199994e-09,3.110587e-09,5.351551e-09,4.073291e-09,4.199994e-09,3.110587e-09,5.351551e-09,4.073291e-09,1.598675,1.471255,1.965310,1.764413
4,018,b,2.075330e-09,8.368422e-10,2.237206e-09,1.679443e-09,1.942377e-09,2.043292e-09,2.364269e-09,1.413356e-09,1.170906,1.140167,1.280196,1.210132
5,004,b,6.976483e-10,6.404484e-10,6.155652e-10,8.540354e-10,4.032217e-10,3.297790e-10,3.732939e-10,3.322946e-10,1.300475,1.042961,1.012387,1.332535
6,002,c,2.639706e-09,2.320391e-09,2.641758e-09,2.343579e-09,5.713987e-11,5.201326e-11,1.080859e-10,8.001559e-11,1.420628,1.800495,1.685467,1.460328
7,111,a,3.643916e-10,3.518915e-10,4.620400e-10,4.455454e-10,1.605201e-09,1.752863e-09,1.736696e-09,3.073524e-09,1.043904,1.012891,1.168637,1.179908
8,019,b,2.765936e-09,3.190180e-09,1.993672e-09,2.067899e-09,2.765936e-09,3.190180e-09,1.993672e-09,2.067899e-09,1.370249,1.126160,0.987119,0.780182
9,104,a,8.312188e-10,7.409420e-10,8.172004e-10,1.371840e-09,5.383580e-10,6.291586e-10,6.934458e-10,9.620184e-10,1.174490,0.894938,1.152100,1.551080
