In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import mne, re, os, pickle
import torch

from utils import *

In [2]:
def fix_average_naming(raw):
    # is the average naming a problem?
    reSTR = r"(?<=EEG )(.*)(?=-)"
    reLowC = ['FP1', 'FP2', 'FZ', 'CZ', 'PZ']

    for channel_name in raw.ch_names:
        if re.search(reSTR, channel_name) and re.search(reSTR, channel_name).group() in reLowC:
            lowC = channel_name[0:5]+channel_name[5].lower()+channel_name[6:]
            mne.channels.rename_channels(raw.info, {channel_name: re.findall(reSTR, lowC)[0]})
        elif channel_name == "PHOTIC-REF":
            mne.channels.rename_channels(raw.info, {channel_name: "PHOTIC"})
        elif re.search(reSTR, channel_name):
            mne.channels.rename_channels(raw.info, {channel_name: re.findall(reSTR, channel_name)[0]})

    return raw

def pick_and_rename_TUH_channels(raw):
    raw = raw.copy()

    raw = fix_average_naming(raw)

    try:
        mne.channels.rename_channels(raw.info, {'Fp1': 'FP1', 'Fp2': 'FP2', 'Fz': 'FZ', 'Cz': 'CZ', 'Pz': 'PZ'})
        mne.channels.rename_channels(raw.info, {'T3': 'T7', 'T4': 'T8'})
    except:
        pass

    EEG_20_div = [
                'FP1', 'FP2',
        'F7', 'F3', 'FZ', 'F4', 'F8',
        'T7', 'C3', 'CZ', 'C4', 'T8',
        'T5', 'P3', 'PZ', 'P4', 'T6',
                 'O1', 'O2'
    ]
    
    raw.pick_channels(ch_names=EEG_20_div)
    raw.reorder_channels(EEG_20_div)

    return raw

def get_TUH_raw(FILE, high_pass=0.5, low_pass=70, SAMPLNIG_FREQ=256):
    raw = mne.io.read_raw_edf(FILE, verbose=False, preload=True)
    raw = pick_and_rename_TUH_channels(raw)

    mne.datasets.eegbci.standardize(raw)  # Set channel names
    montage = mne.channels.make_standard_montage('standard_1020')
    raw = raw.set_eeg_reference(ref_channels='average', projection=True, verbose=False)
    raw = raw.set_montage(montage)

    raw = raw.resample(SAMPLNIG_FREQ)
    raw.filter(high_pass, low_pass, verbose=False)

    return raw

def get_TUH_annotation(raw, WINDOW_LENGTH=4, SAMPLNIG_FREQ=256):
    new_onset = []
    new_duration = []
    new_description = []

    number_windows = int(raw.get_data().shape[1] / SAMPLNIG_FREQ / WINDOW_LENGTH)

    for window_idx in range(number_windows):
        new_onset.append(window_idx * WINDOW_LENGTH)
        new_duration.append(WINDOW_LENGTH)
        new_description.append(f'chop {window_idx}')

    new_onset = np.array(new_onset, dtype=np.float64)
    new_duration = np.array(new_duration, dtype=np.float64)

    annotations = mne.Annotations(onset=new_onset, duration=new_duration, description=new_description)

    return annotations

In [3]:
def process_raw(raw, fwd, bands, SNR=100):
    
    annotations = get_TUH_annotation(raw)

    raw_bands = [raw.copy().filter(bands[b][0], bands[b][1], verbose=False) for b in bands.keys()]

    covs = [get_cov(raw_band) for raw_band in raw_bands]
    inverse_operators = [make_fast_inverse_operator(raw_band.info, fwd, cov, snr=SNR) for raw_band, cov in zip(raw_bands, covs)]

    window_dict_full = get_window_dict(raw, annotations)
    window_dict_bands = [get_window_dict(raw_band, annotations) for raw_band in raw_bands]

    return window_dict_full, window_dict_bands, inverse_operators

In [4]:
def get_TUH_activity(window_dict_bands, inverse_operators, bands, labels, label_names):

    activities = np.zeros((len(window_dict_bands[0].keys()), len(list(bands.keys())), len(label_names)))

    for windows_idx, window_key in tqdm(enumerate(window_dict_bands[0].keys()), total=len(window_dict_bands[0].keys())):

        raw_bands = [window_dict_bands[band_idx][window_key][0] for band_idx in range(len(bands.keys()))]
        for band_idx, band_key in enumerate(bands.keys()):
            raw_temp = window_dict_bands[band_idx][window_key][0]

            stc = inverse_operators[band_idx](raw_temp)
            activities[windows_idx, band_idx] = np.concatenate(get_power_per_label(stc, labels, standardize=False)).reshape(1, -1)
        
    return activities

