In [None]:
import mne
import mne_connectivity
import os
import os.path as op
import scipy
import numpy as np
import pickle
from warnings import filterwarnings
from sys import argv
import matplotlib.pyplot as plt
from stormdb.access import Query
import pandas as pd
from src.decoding_functions import smooth_data
import src.preprocessing as pfun
filterwarnings("ignore", category=DeprecationWarning)

In [None]:
##################### Define relevant variables ################################
# Project info
project = 'MINDLAB2020_MEG-AuditoryPatternRecognition'
project_dir = '/projects/' + project
os.environ['MINDLABPROJ']= project
os.environ['MNE_ROOT']='~/miniconda3/envs/mne'
os.environ['MESA_GL_VERSION_OVERRIDE'] = '3.2'

#Paths
suffix = ''
raw_path = project_dir + '/scratch/maxfiltered_data/tsss_st16_corr96'
ica_path = project_dir + '/scratch/working_memory/ICA'
avg_path = project_dir + '/scratch/working_memory/averages'
log_path = project_dir + '/misc/working_memory_logs'

subjects_dir = project_dir + '/scratch/fs_subjects_dir' # Free surfer subjects dir for parcellation and source localization
fwd_path = project_dir + '/scratch/forward_models'

# Subjects info:
qy = Query(project)
subs = qy.get_subjects()

#Subject
scode = 11
# if len(argv) > 1:
#     scode = int(argv[1])
sub = subs[scode-1]

print('output will be saved to the following filename:\n\n{}{}'.format(sub,suffix))

# Create subject specific directories if they don't exist
if not os.path.exists(avg_path + '/data/' + sub):
    os.mkdir(avg_path + '/data/' + sub)
if not os.path.exists(avg_path + '/figures/' + sub):
    os.mkdir(avg_path + '/figures/' + sub)

# Define output paths
conn_path = avg_path + '/data/{}/{}_conn{}.p'.format(sub,sub,suffix)
fig_path = avg_path + '/figures/{}/{}_conn{}.pdf'.format(sub,sub,suffix)

# Define block names (original MEG names, new condition names and logfile names)
conds_orig = ['main','inv'] # MEG block code
conds = ['maintenance','manipulation'] # New block code
lnames = ['recognize','invert']

In [None]:
################################ Epoch data #########################################

# Epoching parameters
reject = dict(mag = 4e-12, grad = 4000e-13) # rejection thresholds
events_fun = pfun.main_task_events_fun # Event function (see src/preprocessing.py)
tmin = -2 #epoch start
tmax = 8 #epoch end
l_freq = .05 #HP filter
h_freq = None #LP filter
baseline = -.2
# Initialize
epochs = {}
print('\n############### EPOCHING #################\n')
for cidx, c in enumerate(conds_orig):
    nc = conds[cidx] # new condition name
    
    # Files to retrieve
    fname = os.path.join(raw_path, sub, c + '_raw_tsss.fif')
    icaname = os.path.join(ica_path, sub, c + '_raw_tsss-ica.fif')
    lfname = op.join(log_path, sub[0:4] + '_' + lnames[cidx] + '_MEG.csv')
    events_fun_kwargs = {'cond': nc,'lfname': lfname} # input to the events function (new condition name and logfile)
               
    #Epoching proper:
    epochs[nc] = pfun.WM_epoching(data_path = fname, #raw data path
                                  ica_path = icaname, #ICA components path
                                  tmin = tmin, tmax = tmax, #Epoch times
                                  l_freq = l_freq, h_freq = None, #Filterning options
                                  resample = 100, bads = [], #Resample and bad channels to reject
                                  baseline = None, notch_filter = 50, # Demean baseline
                                  events_fun = events_fun, #Event function to use for epoching
                                  events_fun_kwargs = events_fun_kwargs, #Arguments for event function
                                  reject=reject) # thresholds to reject artifacts

In [None]:
#combine epochs
epochs = mne.concatenate_epochs([epochs[e] for e in epochs])

In [None]:
### Source localization
# Get sensor data covariance
data_cov = mne.compute_covariance(epochs.load_data().copy().pick_types('mag'),
                                       tmin= 0, tmax = 6.25,rank ='info')

In [None]:
# smooth_tstep = 0.025
# smooth_twin = 0.08
# if smooth_tstep:
#     new_data, new_times = smooth_data(epochs.get_data(), tstart=epochs.times[0],
#                                       tstep=smooth_tstep, twin=smooth_twin,
#                                       Fs=epochs.info['sfreq'], taxis=2)

# new_info = epochs.info.copy()
# new_info['sfreq'] = 1/smooth_tstep
# epochs = mne.EpochsArray(new_data, info = new_info, events = epochs.events,
#                          event_id = epochs.event_id,tmin = epochs.tmin)

