In [59]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import mne, re, os

from utils import *

In [60]:
DATA_PATH_TUH = '/home/williamtheodor/Documents/DL for EEG Classification/113/aaaaaqtx/'
EDF_FOLDER = DATA_PATH_TUH + 's002_2014_02_08/01_tcp_ar/'
RANDOM_FILE = EDF_FOLDER + 'aaaaaqtx_s002_t000.edf'

fwd = get_fwd(raw.info, trans, src_path, bem_path)

#PARCELLATION = 'aparc.a2009s'
PARCELLATION = 'HCPMMP1_combined'

In [61]:
def fix_average_naming(raw):
    # is the average naming a problem?
    reSTR = r"(?<=EEG )(.*)(?=-)"
    reLowC = ['FP1', 'FP2', 'FZ', 'CZ', 'PZ']

    for channel_name in raw.ch_names:
        if re.search(reSTR, channel_name) and re.search(reSTR, channel_name).group() in reLowC:
            lowC = channel_name[0:5]+channel_name[5].lower()+channel_name[6:]
            mne.channels.rename_channels(raw.info, {channel_name: re.findall(reSTR, lowC)[0]})
        elif channel_name == "PHOTIC-REF":
            mne.channels.rename_channels(raw.info, {channel_name: "PHOTIC"})
        elif re.search(reSTR, channel_name):
            mne.channels.rename_channels(raw.info, {channel_name: re.findall(reSTR, channel_name)[0]})

    return raw

def pick_and_rename_TUH_channels(raw):

    raw = fix_average_naming(raw)

    mne.channels.rename_channels(raw.info, {'Fp1': 'FP1', 'Fp2': 'FP2', 'Fz': 'FZ', 'Cz': 'CZ', 'Pz': 'PZ'})
    mne.channels.rename_channels(raw.info, {'T3': 'T7', 'T4': 'T8'})

    EEG_20_div = [
                'FP1', 'FP2',
        'F7', 'F3', 'FZ', 'F4', 'F8',
        'T7', 'C3', 'CZ', 'C4', 'T8',
        'T5', 'P3', 'PZ', 'P4', 'T6',
                 'O1', 'O2'
    ]
    
    raw.pick_channels(ch_names=EEG_20_div)
    raw.reorder_channels(EEG_20_div)

    return raw

In [62]:
def get_TUH_raw(FILE, high_pass=0.5, low_pass=70):
    raw = mne.io.read_raw_edf(FILE, verbose=False, preload=True)
    raw = pick_and_rename_TUH_channels(raw)

    mne.datasets.eegbci.standardize(raw)  # Set channel names
    montage = mne.channels.make_standard_montage('standard_1020')
    raw = raw.set_eeg_reference(ref_channels='average', projection=True, verbose=False)
    raw = raw.set_montage(montage)

    raw.filter(high_pass, low_pass, verbose=False)

    return raw

def get_TUH_annotation(raw, WINDOW_LENGTH=4, SAMPLNIG_FREQ=256):
    new_onset = []
    new_duration = []
    new_description = []

    number_windows = int(raw.get_data().shape[1] / SAMPLNIG_FREQ / WINDOW_LENGTH)

    for window_idx in range(number_windows):
        new_onset.append(window_idx * WINDOW_LENGTH)
        new_duration.append(WINDOW_LENGTH)
        new_description.append(f'chop {window_idx}')

    new_onset = np.array(new_onset, dtype=np.float64)
    new_duration = np.array(new_duration, dtype=np.float64)

    annotations = mne.Annotations(onset=new_onset, duration=new_duration, description=new_description)

    return annotations

In [63]:
def process_raw(raw, bands, SNR=100):
    
    annotations = get_TUH_annotation(raw)

    raw_bands = [raw.copy().filter(bands[b][0], bands[b][1], verbose=False) for b in bands.keys()]

    covs = [get_cov(raw_band) for raw_band in raw_bands]
    inverse_operators = [make_fast_inverse_operator(raw_band.info, fwd, cov, snr=SNR) for raw_band, cov in zip(raw_bands, covs)]

    window_dict_bands = [get_window_dict(raw_band, annotations) for raw_band in raw_bands]

    return window_dict_bands, inverse_operators

In [64]:
bands = {
    'Delta': (0.5, 4),
    'Theta': (4, 8),
    'Alpha': (8, 12),
    'Beta': (12, 30),
    'Gamma': (30, 70)
}

subjects_dir, subject, trans, src_path, bem_path = get_fsaverage()
src = get_src(src_path)

labels = get_labels(subjects_dir, parcellation_name = 'HCPMMP1_combined')
label_names = [label.name for label in np.array(labels).flatten()]

In [33]:
raw_full = get_TUH_raw(RANDOM_FILE)
window_dict_bands, inverse_operators = process_raw(raw_full, bands, SNR=100)

In [44]:
activity = np.zeros((len(window_dict.keys()), len(list(bands.keys())), len(label_names)))

for i, window_key in tqdm(enumerate(window_dict.keys()), total=len(window_dict.keys())):
    raw_bands = [window_dict_bands[band_idx][window_key][0] for band_idx in range(len(bands.keys()))]
    for j, band_key in enumerate(bands.keys()):
        raw_temp = window_dict_bands[j][window_key][0]

        stc = inverse_operators[j](raw_temp)
        activity[i, j] = np.concatenate(get_power_per_label(stc, labels, standardize=False)).reshape(1, -1)

    

100%|██████████| 75/75 [03:10<00:00,  2.55s/it]


In [51]:
baseline_activity = np.mean(activity, axis=0)
activity -= baseline_activity

In [55]:
np.argmax(activity[0].mean(axis=1))

2

In [57]:
a = activity[0]
most_active_band_idx = np.argmax(a.mean(axis=1))
most_active_band = list(bands.keys())[most_active_band_idx]

brain_region_idx = a[most_active_band_idx].argmax()
brain_region = label_names[brain_region_idx]

concept = most_active_band + '_' + brain_region

In [58]:
concept

'Alpha_Inferior Frontal Cortex-lh'

In [41]:
np.array([1,2]).reshape(1,-1).shape

(1, 2)