In [None]:
from glob import glob
import os
import mne
from mne.preprocessing import annotate_muscle_zscore
from IPython.display import display, HTML
from pathlib import Path  

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
#import torch
import seaborn as sns
sns.set_theme()
mne.set_log_level("critical")


asd_file_path=sorted(glob('data/asd/*.set'))
td_file_path=sorted(glob('data/td/*.set'))

asd_files = [i for i in asd_file_path]
td_files = [i for i in td_file_path]

In [None]:
# Find bad spans of data using mne.preprocessing.annotate_muscle_zscore
def markMuscleArtifacts(raw, threshold, plot=False):
    #print("markMuscleArtifacts")
    threshold_muscle = threshold  # z-score
    annot_muscle, scores_muscle = annotate_muscle_zscore(
    raw, ch_type="eeg", threshold=threshold_muscle, min_length_good=0.2,
    filter_freq=[0, 60])
    raw.set_annotations(annot_muscle)

# Drop epochs marked as bad
def dropBadEpochs(epochs, plotLog=False):
    reject_criteria = dict(eeg=150e-6) # 150 µV
    flat_criteria = dict(eeg=1e-6) # 1 µV
    epochs.drop_bad(reject=reject_criteria, flat=flat_criteria)
    if plotLog: epochs.plot_drop_log()

# Get raw data and return filtered version
def get_filtered_data(file_path, l_freq = 0.5, h_freq = 40, muscle_thresh = 5):
    print("loading", file_path)

    # Load raw data fronm file
    raw = mne.io.read_raw_eeglab(file_path,preload=True)

    # Mark bad data
    markMuscleArtifacts(raw, muscle_thresh)
  
    # Create even length epochs
    epochs = mne.make_fixed_length_epochs(raw,duration=1,overlap=0, preload=True)

    # Drop Bad Epochs
    dropBadEpochs(epochs, False)

    #print(epochs)

    # Apply filter
    filtered = epochs.copy().filter(l_freq=l_freq, h_freq=h_freq)

    return filtered

def checkLeftNumbers(channel):
  return (("1" in channel) or ("3" in channel) or ("5" in channel) or ("7" in channel) or ("9" in channel)) and ("10" not in channel)

def checkRightNumbers(channel):
    return ("2" in channel) or ("4" in channel) or ("6" in channel) or ("8" in channel) or ("10" in channel)

def getFrontalChannels(raw):
  _channels = [i for i in raw.ch_names if "F" in i and "C" not in i]
  return raw.copy().pick_channels(_channels)

def getCentralChannels(raw):
  _channels = [i for i in raw.ch_names if "C" in i and "F" not in i]
  return raw.copy().pick_channels(_channels)

def getPosteriorChannels(raw):
  _channels = [i for i in raw.ch_names if ("P" in i or "O" in i) and "C" not in i]
  return raw.copy().pick_channels(_channels) 

def getLeftChannels(raw):
  _channels = [i for i in raw.ch_names if checkLeftNumbers(i)]
  return raw.copy().pick_channels(_channels) 

def getRightChannels(raw):
  _channels = [i for i in raw.ch_names if checkRightNumbers(i)]
  return raw.copy().pick_channels(_channels)

def getPSD(epochs, fmax = 40):
   return epochs.compute_psd(fmax=fmax)

def split(file):
  raw = get_filtered_data(file)
  frontal = getFrontalChannels(raw)
  central = getCentralChannels(raw)
  posterior = getPosteriorChannels(raw)
  left = getLeftChannels(raw)
  right = getRightChannels(raw)
  return [ frontal, central,  posterior,  left, right, raw]

In [185]:
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold,GridSearchCV,cross_val_score,cross_validate 

# ML Approach
debug_asd_file_path=sorted(glob('data_debug/asd/*.set'))
debug_td_file_path=sorted(glob('data_debug/td/*.set'))