In [None]:
####### Compute sources propper
print('\n computing sources \n')
fwd_fn = op.join(fwd_path, sub + '_vol-fwd.fif')
fwd = mne.read_forward_solution(fwd_fn)
#compute noise covariance
# noise_cov = mne.compute_covariance(epochs,tmin = -1,
#                                    tmax=0, rank='info')
#     data_cov = mne.compute_covariance(epochs.load_data().copy().pick_types('mag'),
#                                        tmin= 0, tmax = 6.25,rank ='info')
## mne solution
inv = mne.beamformer.make_lcmv(epochs['manip'].info,fwd,data_cov, reg=0.05,
                                pick_ori='max-power', #noise_cov=noise_cov,#,depth = 0.95,
                                weight_norm= 'nai', rank = 'info')

In [None]:
### Apply inverse solution
src_epochs = mne.beamformer.apply_lcmv_epochs(epochs,inv)

In [None]:
### Load parcellation for specific subject
label_file = subjects_dir + '/{}/mri/aparc.a2009s+aseg.mgz'.format(sub)
labels = mne.get_volume_labels_from_aseg(label_file)

In [None]:
### Read subjec-specific source space
src = mne.read_source_spaces(subjects_dir + '/{}/bem/{}_vol-src.fif'.format(sub,sub))

In [None]:
# Extract label time courses (see attlas for label names)
clabels = ['ctx_rh_G_temp_sup-G_T_transv',
           'Right-Thalamus-Proper',
           'ctx_rh_G_and_S_cingul-Mid-Post',
           'Right-Hippocampus',    
           'ctx_rh_G_precuneus',
           'ctx_lh_G_temp_sup-G_T_transv',
           'Left-Thalamus-Proper',
           'ctx_lh_G_and_S_cingul-Mid-Post',
           'Left-Hippocampus',
           'ctx_lh_G_precuneus'
           ]

stc_labels = []
for cidx, c in enumerate(src_epochs):
    print('extracting sources for epoch {}'.format(cidx+1))
    stc_labels += [src_epochs[cidx].extract_label_time_course(labels = [label_file,clabels], src = src, mode = 'auto')]

In [None]:
## Convert list of ROI time courses to array
roi_data = np.array(stc_labels)
roi_data.shape
## We may want to get rid of 0-lag corelations (use with care, it may get rid of signal):
roi_data = mne_connectivity.symmetric_orth(roi_data)

In [None]:
## Cross-correlation connectivity for different time periods
periods = {'baseline': [-2,0], 'listen': [0,2],'transition': [1,3], 'imagine': [2,4]}
Xcorr = {}
for p in periods:
    print('Calculating Xcorr for period ',p,' ', periods[p])
    tix = [a and b for a,b in zip(src_epochs[0].times>=periods[p][0],src_epochs[0].times < periods[p][1])]
    ctimes = src_epochs[0].times[tix]
    print(ctimes)
    # Initialize output array with shape nTrials * nROIs * nROIs * nTimeLags
    Xcorr[p] = np.zeros((roi_data.shape[0],roi_data.shape[1],roi_data.shape[1],ctimes.shape[0]*2-1))
    Xcorr[p][Xcorr[p]==0] = np.nan
    # Loop over trials and pairs of regions
    for t in range(roi_data.shape[0]):
        print('Xcorr trial ', t+1)
        for r1 in range(roi_data.shape[1]):
            for r2 in range(roi_data.shape[1]):
                if r2 > r1: # if not computed before
                    ## Compute cross-correlation proper
                    a = roi_data[t,r1,tix]
                    b = roi_data[t,r2,tix]
                    # Standarize data to get pearson's r output
                    aa = (a - a.mean()) / (np.std(a) * len(a)) 
                    bb =  (b - b.mean()) / np.std(b)
                    Xcorr[p][t,r1,r2,:] = scipy.signal(aa,bb,mode='full')

In [None]:
# Make some plots
for r1 in range(roi_data.shape[1]):
    for r2 in range(roi_data.shape[1]):
        if r2 > r1:         
            plt.figure()
            plt.plot(np.squeeze(Xcorr['listen'][:,r1,r2,:].mean(axis=0)))
            plt.plot(np.squeeze(Xcorr['imagine'][:,r1,r2,:].mean(axis=0)))
            #plt.plot(np.squeeze((Xcorr['listen'][:,r1,r2,:]-Xcorr['imagine'][:,r1,r2,:]).mean(axis=0)))
            plt.title(str(r1) + ' '+ str(r2))
            #plt.plot(np.squeeze(Xcorr['imagine'][:,r1,r2,:].mean(axis=0)))

