In [None]:
import mne

import ssvepy

from autoreject import Ransac

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import os
import glob
import re
import collections
from datetime import date

from tqdm import tqdm_notebook as tqdm
from ipywidgets import interact

mne.utils.set_log_level('WARNING')

In [None]:
if os.path.isdir('/Users/jan/Documents/eeg-data/cancan-saturation/'):
    datafolder = '/Users/jan/Documents/eeg-data/cancan-saturation/'  # on laptop
elif os.path.isdir('/data/group/FANS/cancan/eeg/'):
    datafolder = '/data/group/FANS/cancan/eeg/'  # On NaN

files = glob.glob(datafolder + '*saturation*.vhdr')

In [None]:
ids = []
visits = []

for idx, file in enumerate(files):
    m = re.search('(\d+)[ABCabc]', file)
    ids.append(file[m.start():(m.end()-1)])
    visits.append(file[m.end()-1])


In [None]:
raws = [mne.io.read_raw_brainvision(file, event_id={'DCC': 199, 'actiCAP Data On': 200},
                                    montage=mne.channels.read_montage('standard_1020'))
        for file in files]

In [None]:
for idx, _ in enumerate(raws):
    raws[idx].info['subject_info'] = ids[idx] + visits[idx]

In [None]:
epochs = [[mne.Epochs(raw,
                      mne.find_events(raw),
                      event_id=event,
                      tmin=0, tmax=10,
                      picks=mne.pick_types(raw.info, eeg=True))
           for event in [16, 32, 64, 100]]
          for raw in tqdm(raws, desc='Subject')]

In [None]:
for subject in tqdm(epochs):
    for epoch in subject:
        epoch.load_data()
        epoch.resample(256)


In [None]:
%%capture
# Clean the data using autoreject's ransac
cleaners = [Ransac(verbose='tqdm_notebook') for raw in raws]

cleanepochs = [[cleaner.fit_transform(epoch) for epoch in epochlist]
               for epochlist, cleaner in tqdm(zip(epochs, cleaners), desc='Subjects')]

In [None]:
%%capture
ssveps = [[ssvepy.Ssvep(epoch, 5.0, fmin=2, fmax=30) for epoch in epochlist]
           for epochlist in tqdm(cleanepochs)]

In [None]:
occipital_indices = [raws[0].ch_names.index(ch)
                     for ch in ['Oz', 'O1', 'O2', 'POz']]

dataarrays = collections.OrderedDict()

# Aggregate the amp from the occ. electrode with max SNR
dataarrays['maxamp_occipital'] = np.zeros((len(ssveps), len(ssveps[0])))
for subject, ssveplist in enumerate(ssveps):
    maxelec = np.nanargmax(np.stack(
            [ssvep.stimulation.snr[:, occipital_indices].mean(axis=0) for ssvep in ssveplist], axis=-1
        ).mean(axis=-1))
    maxelec = occipital_indices[maxelec]
    for condition, ssvep in enumerate(ssveplist):
        dataarrays['maxamp_occipital'][subject, condition] = ssvep.stimulation.power[:, maxelec].mean()

# Aggregate the amp from the electrode with overall max SNR
dataarrays['maxamp_all'] = np.zeros((len(ssveps), len(ssveps[0])))
for subject, ssveplist in enumerate(ssveps):
    maxelec = np.nanargmax(np.stack(
            [ssvep.stimulation.snr.mean(axis=0) for ssvep in ssveplist], axis=-1
        ).mean(axis=-1))
    for condition, ssvep in enumerate(ssveplist):
        dataarrays['maxamp_all'][subject, condition] = ssvep.stimulation.power[:, maxelec].mean()

# Get the average SNR at the occipital electrodes
dataarrays['avsnr_occipital'] = np.zeros((len(ssveps), len(ssveps[0])))
for subject, ssveplist in enumerate(ssveps):
    for condition, ssvep in enumerate(ssveplist):
        dataarrays['avsnr_occipital'][subject, condition] = np.nanmean(ssvep.stimulation.snr[:, occipital_indices])

# amplitude average weighted by snr, all electrodes
datatype = 'weightedamp_all'
dataarrays[datatype] = np.zeros((len(ssveps), len(ssveps[0])))
for subject, ssveplist in enumerate(ssveps):
    for condition, ssvep in enumerate(ssveplist):
        # create a masked array
        tmpdata = np.squeeze(np.ma.array(ssvep.stimulation.power, mask=np.isnan(ssvep.stimulation.power)))
        # remove outliers
        tmpdata[np.abs(tmpdata - tmpdata.mean()) > 4*np.std(tmpdata)] = np.ma.masked
        #
        dataarrays[datatype][subject, condition] = np.nanmean(
            np.ma.average(tmpdata, weights=np.fmax(ssvep.stimulation.snr, 0))
        )


In [None]:
dataarrays['weightedamp_all']

## Save data to CSV file

In [None]:
# Construct the columns for pandas

datadict = collections.OrderedDict()  # nice to have order
datadict['id'] = ids
datadict['visit'] = visits
for label, data in dataarrays.items():
    for idx, condition in enumerate([16, 32, 64, 100]):
        datadict[label + '_' + str(condition)] = data[:, idx]

df = pd.DataFrame(datadict)

df.to_csv(date.today().strftime('%Y-%m-%d') + '_alldata.csv')

In [None]:
df