In [1]:
import os
import numpy as np
import pandas as pd
import torch

# MNE modules
import mne
from mne.time_frequency import psd_array_multitaper

# Filter warnings
import warnings
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")

mne.set_log_level(verbose='CRITICAL')

In [3]:
def define_bands():
    # Frequency bands
    bands = [(0.9, 4, 'Delta (0.9-4 Hz)', 'D'), (4, 8, 'Theta (4-8 Hz)', 'T'), (8, 14, 'Alpha (8-14 Hz)', 'A'),
             (14, 25, 'Beta (14-25 Hz)', 'B'), (25, 40, 'Gamma (25-40 Hz)', 'G')]

    str_freq = [bands[i][3] for i in range(len(bands))]

    # Localization by scalp regions
    regions = [(['Fp1', 'Fp2'], 'Fp', 'Pre-frontal'), (['F7', 'F3'], 'LF', 'Left Frontal'),
               (['Fz'], 'MF', 'Midline Frontal'), (['F4', 'F8'], 'RF', 'Right Frontal'),
               (['C3'], 'LT', 'Left Temporal'), (['P8'], 'RT', 'Right Temporal'),
               (['C3', 'Cz', 'C4'], 'Cen', 'Central'), (['P3', 'Pz', 'P4'], 'Par', 'Parietal'),
               (['O1', 'O2'], 'Occ', 'Occipital')]

    SLICE_LEN = 10  # number of epochs to measure features, coherence and PLV

    n_freq = len(str_freq)
    n_regions = len(regions)

    return bands, str_freq, regions, SLICE_LEN, n_freq, n_regions