In [None]:
sfreq = epochs.info['sfreq']
fmin = {'delta': .5, 'theta': 4}
fmax = {'delta': 2, 'theta': 8}
cwt_bands = {'delta': np.array([.5,.75,1,1.25,1.5,1.75,2]), 'theta': np.array([4,5,6,7,8])}
cwt_n_cycles = {'delta': np.array([1,2,2,2,2,2,2]), 'theta': np.array([3,3,3,3,3])}
periods = {'whole': [0,10]}#,'listen': [0,1.75], 'imagine': [2,4]}
conn = {}
for b in cwt_bands:
    conn[b] = {}
    for p in periods:
        conn[b][p] = mne_connectivity.phase_slope_index(
            roi_data, names=clabels, mode='cwt_morlet', cwt_freqs = cwt_bands[b], #method='pli',
            cwt_n_cycles = cwt_n_cycles[b], sfreq=sfreq, fmin=fmin[b], fmax=fmax[b], #faverage=True,
            n_jobs=1,tmin = periods[p][0], tmax = periods[p][1])#c mt_adaptive=True


In [None]:
cfile = open(conn_path,'wb')
pickle.dump(conn,cfile)
cfile.close()

In [None]:
fig, ax = plt.subplots(2,10,figsize = (30,8),sharex=True, sharey=True)
for bix,b in enumerate(conn):
    print(b)
    bdata =  np.squeeze(conn[b]['whole'].get_data()).reshape(10,10,-1)
    for rix, r in enumerate(conn[b]['whole'].names):
        cdata = np.squeeze(bdata[rix,:,:].copy())
        cdata[(rix+1):,:] = bdata[(rix+1):,rix,:]*-1
        axx, axy = rix % 10, rix // 10 + 1*bix
        ix = np.arange(10) + rix*10        
        im = ax[axy, axx].imshow(cdata, aspect='auto',vmin=-.1,vmax=.1,
                  interpolation='nearest',cmap='RdBu_r',extent=[epochs.times[0],epochs.times[-1],len(clabels),0])#origin='lower'
#         if axy != 3:
#             ax[axy, axx].set_xticks([])
        if axy == 0:
            ax[axy, axx].set_title(r)
        
        if axx == 0:
            ax[axy, axx].set_yticks(np.arange(len(clabels)) + .5)
            ax[axy, axx].set_yticklabels(clabels)
fig.colorbar(im)
plt.tight_layout()
plt.savefig(fig_path)

In [None]:
# ## Example Cross correlation functions
# def xcorr(x, y):
#     # x and y should be normalized
#     # Divide by lem of x to obtain normalized values 
#     corr = signal.correlate(x / x.shape[0], y, mode="full")
#     lags = signal.correlation_lags(len(x), len(y), mode="full")
#     return corr, lags

# def xcorr_window(W):
#     wcorr = np.full((W.shape[0], W.shape[0], W.shape[1]*2-1), np.nan)
#     for y in range(W.shape[0]):
#         for x in range(W.shape[0]):
#             if x >= y:
#                 wcorr[x, y,:], lags = xcorr(W[x,:],W[y,:])
#     return wcorr, lags

# def xcorr_sliding(T, srate, wmin, wmax, wstep, times, tstart, tend):
#     startix = np.argmin(abs(times-tstart))
#     endix = np.argmin(abs(times-tend))
#     sstep = np.round(srate * wstep)
#     smin = np.round(srate * wmin)
#     smax = np.round(srate * wmax)
#     wcenters = np.arange(startix + smin, endix-smax+1, sstep)
#     scorr = []
#     out_times = []
#     for wc in wcenters:
#         out_times += [times[int(wc)]]
#         wix = np.arange(wc-smin,wc+smax+1,1,dtype=int)
#         #Normalize for values between -1 and 1
#         wnorm = (T[:,wix] - T[:,wix].mean(axis=1,keepdims=True)) / T[:,wix].std(axis=1,keepdims=True)
#         wcorr, lags = xcorr_window(wnorm)
#         scorr += [wcorr.copy()]
#     scorr = np.array(scorr)
#     out_times = np.array(out_times)
#     return scorr, lags / srate, out_times

# def xcorr_trials(D,srate, wmin, wmax, wstep, times, tstart, tend, orth=False):
#     tcorr = []
#     if orth:
#         print("orthogonalizing")
#         D = mne_connectivity.symmetric_orth(D)
#     for d in range(D.shape[0]):
#         print('processing epoch ', d + 1, ' / ', D.shape[0])
#         scorr, lags, out_times = xcorr_sliding(D[d], srate, wmin, wmax, wstep,times, tstart,tend)
#         tcorr += [scorr.copy()]
#     tcorr = np.array(tcorr)
#     return tcorr, lags, out_times

# tcorr, lags, out_times = xcorr_trials(roi_data,srate=100, wmin = .3,wmax=.3, wstep = .02,
#                                      times = src_epochs[0].times, tstart = 0, tend = 4, orth=True)