def save_TUH_concepts(window_dict_full, activities, FILE, DATA_PATH_CONCEPTS, bands, label_names, NUMBER_CHANNELS=20, NUMBER_SAMPLES=1024):
    baseline_activity = np.mean(activities, axis=0)

    for window_idx, window_key in enumerate(window_dict_full.keys()):
       
        activity = activities[window_idx]
        activity -= baseline_activity

        most_active_band_idx = np.argmax(activity.mean(axis=1))
        most_active_band = list(bands.keys())[most_active_band_idx]

        brain_region_idx = activity[most_active_band_idx].argmax()
        brain_region = label_names[brain_region_idx]

        concept = most_active_band + '_' + brain_region        

        x = np.zeros((1, NUMBER_CHANNELS, NUMBER_SAMPLES))
        x[:,:19,:] = window_dict_full[window_key].copy()[0].get_data()[:,:NUMBER_SAMPLES].reshape(1,NUMBER_CHANNELS-1,NUMBER_SAMPLES)
        x[:,19,:] = np.ones((1, NUMBER_SAMPLES)) * -1  
        x = torch.from_numpy(x).float()
        
        picklePath = DATA_PATH_CONCEPTS + concept + '/' + FILE[:-4] + '_' + str(window_idx) + '_' + concept + '.pkl'
        with open(picklePath, 'wb') as handle:
            pickle.dump(x, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [5]:
DATA_PATH_CONCEPTS = '../../data/baseline concepts TUH/'

DATA_PATH_TUH = '/home/williamtheodor/Documents/DL for EEG Classification/data/TUH (raw)/'

RANDOM_FILE = 'aaaaavrx_s004_t000.edf'
RANDOM_FILE_PATH = DATA_PATH_TUH + RANDOM_FILE

subjects_dir, subject, trans, src_path, bem_path = get_fsaverage()

raw = get_TUH_raw(RANDOM_FILE_PATH)
fwd = get_fwd(raw.info, trans, src_path, bem_path)

#PARCELLATION = 'aparc.a2009s'
PARCELLATION = 'HCPMMP1_combined'

In [6]:
bands = {
    'Delta': (0.5, 4),
    'Theta': (4, 8),
    'Alpha': (8, 12),
    'Beta': (12, 30),
    'Gamma': (30, 70)
}

src = get_src(src_path)

labels = get_labels(subjects_dir, parcellation_name = 'HCPMMP1_combined')
label_names = [label.name for label in np.array(labels).flatten()]

In [7]:
for label_idx in range(len(label_names)):
    for band in bands:
    # make directory if it doesn't exist
        if not os.path.exists(f'{DATA_PATH_CONCEPTS}{band}_{label_names[label_idx]}'):
            os.makedirs(f'{DATA_PATH_CONCEPTS}{band}_{label_names[label_idx]}')

In [8]:
for FILE in sorted(os.listdir(DATA_PATH_TUH)):
    print(FILE)
    FILE_PATH = DATA_PATH_TUH + FILE
    

    raw_full = get_TUH_raw(FILE_PATH)
    window_dict_full, window_dict_bands, inverse_operators = process_raw(raw_full, fwd, bands, SNR=100)

    activities = get_TUH_activity(window_dict_bands, inverse_operators, bands, labels, label_names)

    save_TUH_concepts(window_dict_full, activities, FILE, DATA_PATH_CONCEPTS, bands, label_names)

/home/williamtheodor/Documents/DL for EEG Classification/data/TUH (raw)/aaaaanjj_s006_t005.edf


100%|██████████| 150/150 [06:18<00:00,  2.52s/it]


/home/williamtheodor/Documents/DL for EEG Classification/data/TUH (raw)/aaaaanke_s001_t000.edf


100%|██████████| 141/141 [05:45<00:00,  2.45s/it]


/home/williamtheodor/Documents/DL for EEG Classification/data/TUH (raw)/aaaaanyo_s001_t000.edf


100%|██████████| 17/17 [00:40<00:00,  2.41s/it]


/home/williamtheodor/Documents/DL for EEG Classification/data/TUH (raw)/aaaaanze_s001_t003.edf


100%|██████████| 152/152 [06:21<00:00,  2.51s/it]


/home/williamtheodor/Documents/DL for EEG Classification/data/TUH (raw)/aaaaanzr_s002_t003.edf


 16%|█▌        | 24/150 [00:59<05:26,  2.59s/it]