In [None]:
import pickle
import time
import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
import connectivity_measures as cm

In [None]:
IDs = ['SERT1597'] #, 'SERT1659', 'SERT1678', 'SERT1908',
       #'SERT1984', 'SERT1985', 'SERT2014', 'SERT1668',
       #'SERT1665', 'SERT2018', 'SERT2024', 'SERT2013']

In [None]:
### Contructing filters
### Butterworth filter
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = signal.butter(order, [low, high], btype='band')
    return b, a

fs      = 1000.0

filter_parameters = {'theta':    {'N': 4, 'lowcut': 3.5,  'highcut': 7.5},
                    #'alpha':     {'N': 4, 'lowcut': 8.0,  'highcut': 13.0},
                    'beta':      {'N': 4, 'lowcut': 13.0, 'highcut': 25.0},
                    'low_gamma': {'N': 6, 'lowcut': 25.0, 'highcut': 60}}

butterworths = dict()
for band in filter_parameters.keys():
    lowcut = filter_parameters[band]['lowcut']
    highcut = filter_parameters[band]['highcut']
    N = filter_parameters[band]['N']
    butterworths[band] = butter_bandpass(lowcut, highcut, fs, order=N)


In [1]:
condition = 'epochs'

wPLI = {'theta': {}, 'beta': {}, 'low_gamma': {}}
for mouse in IDs: 
    npys_dir = '/home/maspe/filer/SERT/' + mouse + '/npys/'
    print('\nLoading mouse {}...'.format(mouse))
    
    ### Loading data
    data = pickle.load(open(npys_dir + mouse + '.epochs', 'rb') , encoding='latin1')
       
    ### Loop
    filtered_bands = {'theta': {}, 'beta': {}, 'low_gamma': {}}
    iterator = 0
    
    print('Stacking structures...')
    for structure in ['mPFC', 'NAC', 'BLA', 'vHip']: 
        print('Loading ' + structure + '...')
        if iterator == 0:
            all_structures = data[structure][condition]
        else:
            all_structures = np.vstack((all_structures, data[structure][condition]))
        
        iterator += 1
        
    print('Filtering...') 
    phases = dict()
    for band in filtered_bands.keys():
        filtered = signal.filtfilt(b=butterworths[band][0], a=butterworths[band][1],
                                   x=all_structures, axis=1)
        
        print('Getting phases for {} band...'.format(band))
        transformed = signal.hilbert(filtered)
        phases[band] = np.angle(transformed)
        
        print('Calculating wPLI for {} band...'.format(band))
        clock = time.time()
        
        roi_NAC  = np.array([-1.4, -1.1])
        roi_vHip = np.array([-0,5, 0.5])
        roi_mPFC = np.array([2.1, 2.6])

        roi_NAC  = ((roi_NAC + 3) * 30000).astype(int)
        roi_vHip = ((roi_vHip + 3) * 30000).astype(int)
        roi_mPFC = ((roi_mPFC + 3) * 30000).astype(int)

        
        n_epochs = phases[band].shape[2]
        
        nac2vhip_wpli = \   
        cm.PLI2(phases['theta'][:, roi_NAC[0]:roi_NAC[1], epoch], phases['beta'][:, roi_vHip[0]:pre[1], epoch]
                
                
        for epoch in range(n_epochs):
            if epoch == 0:
                wpli_pre[band] = cm.PLI2(phases[:, pre[0]:pre[1], epoch], average = False, method='wpli')
                wpli_post[band] = cm.PLI2(phases[:, post[0]:post[1], epoch], average = False, method='wpli')
            else:
                wpli_pre[band]  = np.dstack((wpli_pre[band], cm.PLI2(phases[:,pre[0]:pre[1],epoch], average = False, method='wpli')))
                wpli_post[band] = np.dstack((wpli_post[band], cm.PLI2(phases[:,post[0]:post[1],epoch], average = False, method='wpli')))
        
        print('wPLI calculated in {} s.'.format(time.time() - clock))
        
        wPLI[band]['pre'] = wpli_pre
        wPLI[band]['post'] = wpli_post
      
    
    pickle.dump(wPLI, open(npys_dir + mouse + '.wpli3', 'wb'), protocol=2)


print('Done!')        

SyntaxError: invalid syntax (<ipython-input-1-7d5ff8bc6cda>, line 48)