In [1]:
import mne
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
# read one sample data
mne_filename = "mne_data/C10_32Ch_48Subjects_al_raw_eeg.fif"
raw = mne.io.read_raw_fif(mne_filename, verbose=False)
events = mne.make_fixed_length_events(raw, start=0, stop=20, duration=2., overlap=1.5, id = 1)

In [3]:
# make epochs
event_id = dict(epoch=1)  # event trigger and conditions
tmin = 0  # start of each epoch
tmax = 2  # end of each epoch
baseline = (0, 0)  # means from the first instant to t = 0
reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
                    picks="all", baseline=baseline, verbose=False)

In [4]:
evoked = epochs.average()

In [5]:
epo_spectrum = evoked.compute_psd(method="welch")
psds, freqs = epo_spectrum.get_data(return_freqs=True)

Effective window size : 1.000 (s)


In [16]:
fmin = (0., 4., 8., 13., 30.)
fmax = (4., 8., 13., 30., 45.)
freqs = ["delta", "theta", "alpha", "beta", "gamma"]

In [25]:
psds[:, 4:8].mean(axis=1)

array([0.0109173 , 0.0127085 , 0.01569351, 0.01944721, 0.03386117,
       0.00346549, 0.01177934, 0.00791516, 0.00550541, 0.00826978,
       0.02057691, 0.02434618, 0.00400109, 0.06045927, 0.02189855,
       0.05764478, 0.01243475, 0.03420854, 0.02491956, 0.01569328,
       0.00966183, 0.00526133, 0.01602709, 0.01924308, 0.00912095,
       0.0113318 , 0.00494964, 0.02002104, 0.00783928, 0.01423841,
       0.02412435, 0.0175747 ])

# Full pipeline

In [27]:
mne_data = pd.read_csv("mne_data.csv")
mne_data

Unnamed: 0,pid,al,fa
0,1,mne_data/C1_32Ch_48Subjects_al_raw_eeg.fif,mne_data/C1_32Ch_48Subjects_fa_raw_eeg.fif
1,2,mne_data/C2_32Ch_48Subjects_al_raw_eeg.fif,mne_data/C2_32Ch_48Subjects_fa_raw_eeg.fif
2,3,mne_data/C3_32Ch_48Subjects_al_raw_eeg.fif,mne_data/C3_32Ch_48Subjects_fa_raw_eeg.fif
3,4,mne_data/C4_32Ch_48Subjects_al_raw_eeg.fif,mne_data/C4_32Ch_48Subjects_fa_raw_eeg.fif
4,5,mne_data/C5_32Ch_48Subjects_al_raw_eeg.fif,mne_data/C5_32Ch_48Subjects_fa_raw_eeg.fif
5,6,mne_data/C6_32Ch_48Subjects_al_raw_eeg.fif,mne_data/C6_32Ch_48Subjects_fa_raw_eeg.fif
6,7,mne_data/C7_32Ch_48Subjects_al_raw_eeg.fif,mne_data/C7_32Ch_48Subjects_fa_raw_eeg.fif
7,8,mne_data/C8_32Ch_48Subjects_al_raw_eeg.fif,mne_data/C8_32Ch_48Subjects_fa_raw_eeg.fif
8,9,mne_data/C9_32Ch_48Subjects_al_raw_eeg.fif,mne_data/C9_32Ch_48Subjects_fa_raw_eeg.fif
9,10,mne_data/C10_32Ch_48Subjects_al_raw_eeg.fif,mne_data/C10_32Ch_48Subjects_fa_raw_eeg.fif


In [28]:
mne_files = mne_data[["al", "fa"]].values.reshape(-1)

In [37]:
X = []
Y = []

