# 1. Import Required Libraries

In [59]:
import mne
import torch
import os
from pathlib import Path 
import logging
import random
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mne.preprocessing import annotate_muscle_zscore
from mne.time_frequency import tfr_morlet
import math

epochDuration = 1

study_epochs = {
    'td': [],
    'asd': []
}

matplotlib.use('Qt5Agg')
mne.set_log_level('warning')

# Set seed
#random.seed(42) 

# torch.cuda.is_available()

# 2. Load Raw EEG Data

In [60]:
data_path = Path("data/")
raw_eeg_path = data_path / "raw_eeg"

def checkPath(dir):
    # If the image folder doesn't exist, download it and prepare it... 
    if not data_path.is_dir():
        logging.error(f"{data_path} directory DOES NOT exists.")

def walkThroughDir(dir_path):
    checkPath(dir_path)
    for dirpath, dirnames, filenames in os.walk(dir_path):
        print(f"There are {len(dirnames)} directories and {len(filenames)} file(s) in '{dirpath}'.")

# Assumes file is formatted as {ID}_{type}_{XXXHz}.{extension}. Example TD100_raw_512Hz.asc
def csvToRaw(file, fMax=40):
    print(f"Processing {file.stem}")
    data = pd.read_csv(file, sep='\t')

    try:
        data =  data.drop(['VEOG - LkE', 'HEOG - LkE', 'Unnamed: 34'], axis=1)
    except:
        print("No EOG Channels found")
    # Get Channels
    channels = list(data.columns)

    # Format Channel names
    f = lambda str: str.split("-")[0].replace(" ", "")
    channels = [f(x) for x in channels]
    channel_count = len(channels)

    # Load Data
    data = data.transpose()
    ch_types = np.full((channel_count), "eeg")
    sfreq = int(file.stem.split("_")[2].replace("Hz", ""))
    info = mne.create_info(ch_names = channels, sfreq = sfreq, ch_types=ch_types)
    raw = mne.io.RawArray(data, info)

    # Format data date for annotations later
    raw.set_meas_date(0)
    raw.set_montage("standard_1020")

    # Convert from uV to V for MNE
    raw.apply_function(lambda x: x * 1e-6)

    # Mark bad data
    # Addressing this later now
    # markMuscleArtifacts(raw, 2)

    filtered = raw.copy().filter(l_freq=1.0, h_freq=fMax)

    return raw, filtered

def markMuscleArtifacts(raw, threshold, plot=False):
    print("markMuscleArtifacts")
    threshold_muscle = threshold  # z-score
    annot_muscle, scores_muscle = annotate_muscle_zscore(
    raw, ch_type="eeg", threshold=threshold_muscle, min_length_good=0.2,
    filter_freq=[0, 60])
    raw.set_annotations(annot_muscle)

    if plot:
        fig, ax = plt.subplots()
        start = 512 * 10
        end = 512 * 20
        ax.plot(raw.times[:end], scores_muscle[:end])
        ax.axhline(y=threshold_muscle, color='r')
        ax.set(xlabel='time, (s)', ylabel='zscore', title='Muscle activity')
        plt.show()
    
def createEvent(onset, label, sFreq):
    return [onset * sFreq, 0, label]

def addEpochs(processedData, start, file_length_seconds, label):
    #processedData.plot(start=start, duration=duration)
    stop = start + epochDuration
    events = mne.make_fixed_length_events(processedData, start=start, stop=file_length_seconds, duration=epochDuration)
    epochs = mne.Epochs(processedData, events, tmin=0, tmax=epochDuration, event_id={"event": 1}, baseline=(0, 0))
    study_epochs[label].append(epochs)
    return epochs

def process(file):
    raw, filtered = csvToRaw(file)
    file_length = math.floor(len(filtered.times) / float(filtered.info['sfreq']))
    #print(f"{file_length}s file length")
    label = file.parent.stem
    epochs = addEpochs(filtered, 1.0, file_length, label)  # first ten seconds
    return filtered, epochs