In [4]:
def extract_features(sample, window=219, step=32, samp_rate=100):
    sliced_data = []
    slices_amount = int((sample.shape[0] - window) / step + 1)
    for i in range(slices_amount):
        slicee = sample[0 + i*step :window + i*step, :]
        sliced_data.append(slicee)
    sliced_data = np.array(sliced_data) # events, chanels, window
    sliced_data = sliced_data.reshape(slices_amount, sample.shape[1], window)

    ch_names = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'C3', 'Cz', 'C4', 'P3', 'Pz', 'P4', 'O1', 'O2']
    n_channels = len(ch_names)
    bands, str_freq, regions, SLICE_LEN, n_freq, n_regions = define_bands()
    
    kwargs = dict(fmin=bands[0][0], fmax=bands[-1][1], sfreq=samp_rate, bandwidth=None, adaptive=True, n_jobs=1)
    loc_masks = [[ch_names[i] in reg for i in range(n_channels)] for (reg, _, _) in regions]
    
    lst_st_psd_raw = []
    lst_st_psd_loc_raw = []
    lst_st_psd_all_raw = []
    
    st_psd_mtaper, st_freq_mtaper = psd_array_multitaper(sliced_data, **kwargs)
    freq_masks = [(fmin < st_freq_mtaper) & (st_freq_mtaper < fmax) for (fmin, fmax, _, _) in bands]
    
    
        # Stages
    st_psd_raw = np.array([np.mean(st_psd_mtaper[:, :, _freq_mask], axis=2) for _freq_mask in freq_masks]).transpose(1,
                                                                                                                     2,
                                                                                                                     0)
    st_psd_loc_raw = np.array([np.mean(st_psd_raw[:, _mask, :], axis=1) for _mask in loc_masks]).transpose(1, 0, 2)
    st_psd_all_raw = np.mean(st_psd_raw, axis=1)

    df_st_raw = pd.DataFrame()
    df_st_loc_raw = pd.DataFrame()
    df_st_all_raw = pd.DataFrame()
    for _fr in range(n_freq):
        for _ch in range(n_channels):
            df_st_raw[str_freq[_fr] + '_psd_' + ch_names[_ch]] = st_psd_raw[:, _ch, _fr]
        for _r in range(n_regions):
            df_st_loc_raw[str_freq[_fr] + '_psd_' + regions[_r][1]] = st_psd_loc_raw[:, _r, _fr]
        df_st_all_raw[str_freq[_fr] + '_psd_All'] = st_psd_all_raw[:, _fr]
    
    
    df = df_st_raw

    lst_st = 10 * np.log10(df[SLICE_LEN // 2:-SLICE_LEN // 2]) #- 10 * np.log10(df_blm_psd_raw.mean(axis=0))
    # need to subtract baseline!!!
    # so need to have baseline file
    
    return lst_st

In [6]:
def features_indices(psd_previous):
    samp_rate = 100
    ch_names = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'C3', 'Cz', 'C4', 'P3', 'Pz', 'P4', 'O1', 'O2']
    n_channels = len(ch_names)
    
    bands, str_freq, regions, SLICE_LEN, n_freq, n_regions = define_bands()
    print('Defining PSD indices...')

    # PSD special features (EEG indices) (re-referenced data)

    lst_st_psd_ind_raw = []
    lst_st_psd_ind_loc_raw = []
    lst_st_psd_ind_all_raw = []

    str_psd_ind = ['T_D', 'A_D', 'A_T', 'A_DT', 'B_D', 'B_T', 'B_A', 'B_DT', 'B_TA', 'G_D', 'G_T', 'G_A', 'G_B', 'G_DT',
                   'G_TA', 'G_AB']

    df_st_raw = pd.DataFrame()
    df_st_loc_raw = pd.DataFrame()
    df_st_all_raw = pd.DataFrame()

    # Indices per channel (averaged PSD)
    for _ch in range(n_channels):
        for ind in str_psd_ind:
            if (len(ind) == 3):
                df_st_raw[ind + '_psd_' + ch_names[_ch]] = (psd_previous[ind[0] + '_psd_' + ch_names[_ch]] /
                                                            psd_previous[ind[2] + '_psd_' + ch_names[_ch]])
            elif (len(ind) == 4):
                df_st_raw[ind + '_psd_' + ch_names[_ch]] = (psd_previous[ind[0] + '_psd_' + ch_names[_ch]] /
                                                            psd_previous[ind[2] + '_psd_' + ch_names[_ch]] +
                                                             psd_previous[ind[3] + '_psd_' + ch_names[_ch]])
    lst_st_psd_ind_raw = df_st_raw
    lst_st_psd_ind_loc_raw = df_st_loc_raw
    lst_st_psd_ind_all_raw = df_st_all_raw

    # Aggregate all stages in one DataFrame
    df = lst_st_psd_ind_raw
    lst_st = 10 * np.log10(df[SLICE_LEN // 2:-SLICE_LEN // 2]) #- 10 * np.log10(df_blm_psd_ind_raw.mean(axis=0)

    return lst_st # == df_st_psd_ind_db



In [8]:
# ----------------------------------------Data-----------------------------------------------------
class TEST_TUH(torch.utils.data.Dataset):
    def __init__(self, path): #, tuh_filtered_stat_vals):
        super(TEST_TUH, self).__init__()
        self.main_path = path
        self.paths = path
        print(self.paths)
        # self.tuh_filtered_stat_vals = tuh_filtered_stat_vals
        # self.paths = ['{}/{}'.format(self.main_path, i) for i in os.listdir(self.main_path)]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx, negative=False):
        path = self.paths[idx]
        # take 60s of recording with specified shift
        key = False
        while (key == False):
            try:
                # sample = np.load(path, allow_pickle=True).item()['value']
                sample = np.load(path, allow_pickle=True).item()
                key = True
            except Exception as e:
                print("Path: {} is broken ".format(path), e)
                path = np.random.choice(self.paths, 1)[0]
                # sample = np.load(path, allow_pickle=True).item()['value']
        real_len = min(3000, sample['value_pure'].shape[0])

        HSE_Stage2_channels = ['Fp1', 'FP2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'C3', 'Cz', 'C4', 'P3', 'Pz', 'P4', 'O1', 'O2']
        HSE_Stage2_channels = [i.upper() for i in HSE_Stage2_channels]
        channels_ids = [i for i, val in enumerate(sample['channels']) if val in HSE_Stage2_channels]

        sample = sample['value_pure'][:real_len]

        # choose 2 random channels
        channels_to_train = channels_ids  # np.random.choice(channels_ids, 2, replace=False)
        channels_vector = torch.tensor((channels_to_train))
        sample = sample[:, channels_to_train]

        sample_norm = sample
        if sample_norm.shape[0] < 3000:
            sample_norm = np.pad(sample_norm, ((0, 3000 - sample_norm.shape[0]), (0, 0)))
        print(sample_norm.shape)
        lst_st_feat = extract_features(sample_norm)
        indices = features_indices(lst_st_feat)
        df_st_eeg = pd.concat([lst_st_feat, indices], axis=1).dropna()

        if np.random.choice([0, 1], p=[0.7, 0.3]) and not negative:
            index = np.random.choice(self.__len__() - 1)
            negative_sample = self.__getitem__(index, True)
            negative_path = negative_sample['path']
            negative_sample_norm = negative_sample['current'].numpy()

            negative_person = negative_sample['path'].split('/')[-1]  # .split('_')
            current_person = path.split('/')[-1]  # .split('_')
            if negative_person.split('_')[0] == current_person.split('_')[0] and \
                    abs(int(negative_person.split('_')[1][:-4]) - int(current_person.split('_')[1][:-4])) < 20000:
                negative_label = torch.tensor(0)               # возможно стоит запретить позитивы отличающиеся < 20000 , если состояние реально изменилось то сеть будет учиться странному.
            else:
                negative_label = torch.tensor(1)
        else:
            negative_sample_norm = sample_norm.copy()
            negative_label = torch.tensor(0)
            negative_path = ''

        attention_mask = torch.ones(3000)
        attention_mask[real_len:] = 0
        return {'current': torch.from_numpy(sample_norm).float(),
                'negative': torch.from_numpy(negative_sample_norm).float(),
                'path': path,
                'label': negative_label,
                'channels': channels_vector,
                'attention_mask': attention_mask,
               'features': df_st_eeg.to_numpy()}

In [9]:
path = os.getcwd() + '/example_data/TUH/'

In [10]:
# splitted_paths = [f'{path}/{i}'.format(i) for i in os.listdir(path)]
splitted_paths = ['/media/hdd/data/TUH_splited.examples/{}'.format(i) for i in os.listdir('/media/hdd/data/TUH_splited.examples/')]

In [11]:
train_dataset = TEST_TUH(splitted_paths[1:3])

['/media/hdd/data/TUH_splited.examples/2_0.npy', '/media/hdd/data/TUH_splited.examples/0_0.npy']


In [10]:
train_loader = torch.utils.data.DataLoader(train_dataset)

In [1]:
for batch in train_loader:
    features = batch['features']

NameError: name 'train_loader' is not defined