def extractFetures(psd):
    print(psd.shape)
    features = []
    for epoch in psd:
        #print(epoch.shape)
        alpha = np.mean(scaleEEGPower(epoch[:, 8:12]), axis=(1,0)) #get mean alpha value for this epoch across all channels
        delta = np.mean(scaleEEGPower(epoch[:, 0:4]), axis=(1,0))
        theta = np.mean(scaleEEGPower(epoch[:, 4:8]), axis=(1,0))
        features.append([alpha])
        #print(features)
    return features
    
def handlePSD(raw):
    psd = getPSD(raw)
    psd_np = psd.get_data() 
    features = extractFetures(psd_np)
    return features

def getFrontalPSDFeatures(file):
    filtered_data = get_filtered_data(file)
    frontal = getPosteriorChannels(filtered_data)
    return handlePSD(frontal)

def getPosteriorPSDFeatures(file):
    filtered_data = get_filtered_data(file)
    frontal = getFrontalChannels(filtered_data)
    return handlePSD(frontal)

def getFeatureLabels(extractor, control_file_path, experiment_file_path):
    control_epochs = [extractor(subject) for subject in control_file_path] # subject, epochs, features(2) 
    experiment_epochs = [extractor(subject) for subject in experiment_file_path] # subject, epochs, features(2)
    control_epochs_labels = [len(i) * [0] for i in control_epochs]
    experiment_epochs_labels = [len(i) * [1] for i in experiment_epochs ]
    data_list = control_epochs + experiment_epochs
    label_list = control_epochs_labels + experiment_epochs_labels
    groups_list = [[i]*len(j) for i, j in enumerate(data_list)]
    data_array = np.vstack(data_list)
    label_array = np.hstack(label_list)
    group_array = np.hstack(groups_list)
    return [data_array, label_array, group_array] 

def evalLogisticRegressionML(data_array, label_array, group_array, splits=5):     
    clf=LogisticRegression()
    gkf=GroupKFold(n_splits=splits)
    param_grid = {'classifier__C': [0.01,0.05,0.1,0.5, 1,2,3,4,5,8, 10,12,15]}
    pipe=Pipeline([('scaler',StandardScaler()),('classifier',clf)])
    gscv=GridSearchCV(pipe, param_grid, cv=gkf, n_jobs=16)
    gscv.fit(data_array, label_array, groups=group_array)
    print(gscv.best_score_)
    return gscv.best_score_

#data_array, label_array, group_array = getFeatureLabels(getFrontalPSDFeatures, debug_td_file_path, debug_asd_file_path)
#evalLogisticRegressionML(data_array, label_array, group_array, 5)
data_array, label_array, group_array = getFeatureLabels(getPosteriorPSDFeatures, debug_td_file_path, debug_asd_file_path)
evalLogisticRegressionML(data_array, label_array, group_array, 2)

print(data_array.shape, label_array.shape, group_array.shape)

#print(len(control_epochs), len(control_epochs[0]))
#print("data list", data_list)
#print(groups_list)
#print(group_array)
#print(len(control_epochs[0]), control_epochs[0])

'''

'''

loading data_debug/td/24Abby_Resting.set
(59, 19, 41)
loading data_debug/td/30Abby_Resting.set
(110, 54, 41)
loading data_debug/asd/1Abby_Resting.set
(103, 18, 41)
loading data_debug/asd/2Abby_Resting.set
(138, 16, 41)
0.2570124639546245
(410, 1) (410,) (410,)


'\n\n'

In [None]:

# Prepare Data Frame for plotting PSD

df = pd.DataFrame({})
regions_id = ["frontal", "central", "posterior", "left", "right"]

groups = [{"data": asd_files, "group_id": "ASD"}, {"data": td_files, "group_id": "TD"}]