for mne_filename in mne_files:
    print(f"Running: {mne_filename}")
    # read file and make events
    p_id = mne_filename[10:12].replace("_","")
    if "_al_" in mne_filename: 
        condition = "al"
    elif "_fa_" in mne_filename:
        condition = "fa"
    raw = mne.io.read_raw_fif(mne_filename, verbose=False)
    events = mne.make_fixed_length_events(raw, start=0, stop=20, duration=2., overlap=1.5, id = 1)
    # make epochs
    event_id = dict(epoch=1)  # event trigger and conditions
    tmin = 0  # start of each epoch
    tmax = 2  # end of each epoch
    baseline = (0, 0)  # means from the first instant to t = 0
    reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
                        picks="all", baseline=baseline, verbose=False)
    # compute psd
    evoked = epochs.average()
    epo_spectrum = evoked.compute_psd(method="welch")
    psds, freqs = epo_spectrum.get_data(return_freqs=True)
    fmin = (0., 4., 8., 13., 30.)
    fmax = (4., 8., 13., 30., 45.)
    freqs = ["delta", "theta", "alpha", "beta", "gamma"]

    # get the mean of the psd for each frequency band
    pds_band = []
    for i, (fmin, fmax) in enumerate(zip(fmin, fmax)):
        pds_band.append(psds[:, int(fmin):int(fmax)].mean(axis=1))
    
    # append to X and Y
    X.append(pds_band)
    Y.append(condition)

X = np.array(X)
Y = np.array([0 if y == "al" else 1 for y in Y])

Running: mne_data/C1_32Ch_48Subjects_al_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C1_32Ch_48Subjects_fa_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C2_32Ch_48Subjects_al_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C2_32Ch_48Subjects_fa_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C3_32Ch_48Subjects_al_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C3_32Ch_48Subjects_fa_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C4_32Ch_48Subjects_al_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C4_32Ch_48Subjects_fa_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C5_32Ch_48Subjects_al_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C5_32Ch_48Subjects_fa_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C6_32Ch_48Subjects_al_raw_eeg.fif
Effective window size : 1.000 (s)
Running: mne_data/C6_32Ch_48Subjects_fa_raw_eeg.fif
Ef

# classification

In [42]:
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV

from sklearn.metrics import accuracy_score, classification_report
from airtable.airtable import AirTableClient
from airtable.config import config
from pyairtable import Table

In [43]:
# logging to airtable
atc = AirTableClient(**config["psd"])

In [39]:
models = {
    "svm": {
        "model": SVC(kernel='linear', C=1),
        "params": {
            "C": [0.1, 1, 10, 100, 1000],
            "kernel": ['linear', 'rbf']
        }
    },
    "logistic": {
        "model": LogisticRegression(solver='liblinear', multi_class='auto'),
        "params": {
            "C": [0.1, 1, 10, 100, 1000],
            "solver": ['liblinear', 'lbfgs']
        }
    },
    "random_forest": {
        "model": RandomForestClassifier(),
        "params": {
            "n_estimators": [100, 200],
            "max_features": ['sqrt', 'log2'],
            "max_depth": [4, 5, 6],
            "criterion": ['gini', 'entropy']
        }
    },
    "decision_tree": {
        "model": DecisionTreeClassifier(),
        "params": {
            "criterion": ["gini", "entropy"],
            "splitter": ["best", "random"],
            "max_depth": [2, 3, 5],
            "min_samples_split": [2, 3, 5],
            "min_samples_leaf": [2, 3, 5]
        }
    },
}

In [41]:
def run_grid_search(model_name, model, params, X, y):
    """
    run grid search for a model
    """
    grid_search = GridSearchCV(model, params, cv=5, return_train_score=False)
    grid_search.fit(X, y)
    # print("Best params: ", grid_search.best_params_)
    # print("Best score: ", grid_search.best_score_)
    return grid_search.best_params_, grid_search.best_score_

In [44]:
for i, freq in enumerate(freqs):
    X_freq = X[:, i, :]
    X_train, X_test, Y_train, Y_test = train_test_split(X_freq, Y, test_size=0.5, random_state=12)
    for m in models.keys():
        print(f"Processing: model: {m}, freq: {freq}")
        best_params, best_score = run_grid_search(m, models[m]["model"], models[m]["params"], X_freq, Y)
        # run best model
        model = models[m]["model"].set_params(**best_params)
        model.fit(X_train, Y_train)
        Y_pred = model.predict(X_test)
        score = round(accuracy_score(Y_test, Y_pred),2)
        # print classification report
        print(classification_report(Y_test, Y_pred))
        
        full_classification_report = classification_report(Y_test, Y_pred)

        res = {
            "frequency": freq,
            "model": m,
            "best_params": str(best_params),
            "accuracy": str(score),
            "full accuracy report": full_classification_report
        }
        atc.add_row(res)


