In [2]:
import mne
mne.set_log_level('CRITICAL')

#import torch
import os
from pathlib import Path 
import logging
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mne.preprocessing import annotate_muscle_zscore
from mne.io import concatenate_raws, read_raw_edf
from mne.time_frequency import tfr_morlet
import math
from mne.preprocessing import (create_eog_epochs, create_ecg_epochs,
                               corrmap)
#matplotlib.use('Qt5Agg')
#mne.set_log_level('warning')
print("eeg")

data_path = Path("data/")
raw_eeg_path = data_path / "raw_eeg"

label_ids = {
    "asd": 1,
    "td": 2
}

study_epochs = {
    'td': [],
    'asd': []
}

# Assumes file is formatted as {ID}_{type}_{XXXHz}.{extension}. Example TD100_raw_512Hz.asc
def csvToRaw(file, fMax=40):
    print(f"Processing {file.stem}")
    data = pd.read_csv(file, sep='\t')

    try:
        data =  data.drop(['VEOG - LkE', 'HEOG - LkE', 'Unnamed: 34'], axis=1)
    except:
        print("No EOG Channels found")
    # Get Channels
    channels = list(data.columns)

    # Format Channel names
    f = lambda str: str.split("-")[0].replace(" ", "")
    channels = [f(x) for x in channels]
    channel_count = len(channels)

    # Load Data
    data = data.transpose()
    ch_types = np.full((channel_count), "eeg")
    sfreq = int(file.stem.split("_")[2].replace("Hz", ""))
    info = mne.create_info(ch_names = channels, sfreq = sfreq, ch_types=ch_types)
    raw = mne.io.RawArray(data, info)

    # Format data date for annotations later
    raw.set_meas_date(0)
    raw.set_montage("standard_1020")

    # Convert from uV to V for MNE
    raw.apply_function(lambda x: x * 1e-6)

    # Mark bad data
    # Addressing this later now
    markMuscleArtifacts(raw, 2)

    filtered = raw.copy().filter(l_freq=1.0, h_freq=fMax)

    return raw, filtered

# Find bad spans of data using mne.preprocessing.annotate_muscle_zscore
def markMuscleArtifacts(raw, threshold, plot=False):
    #print("markMuscleArtifacts")
    threshold_muscle = threshold  # z-score
    annot_muscle, scores_muscle = annotate_muscle_zscore(
    raw, ch_type="eeg", threshold=threshold_muscle, min_length_good=0.2,
    filter_freq=[0, 60])
    raw.set_annotations(annot_muscle)

    if plot:
        fig, ax = plt.subplots()
        start = 512 * 10
        end = 512 * 20
        ax.plot(raw.times[:end], scores_muscle[:end])
        ax.axhline(y=threshold_muscle, color='r')
        ax.set(xlabel='time, (s)', ylabel='zscore', title='Muscle activity')
        plt.show()

def plot(raw, start=0, duration=10.0):
    raw.plot(show_scrollbars=False, show_scalebars=False, duration=10.0, start=20.0)

# Create epoch data, get ICs, and add to study_epochs object
def addEpochs(data, raw, start, file_length_seconds, label, file="",epochDuration=1):
    stop = start + epochDuration
    events = mne.make_fixed_length_events(data,  label_ids[label], start=start, stop=file_length_seconds, duration=epochDuration)
    epochs = mne.Epochs(data, events, tmin=0, tmax=epochDuration, event_id={label: label_ids[label]}, baseline=(0, 0), preload=True)
    #epochs.plot(title="before")
    dropBadEpochs(epochs)
   
    # Run ICA
    ica = runICA(epochs)
    eog_epochs = create_eog_epochs(raw, "Fp1")
    eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, "Fp2",  threshold=1.25)
    ica.exclude = eog_inds
    #print(eog_inds)

    # Get Clean epochs
    cleaned_epochs = ica.apply(epochs.copy())
    #fig = cleaned_epochs.compute_psd().plot()
    #fig.suptitle(file, fontsize=16)

    #cleaned_epochs.plot(title="after")
    study_epochs[label].append(cleaned_epochs)
    # update with cleaned epochs
    return cleaned_epochs