# Find a better way to format this at some point. Nested loops are very slow
# Prepare data frame to plot PSD comparison
for type_id, group in enumerate(groups):
    for subject_id, subject in enumerate(group['data']):
        print(subject_id/len(group['data']))
        frontal, central, posterior, left, right, all = split(subject)
        regions = [frontal, central, posterior, left, right]
        for id_region, region in enumerate(regions):
            psd = getPSD(region)
            psd_np = psd.get_data()
            _psd = np.mean(psd_np, axis=(0))
            for id_channel, channel in enumerate(_psd):
                for id_freq, sample in enumerate(channel):
                    # Data type, Subject ID, Brain Region, Channel Name, Freq Bin, Sample
                    #print(data_type, subject_id, brain_region, frontal.ch_names[idx], idj, sample)
                    new_row = pd.DataFrame.from_records([{
                        "Type":group["group_id"],  
                        "Subject": subject_id,
                        "Region": regions_id[id_region],
                        "Channel": region.ch_names[id_channel],
                        "Freq": id_freq,
                        "Sample": sample
                    }])
                    df = pd.concat([df, new_row])
                    
filepath = Path('data/asd_td_3.csv')  
filepath.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(filepath, index=False)   


data/asd/10Abby_Resting.set
data/asd/11Abby_Resting.set
data/asd/12Abby_Resting.set
data/asd/13Abby_Resting.set
data/asd/14Abby_Resting.set
data/asd/15Abby_Resting.set
data/asd/16Abby_Resting.set
data/asd/17Abby_Resting.set
data/asd/18Abby_Resting.set
data/asd/19Abby_Resting.set
data/asd/1Abby_Resting.set
data/asd/20Abby_Resting.set
data/asd/21Abby_Resting.set
data/asd/22Abby_Resting.set
data/asd/23Abby_Resting.set
data/asd/25Abby_Resting.set
data/asd/26Abby_Resting.set
data/asd/27Abby_Resting.set
data/asd/28Abby_Resting.set
data/asd/29Abby_Resting.set
data/asd/2Abby_Resting.set
data/asd/3Abby_Resting.set
data/asd/4Abby_Resting.set
data/asd/5Abby_Resting.set
data/asd/6Abby_Resting.set
data/asd/7Abby_Resting.set
data/asd/8Abby_Resting.set
data/asd/9Abby_Resting.set


In [79]:
# Creating Feature Structure


def getTotalSubjects(groups):
    subject_count = 0
    for i in groups:
        subject_count += len(i["data"])
    return subject_count

def extractFetures(signal):
    delta = np.mean(signal[:4])
    theta = np.mean(signal[4:8])
    alpha = np.mean(signal[8:12])
    beta = np.mean(signal[13:30])
    gamma = np.mean(signal[30:40])
    return [gamma]

# https://mne.discourse.group/t/psd-multitaper-output-conversion-to-db-welch/4445
def scaleEEGPower(powerArray):
    powerArray = powerArray * 1e6**2 
    powerArray = (10 * np.log10(powerArray))
    return powerArray

def extractRegionFeatures(epoch, ranges=[[1,4], [8,12]]):
    psd = getPSD(epoch)
    epochs_features = []
    psd_np = psd.get_data()

    for i in ranges:
        #print(i)
        band = scaleEEGPower(psd_np[:, :, i[0]:i[1]])
        band_channel_mean = np.mean(band, axis=(1, 2)) # average across channels
        #print(band_channel_mean)
        #print(band_channel_mean.shape)
        epochs_features.append(band_channel_mean)
        #alpha_reshape = alpha.reshape(95, -1)
        #print(alpha_reshape.shape)
        #print(alpha_reshape)
    return np.stack(epochs_features, axis=1)

regions_id = ["frontal", "central", "posterior", "left", "right"]
groups = [{"data": asd_files, "group_id": "ASD"}, {"data": td_files, "group_id": "TD"}]
frontal, central, posterior, left, right, all = split(groups[0]['data'][0])
featureCount = 2
regions = [frontal, central, posterior, left, right]
regionCount = len(regions)
epochs = len(frontal)
X_2d = np.empty([epochs, regionCount, featureCount], dtype=float)
print("X_2d", X_2d.shape)



def main():
    x = []
    for region_id, region_epoch in enumerate(regions):
        x.append(extractRegionFeatures(region_epoch))
    t = np.stack(x, axis=1)
    print(t.shape)
    print(t)

main()