Processing: model: svm, freq: delta
              precision    recall  f1-score   support

           0       0.70      0.59      0.64        27
           1       0.56      0.67      0.61        21

    accuracy                           0.62        48
   macro avg       0.63      0.63      0.62        48
weighted avg       0.64      0.62      0.63        48

Processing: model: logistic, freq: delta


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logist

              precision    recall  f1-score   support

           0       0.00      0.00      0.00        27
           1       0.44      1.00      0.61        21

    accuracy                           0.44        48
   macro avg       0.22      0.50      0.30        48
weighted avg       0.19      0.44      0.27        48

Processing: model: random_forest, freq: delta
              precision    recall  f1-score   support

           0       0.61      0.41      0.49        27
           1       0.47      0.67      0.55        21

    accuracy                           0.52        48
   macro avg       0.54      0.54      0.52        48
weighted avg       0.55      0.52      0.52        48

Processing: model: decision_tree, freq: delta
              precision    recall  f1-score   support

           0       0.41      0.26      0.32        27
           1       0.35      0.52      0.42        21

    accuracy                           0.38        48
   macro avg       0.38      0.39   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Processing: model: logistic, freq: theta


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00        27
           1       0.44      1.00      0.61        21

    accuracy                           0.44        48
   macro avg       0.22      0.50      0.30        48
weighted avg       0.19      0.44      0.27        48

Processing: model: random_forest, freq: theta
              precision    recall  f1-score   support

           0       0.50      0.33      0.40        27
           1       0.40      0.57      0.47        21

    accuracy                           0.44        48
   macro avg       0.45      0.45      0.44        48
weighted avg       0.46      0.44      0.43        48

Processing: model: decision_tree, freq: theta
              precision    recall  f1-score   support

           0       0.43      0.22      0.29        27
           1       0.38      0.62      0.47        21

    accuracy                           0.40        48
   macro avg       0.41      0.42   

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logist

              precision    recall  f1-score   support

           0       0.73      0.59      0.65        27
           1       0.58      0.71      0.64        21

    accuracy                           0.65        48
   macro avg       0.65      0.65      0.65        48
weighted avg       0.66      0.65      0.65        48

Processing: model: random_forest, freq: alpha
              precision    recall  f1-score   support

           0       0.81      0.63      0.71        27
           1       0.63      0.81      0.71        21

    accuracy                           0.71        48
   macro avg       0.72      0.72      0.71        48
weighted avg       0.73      0.71      0.71        48

Processing: model: decision_tree, freq: alpha
              precision    recall  f1-score   support

           0       0.74      0.52      0.61        27
           1       0.55      0.76      0.64        21

    accuracy                           0.62        48
   macro avg       0.64      0.64   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Processing: model: random_forest, freq: beta
              precision    recall  f1-score   support

           0       0.65      0.41      0.50        27
           1       0.48      0.71      0.58        21

    accuracy                           0.54        48
   macro avg       0.57      0.56      0.54        48
weighted avg       0.58      0.54      0.53        48

Processing: model: decision_tree, freq: beta
              precision    recall  f1-score   support

           0       0.82      0.33      0.47        27
           1       0.51      0.90      0.66        21

    accuracy                           0.58        48
   macro avg       0.67      0.62      0.56        48
weighted avg       0.68      0.58      0.55        48

Processing: model: svm, freq: gamma
              precision    recall  f1-score   support

           0       0.50      0.26      0.34        27
           1       0.41      0.67      0.51        21

    accuracy                           0.44        48
  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Processing: model: random_forest, freq: gamma
              precision    recall  f1-score   support

           0       0.57      0.30      0.39        27
           1       0.44      0.71      0.55        21

    accuracy                           0.48        48
   macro avg       0.51      0.51      0.47        48
weighted avg       0.51      0.48      0.46        48

Processing: model: decision_tree, freq: gamma
              precision    recall  f1-score   support

           0       0.45      0.33      0.38        27
           1       0.36      0.48      0.41        21

    accuracy                           0.40        48
   macro avg       0.40      0.40      0.40        48
weighted avg       0.41      0.40      0.39        48

