In [12]:
import mne
import warnings
import matplotlib.pyplot as plt
import os
import numpy as np
from tqdm import tqdm
import pycwt
warnings.filterwarnings('ignore')

In [13]:
def find_common_electrodes(path):
    mdd_electrodes_set = []
    for elem in os.listdir(path):
        data = mne.io.read_raw(path + elem, verbose = False)
        mdd_electrodes_set.append(set(data.ch_names))
    return set.intersection(*mdd_electrodes_set)

In [14]:
kz_mdd = find_common_electrodes('data/kz_clean/mdd/')
kz_health = find_common_electrodes('data/kz_clean/health/')
electrodes_to_use = set.intersection(kz_mdd, kz_health)
print(electrodes_to_use)

{'T3', 'P3', 'F4', 'C4', 'O2', 'Pz', 'Fp2', 'Fpz', 'Cz', 'T6', 'Oz', 'F3', 'Fp1', 'T4', 'Fz', 'F7', 'C3', 'T5', 'O1', 'P4', 'F8'}


# Cut data into 2s patches (without overlapping)

In [15]:
Fs = 500
electrodes_to_use = set.intersection(kz_mdd, kz_health)
window_size = 2 # in seconds

In [16]:
def cut_data(path, path_to_save, window_size = window_size, 
             window_offset = window_size, 
             Fs = Fs, electrodes_to_use = electrodes_to_use):
    for elem in tqdm(os.listdir(path)):
        data = mne.io.read_raw(path + elem, verbose = False)
        
        for ch in data.ch_names:
            if ch not in electrodes_to_use:
                data.drop_channels(ch)
        np_data = data[:][0].T
        
        idx = 1
        offset = 0
        num_of_samples = Fs * window_size
        while np_data[offset : offset + num_of_samples].shape[0] == num_of_samples:
            name = elem.split('.')[0]
            np.save(path_to_save + f'{name}_patch_{idx}.npy', 
                    np_data[offset : offset + num_of_samples])
            offset += window_offset * Fs
            idx += 1

In [17]:
# cut_data('data/kz_clean/health/', 'data_np/kz_clean/health/')
# cut_data('data/kz_clean/mdd/', 'data_np/kz_clean/mdd/')

In [None]:
# !mkdir data_np_

# ToDo  
* baseline feature generation
* wavelet coherence for each patch for each electrode  
* simple run with gnn (working!)

## Computing wavelet coherence significance

In [18]:
test_signal = np.load('data_np/kz_clean/health/clean_F116_patch_1.npy')

In [19]:
test_signal[:, 1]

array([-1.10052794e-08, -6.72348676e-07, -1.22173844e-06, -1.56941871e-06,
       -1.66339237e-06, -1.49397022e-06, -1.08666029e-06, -4.86569036e-07,
        2.59925713e-07,  1.11994461e-06,  2.08069514e-06,  3.14635531e-06,
        4.32706111e-06,  5.62435207e-06,  7.01818863e-06,  8.45974955e-06,
        9.87204658e-06,  1.11579011e-05,  1.22128549e-05,  1.29397549e-05,
        1.32619871e-05,  1.31333891e-05,  1.25440856e-05,  1.15223884e-05,
        1.01332434e-05,  8.47357387e-06,  6.66464803e-06,  4.84150041e-06,
        3.13980263e-06,  1.68127497e-06,  5.59456851e-07, -1.71858730e-07,
       -5.05344644e-07, -4.80526580e-07, -1.77567273e-07,  2.94555861e-07,
        8.13453369e-07,  1.25975600e-06,  1.53211215e-06,  1.55883129e-06,
        1.30504213e-06,  7.75028639e-07,  1.00188258e-08, -9.17874843e-07,
       -1.91464824e-06, -2.87524495e-06, -3.69507757e-06, -4.28099565e-06,
       -4.56047474e-06, -4.48800165e-06, -4.04797629e-06, -3.25395990e-06,
       -2.14467605e-06, -

In [24]:
def standardize(s, detrend = True, standardize = True, remove_mean = False):
    '''
    Helper function for pre-processing data, specifically for wavelet analysis

    INPUTS:
        s - numpy array of shape (n,) to be normalized
        detrend - boolean on whether to linearly detrend s
        standardize - boolean on whether to divide by the standard deviation
        remove_mean - boolean on whether to remove the mean of s. Exclusive with detrend.

    OUTPUTS:
        snorm - numpy array of shape (n,)
    '''

    # Derive the variance prior to any detrending
    std = s.std()
    smean = s.mean()

    if detrend and remove_mean:
        raise ValueError('Only standardize by either removing secular trend or mean, not both.')

    # Remove the trend if requested
    if detrend:
        arbitrary_x = np.arange(0, s.size)
        p = np.polyfit(arbitrary_x, s, 1)
        snorm = s - np.polyval(p, arbitrary_x)
    else:
        snorm = s

    if remove_mean:
        snorm = snorm - smean

    # Standardize by the standard deviation
    if standardize:
        snorm = (snorm / std)

    return snorm

In [31]:
sig1 = standardize(test_signal[:, 0])
sig2 = standardize(test_signal[:, 2])

In [32]:
wcoh = pycwt.wct(sig1, sig2, dt = 1 / Fs, normalize = True, sig = True) 

Calculating wavelet coherence significance


  0%|▌                                                                                                                                                                      | 1/300 [00:06<33:41,  6.76s/it]


KeyboardInterrupt: 