loading data/asd/10Abby_Resting.set
X_2d (95, 5, 2)
(95, 5, 2)
[[[29.41869241 23.55221255]
  [29.59297324 21.5986365 ]
  [30.48340109 23.15271954]
  [29.99696621 22.0311355 ]
  [29.43130665 23.47489277]]

 [[32.38116904 25.9677469 ]
  [33.10560248 22.83242218]
  [33.57315291 24.09642152]
  [33.79233419 24.88147251]
  [32.03962241 23.84113524]]

 [[30.05619125 24.80275104]
  [28.81464362 23.55476347]
  [29.47758947 25.49498887]
  [30.35370728 24.4386963 ]
  [29.31396835 25.0349984 ]]

 [[31.93598392 28.86114097]
  [32.58539139 24.95787665]
  [32.32100977 30.86075529]
  [31.95000951 27.73653123]
  [32.62552974 29.37728294]]

 [[28.95232229 27.71923735]
  [27.68797821 23.89078334]
  [28.95967624 28.14673871]
  [27.8110229  26.517613  ]
  [29.45523694 27.03491202]]

 [[31.74908212 29.83622613]
  [30.51753159 26.86302393]
  [30.63977491 30.50776835]
  [31.84960462 28.33927623]
  [30.00468652 30.23378131]]

 [[28.51191847 29.03023241]
  [29.94504892 25.15732632]
  [30.10383585 28.86163993]
 

In [None]:
# https://pandas.pydata.org/docs/getting_started/intro_tutorials/03_subset_data.html
#df = pd.read_csv(filepath) 
#display(df)
#df = df[(df["Freq"] < 12) & (df["Freq"] > 8)]display(df)


# Type, Subject, Region, Channel, Band, Value
#sns.lineplot(data=df, x="Freq", y="Sample", hue="Type")
#plt.show()

In [None]:
# Misc

print(len(asd_files))
print(len(td_files))

def channel_in_file(channel, file_path):
    #print("loading", file_path)
    datax=mne.io.read_raw_eeglab(file_path,preload=True)
    return channel in datax.ch_names

def checkChannels(channels, file_path):
  results = []
  for idx, channel in enumerate(channels):
    results.append(0)
    for file in file_path: 
      print(file)
      results[idx] += channel_in_file(channel, file)
      print(results)


#ASD
#channels = np.array(['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'])
#channel_count = np.array([27, 27, 26, 25, 25, 24, 27, 24, 26, 23, 26, 23, 26, 26, 18, 23, 23, 25, 25, 27, 25, 26, 23, 27, 26, 26, 27, 24, 25, 27, 26, 25, 26, 27, 25, 24, 24, 26, 26, 22, 24, 25, 22, 21, 24, 24, 25, 23, 24, 22, 20, 23, 26, 26, 24, 27, 24, 24, 26, 27, 25, 25, 27])
#greater_25 = ['Fp1', 'AF7', 'AF3', 'F7', 'FC5', 'FC1', 'C3', 'C5', 'P1', 'P5', 'P9', 'PO7', 'PO3', 'O1', 'POz', 'Pz', 'Fpz', 'Fp2', 'Fz', 'F2', 'CP6', 'CP4', 'P2', 'P8', 'P10', 'O2']
#output = channels[channel_count > 25]
#print(output)
#checkChannels(channels, asd_file_path)

def read_data(file_path):
    pass
    #print("loading", file_path)
    #raw = mne.io.read_raw_eeglab(file_path,preload=True)
    #return raw
    #epochs=mne.make_fixed_length_epochs(datax,duration=1,overlap=0, preload=True)
    #print("original ", epochs.ch_names)
    #epochs.pick_channels(['Fp1', 'Fpz', 'Fp2', 'O1', 'Oz', 'O2', 'Fz', 'FCz'])

    '''
    datax.set_eeg_reference()
    datax.filter(l_freq=1,h_freq=45)
    epochs=mne.make_fixed_length_epochs(datax,duration=1,overlap=0)
    epochs_np=epochs.get_data()
    return epochs, epochs_np #trials,channel,length
    '''