In [3]:
import numpy as np
import scipy as sp
import mne
from functools import reduce
from time import time
from models.preprocessing import utils
from joblib import Parallel, delayed
from pickle import dump

In [4]:
d="../data/reinhartlab/multimodal/cg/Epochs/"
ext="-epo.fif.gz"
box=utils.select_dataset(d,ext);box

HBox(children=(Checkbox(value=True, description='multimodal_eeg_062CG.stimulus', indent=False), Checkbox(value…

In [17]:
def get_mask(raw, evts, times, by_trial=False):
    T = len(raw.times)
    mask = np.array([[np.where((np.arange(T) >= max(int(e+tmin*raw.info['sfreq']), 0)) &
            (np.arange(T) <= min(int(e+tmax*raw.info['sfreq']), T)))[0] for tmin,tmax in times] for e in evts])
    if not by_trial:
        mask = mask.reshape(-1)
    return mask


def phase_fn(x):
    hil = sp.signal.hilbert(x)
    res = np.angle(hil)+np.pi
    return res


def amplitude_fn(x):
    hil = sp.signal.hilbert(x)
    res = np.abs(hil)
    return res


def pac_tort_epsilon(lo, hi, nbin=20):
    bins = np.linspace(0, 2 * np.pi, nbin + 1)
    bins = [np.logical_and(lo >= a, lo <= b) for a,b in zip(bins[0:-1], bins[1:])] # bins differ in size, i.e. can't merge masks
    amps = np.fmax(np.array([[[np.mean(hi[k,i,:][bins[j][k,i,:]]) for k in range(lo.shape[0])] for j in range(nbin)] for i in range(lo.shape[1])]).T, np.finfo(float).eps)
    amps = amps/np.repeat(np.expand_dims(np.sum(amps, axis=1),axis=1),nbin,axis=1) #manual broadcast
    hs = np.sum(amps * np.log(amps), axis=1)
    return (np.log(nbin) + hs) / np.log(nbin), amps


def pac_tort_skip(lo, hi, nbin=20):
    bins = np.linspace(0, 2 * np.pi, nbin + 1)
    bins = [np.logical_and(lo >= a, lo <= b) for a,b in zip(bins[0:-1], bins[1:])] # bins differ in size, i.e. can't merge masks
    # filter empty bins
    bins = [b for b in bins if np.any(b)]
    nbin = len(bins)
    # continue as previously
    amps = np.fmax(np.array([[[np.mean(hi[k,i,:][bins[j][k,i,:]]) for k in range(lo.shape[0])] for j in range(nbin)] for i in range(lo.shape[1])]).T, np.finfo(float).eps)
    amps = amps/np.repeat(np.expand_dims(np.sum(amps, axis=1),axis=1),nbin,axis=1) #manual broadcast
    hs = np.sum(amps * np.log(amps), axis=1)
    return (np.log(nbin) + hs) / np.log(nbin), amps

def pac_tort_avg(lo, hi, nbin=20):
    bins = np.linspace(0, 2 * np.pi, nbin + 1)
    bins = [np.logical_and(lo >= a, lo <= b) for a,b in zip(bins[0:-1], bins[1:])] # bins differ in size, i.e. can't merge masks
    # lo, hi => trials x intervals x samples
    # amps => trials x bins x intervals
    # hs => trials x intervals
    amps = np.transpose([[[np.mean(hi[k,i,:][bins[j][k,i,:]]) for k in range(lo.shape[0])] for j in range(nbin)] for i in range(lo.shape[1])],axes=(2,1,0))
    mask = np.where(np.isnan(amps))
    amps[mask] = np.take(np.nanmean(amps,axis=1),mask[-1])
    
    amps = amps/np.repeat(np.expand_dims(np.sum(amps, axis=1),axis=1),nbin,axis=1) #manual broadcas
    hs = np.sum(amps * np.log(amps), axis=1)
    return (np.log(nbin) + hs) / np.log(nbin), amps


In [14]:
def get_filtered_channel(raw,lofrq,hifrq,ch,lo_func,hi_func,n_jobs=1):
    r = raw.get_data(ch)
    par_func = lambda h,func: np.squeeze(func(mne.filter.filter_data(r, raw.info['sfreq'], h[0], h[1],l_trans_bandwidth=1,h_trans_bandwidth=1)))
    hi = Parallel(n_jobs=n_jobs)(delayed(par_func)(hi,hi_func) for hi in hifrq)
    lo = Parallel(n_jobs=n_jobs)(delayed(par_func)(lo,lo_func) for lo in lofrq)
    return hi,lo

def get_masked_metric(hi,lo,hifrq,lofrq,mask,pac_func):
    hi = [h[mask] for h in hi]
    lo = [l[mask] for l in lo]
    results = Parallel(n_jobs=1)(delayed(pac_func)(l,h) for h in hi for l in lo)
    return dict([a for a in zip([(l,h) for h in hifrq for l in lofrq],results)])

def get_pac_per_channel(raw, masks, lofrq, hifrq, ch, lo_func=phase_fn, hi_func=amplitude_fn, pac_func=lambda x: x,n_jobs=1):
    hi,lo = get_filtered_channel(raw,lofrq,hifrq,ch,lo_func,hi_func,n_jobs)
    return [get_masked_metric(hi,lo,hifrq,lofrq,mask,pac_func) for mask in masks]
    
    #return {str(lofrq[i][0])+'-'+str(lofrq[i][1]): {str(hifrq[j][0])+'-'+str(hifrq[j][1]): pac_func(l, h) for j,h in enumerate(hi)} for i,l in enumerate(lo)}

def get_pac(raw, mask, lofrq, hifrq, chs, lo_func=phase_fn, hi_func=amplitude_fn, pac_func=lambda x: x,n_jobs=1):
    return {ch:get_pac_per_channel(raw,mask,lofrq,hifrq,ch,lo_func,hi_func,pac_func,n_jobs) for ch in chs}

def get_pac_from_file(files, lofrq, hifrq, chs=None, condition=None, times=[(0.,.5),(.5,1.)], lo_func=phase_fn, hi_func=amplitude_fn, pac_func=lambda x: x,n_jobs=1,decim=1):
    raw = mne.io.read_raw_fif(files[0],preload=True)
    raw.resample(raw.info['sfreq']/decim)
    masks = [None]*len(files[1:])
    for i,f in enumerate(files[1:]):
        epochs = mne.read_epochs(f,preload=False)
        evts = epochs.events[:,0] if condition is None else epochs[condition].events[:,0]
        masks[i] = get_mask(raw,np.round(evts/decim),times,by_trial=True)
    if chs is None:
        chs = [c for c in epochs.ch_names if not 'EOG' in c]
    pacs = get_pac(raw,masks,lofrq,hifrq,chs,pac_func=pac_tort_avg,n_jobs=n_jobs)
    return {f:{k:v[i] for k,v in pacs.items()} for i,f in enumerate(files[1:])}

In [15]:
 
f_phase = [(1,3),(3,5),(5,7),(7,9),(9,11),(11,13)]
f_amp = [(14,18),(18,22),(22,26),(26,30),(30,34),(34,38),(38,42),(42,46),(46,50),(50,54),(54,58),(58,62),(62,66),(66,70),(70,74),(74,78),(78,82),(82,86),(86,90)]

f_phase = [(1,3),(3,5),(5,7)]
f_amp = [(22,26),(26,30)]

chs = ['FCz']
cond=None

def compute_and_save_pac(files):
    res = get_pac_from_file(files,f_phase,f_amp,chs,cond,pac_func=pac_tort_avg,n_jobs=20,decim=4)
    for f in res.keys():
        data = res.get(f)
        fname = f.replace("-epo.fif.gz","-bytrial.pac.json").replace("Epochs","PAC")
        f = open(fname,"wb")
        dump(data,f,4)

datasets = [b for b in utils.get_selection(box)]
subjects = np.unique([d.split(".")[0] for d in datasets])
#filesets = [[d.replace("Epochs","Raw")+s+".raw.fif.gz"]+[d+b+"-epo.fif.gz" for b in datasets if s in b] for s in subjects[:10]]
filesets = [[d.replace('Epochs','Raw')+s+".raw.fif.gz"]+[d+b+"-epo.fif.gz" for b in datasets if s in b] for s in subjects]

In [20]:
for files in filesets:
    s=time()
    try:
        compute_and_save_pac(files)
    finally:
        print(time()-s)

Opening raw data file ../data/reinhartlab/multimodal/cg/Raw/MulitModal_EEG_078CG.raw.fif.gz...
    Range : 0 ... 5749559 =      0.000 ...  5749.559 secs
Ready.
Reading 0 ... 5749559  =      0.000 ...  5749.559 secs...
Reading ../data/reinhartlab/multimodal/cg/Epochs/MulitModal_EEG_078CG.stimulus-epo.fif.gz ...
    Found the data of interest:
        t =   -2500.00 ...    3500.00 ms
        0 CTF compensation matrices available
1216 matching events found
No baseline correction applied
Adding metadata with 9 columns
0 projection items activated
Reading ../data/reinhartlab/multimodal/cg/Epochs/MulitModal_EEG_078CG.feedback-epo.fif.gz ...
    Found the data of interest:
        t =   -2500.00 ...    3500.00 ms
        0 CTF compensation matrices available
1216 matching events found
No baseline correction applied
Adding metadata with 9 columns
0 projection items activated
Reading ../data/reinhartlab/multimodal/cg/Epochs/MulitModal_EEG_078CG.response-epo.fif.gz ...
    Found the data of inte

  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
  epochs = mne.read_epochs(f,preload=False)