def getRandomFile(data_list):
    random_raw_eeg_path = random.choice(data_list)
    label = random_raw_eeg_path.parent.stem
    return random_raw_eeg_path, label
    

def dropBadEpochs(epochs):
    reject_criteria = dict(eeg=150e-6) # 150 µV
    flat_criteria = dict(eeg=1e-6) # 1 µV
    epochs.drop_bad(reject=reject_criteria, flat=flat_criteria)
    #epochs.plot_drop_log()

def runICA(epochs, title):
    print("Running ICA")
    n_components = 0.9  # Should normally be higher, like 0.999!!
    method = 'picard'
    fit_params = dict(fastica_it=5)
    random_state = 42

    ica = mne.preprocessing.ICA(n_components=n_components,
        method=method,
        fit_params=fit_params,
        random_state=random_state)

    ica.fit(epochs)
    ica.plot_components(inst=epochs, title=title)
    print("ICA Done")

def rawToEpochs():
    data_path_list = list(raw_eeg_path.glob("train/*/*.asc"))
    count = 0
    eeg_path, label = getRandomFile(data_path_list)
    #filtered, epochs = process(eeg_path)

    for file in data_path_list:
        process(file)
        count += 1

    asd_concat_epochs = mne.concatenate_epochs(study_epochs['asd'])
    td_concat_epochs = mne.concatenate_epochs(study_epochs['td'])
    print(f"{count} files processed.")
    print(f"{len(study_epochs['asd'])} asd epoch objects")
    print(f"{len(study_epochs['td'])} td td objects")

    # Drop Bad Epochs
    dropBadEpochs(asd_concat_epochs)
    dropBadEpochs(td_concat_epochs)

    asd_concat_epochs.save(Path('out_data') / 'asd_concat_1_40_hz_bad_dropped_epo.fif', overwrite=True)
    td_concat_epochs.save(Path('out_data') / 'td_concat_1_40hz_bad_dropped_epo.fif', overwrite=True)

    #asd_concat_epochs.plot()
    #runICA(asd_concat_epochs)

def icaClean():
    asd_epochs = mne.read_epochs(Path('out_data') / 'asd_concat_1_40_hz_bad_dropped_epo.fif')
    td_epochs = mne.read_epochs(Path('out_data') / 'td_concat_1_40hz_bad_dropped_epo.fif')
    #asd_epochs.plot(title="ASD")
    #td_epochs.plot(title="TD")
    runICA(asd_epochs, "ASD")
    #runICA(td_epochs, "TD")

icaClean()

# main()


#filtered.plot()
#epochs.plot()

# epochs.drop_bad(reject=reject_criteria, flat=flat_criteria)

#epochs.plot_drop_log()

# epochs['event'].plot_image()

#epochs.plot_sensors(ch_type='all',title="sensors")


#asd_concat_epochs.save(Path('out_data') / 'asd_concat_epochs.fif', overwrite=True)
#td_concat_epochs.save(Path('out_data') / 'td_concat_epochs.fif', overwrite=True)

Running ICA


  ica.fit(epochs)


In [None]:

#walkThroughDir(data_path)


#print(asd_concat_epochs)

#p.plot(title="Power")

'''
for epochs in study_epochs['asd']:
    #epochs = study_epochs[i]
    p = epochs.compute_psd()
    p.plot()
'''

#asd_concat_epochs.plot(title="ASD")
#td_concat_epochs.plot(title="TD")

#print(asd_concat_epochs)


#asd_concat_epochs.plot_psd_topomap(ch_type='eeg')

#print(td_concat_epochs)

'''
power_asd_concat_epochs = asd_concat_epochs.compute_psd()
power_asd_concat_epochs.plot(title="power_asd_concat_epochs")

power_td_concat_epochs = td_concat_epochs.compute_psd()
power_td_concat_epochs.plot(title="power_td_concat_epochs")
'''