# Reject epochs based on maximum acceptable peak-to-peak amplitude 
# https://mne.tools/stable/auto_tutorials/preprocessing/20_rejecting_bad_data.html#sphx-glr-auto-tutorials-preprocessing-20-rejecting-bad-data-py
def dropBadEpochs(epochs, plotLog=False):
    reject_criteria = dict(eeg=150e-6) # 150 µV
    flat_criteria = dict(eeg=1e-6) # 1 µV
    epochs.drop_bad(reject=reject_criteria, flat=flat_criteria)
    if plotLog: epochs.plot_drop_log()

# Get ICs 
def runICA(epochs):
    #print("Running ICA")
    n_components = 0.99  # Should normally be higher, like 0.999!!
    method = 'picard'
    # Picard method requires python 3.7
    #method = 'fastica'
    fit_params = dict(fastica_it=5)
    random_state = 42

    ica = mne.preprocessing.ICA(n_components=n_components,
        method=method,
        fit_params=fit_params,
        random_state=random_state)

    ica.fit(epochs)
    return ica
 
# Process raw file. 
def process(file, epoch_len_seconds):
    raw, filtered = csvToRaw(file)
    #plot(raw)
    #plot(filtered)
    #fig = filtered.complotRawpute_psd().plot()
    file_length = math.floor(len(filtered.times) / float(filtered.info['sfreq']))
    #print(f"{file_length}s file length")
    label = file.parent.stem
    print(label)
    epochs = addEpochs(filtered, raw, epoch_len_seconds, file_length, label, file) 
    np_all_epochs = epochs.get_data()
    #print(np_all_epochs.shape)

    #return filtered, epochs
    
    
def main():
    dir = "train"
    data_path_list = list(raw_eeg_path.glob(f"{dir}/*/*.asc"))
    epoch_len_seconds = 1.0

    
    for file in data_path_list:
        print(file.stem)
        process(file, epoch_len_seconds)

    asd_concat_epochs = mne.concatenate_epochs(study_epochs['asd'])
    td_concat_epochs = mne.concatenate_epochs(study_epochs['td'])
    asd_concat_epochs.save(Path('out_data/2023') / f'asd_concat_cleaned_1_40hz_epo.fif', overwrite=True)
    td_concat_epochs.save(Path('out_data/2023') / f'td_concat_cleaned_1_40hz_epo.fif', overwrite=True)
    train_all_epochs = mne.concatenate_epochs([asd_concat_epochs, td_concat_epochs])
    train_all_epochs.equalize_event_counts(train_all_epochs.event_id)
    train_all_epochs.save(Path('out_data/2023') / f'all_epo_asd_td.fif')
    #print(study_epochs)
    #process(data_path_list[0])
    
    #print(data_path_list)

main()


eeg
TD109_raw_512Hz
Processing TD109_raw_512Hz
td
TD110_raw_512Hz
Processing TD110_raw_512Hz
td
TD107_raw_512Hz
Processing TD107_raw_512Hz
td
TD108_raw_512Hz
Processing TD108_raw_512Hz
td
TD106_raw_512Hz
Processing TD106_raw_512Hz
td
TD113_raw_512Hz
Processing TD113_raw_512Hz
td
TD112_raw_512Hz
Processing TD112_raw_512Hz
td
TD104_raw_512Hz
Processing TD104_raw_512Hz
td
TD105_raw_512Hz
Processing TD105_raw_512Hz
td
TD114_raw_512Hz
Processing TD114_raw_512Hz
td
A109_raw_512Hz
Processing A109_raw_512Hz
asd
A113_raw_512Hz
Processing A113_raw_512Hz
asd
A103_raw_512Hz
Processing A103_raw_512Hz
asd
A101_raw_512Hz
Processing A101_raw_512Hz
asd
A112_raw_512Hz
Processing A112_raw_512Hz
asd
A111_raw_512Hz
Processing A111_raw_512Hz
asd
A104_raw_512Hz
Processing A104_raw_512Hz
asd
A108_raw_512Hz
Processing A108_raw_512Hz
asd
A114_raw_512Hz
Processing A114_raw_512Hz
asd
