In [95]:
from mne.io import RawArray
from mne import create_info
from mne.time_frequency import psd_welch
from mne import Epochs
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import (
    FunctionTransformer,
    StandardScaler,
)
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report



import numpy as np
import joblib
import mne

In [2]:
srate = 250
epoch_len = srate
CHANNELS = ['Cz', 'FC2', 'CP2', 'C4', 'FC6', 'CP6', 'T8']
FREQ_BANDS_RANGE = {
    'DELTA': [0.5, 4.5],
    'THETA': [4.5, 8.5],
    'ALPHA': [8.5, 11.5],
    'SIGMA': [11.5, 15.5],
    'BETA': [15.5, 30]
}
RAW_INFO = create_info(sfreq=srate, ch_types='eeg', ch_names=CHANNELS)

In [21]:
filename = 'dataset/OpenBCI-RAW-2020-10-12_08-41-25_JD_02.txt'
skiprows = 7 + srate*7

# filename = 'dataset/OpenBCI-RAW-2020-10-10_20-58-47-Josquin-Duchaine.txt'
# skiprows = 7

openbci_dataset = np.loadtxt(filename, skiprows=skiprows, delimiter=',', usecols=(1,2,3,4,5,6,7))
openbci_dataset = openbci_dataset.T

In [22]:
print(openbci_dataset.shape)

(7, 193495)


In [23]:
raw_array = RawArray(openbci_dataset, info=RAW_INFO, verbose=False)
# raw_array.notch_filter(60)
raw_array.filter(0.5,30)


<RawArray | 7 x 193495 (774.0 s), ~10.4 MB, data loaded>

In [37]:
raw_array.n_times

193495

In [24]:
events = np.arange(raw_array.n_times)
result = np.zeros((events.shape[0], 3))
result[:,0] = events
result = result.astype('int')
result

array([[     0,      0,      0],
       [     1,      0,      0],
       [     2,      0,      0],
       ...,
       [193492,      0,      0],
       [193493,      0,      0],
       [193494,      0,      0]])

In [25]:
epochs = Epochs(
    raw=raw_array,
    events=result,
    tmin=0,
    tmax=1-(1/srate),
    preload=True,
    baseline=None,
    verbose=False
)
epochs

<Epochs |  193246 events (all good), 0 - 0.996 sec, baseline off, ~2.52 GB, data loaded,
 '0': 193246>

In [38]:
psds, freqs = psd_welch(
    epochs[0:10000],
    fmin=0.5,
    fmax=30.,
    n_fft=250,
    verbose=False,
)

In [39]:
features = []
for fmin, fmax in FREQ_BANDS_RANGE.values():
    psds_band = psds[:,:, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
    features.append(psds_band.reshape(len(psds), -1))

features = np.concatenate(features, axis=1)

print(features.shape)
# print('Hand is: ', pipeline.predict(features))

(10000, 35)


In [40]:
pipeline = joblib.load("./models/rf.joblib")




In [44]:
np.unique(pipeline.predict(features), return_counts=True)

(array([0, 1], dtype=int32), array([9759,  241]))

We see that we predict some epochs as hand closed...

## Train & test on OpenBCI data

In [82]:
FREQ_BANDS_RANGE = {
    'DELTA': [0.5, 4.5],
    'THETA': [4.5, 8.5],
    'ALPHA': [8.5, 11.5],
    'SIGMA': [11.5, 15.5],
    'BETA': [15.5, 30]
}

def get_psds_from_epochs(epochs):
    """Extracts power spectrum densities from epochs
    Returns
    --------
    psds with associated frequencies calculated with the welch method.
    """
    psds, freqs = psd_welch(epochs, fmin=0.5, fmax=30.,n_fft=250,)
    return psds, freqs

def get_mean_psds(psds_with_freqs, are_relative=False):
    """EEG power band feature extraction.
    Input
    -------
    psds_with_freqs: tuple which contains
            - (nb_epochs, nb_chan=1, nb_freqs) psds amplitudes
            - (nb_freqs,) corresponding frequency values
            
    are_relative: boolean which indicates if the mean band powers
        for each subband are relative to the total power or not.
    
    Returns
    -------
    X : numpy array of shape [n_samples, nb_subband=5]
        Transformed data.
    """
    psds = psds_with_freqs[0]
    freqs = psds_with_freqs[1]
    
    if are_relative:
        psds /= np.sum(psds, axis=-1, keepdims=True)

    X = []
    for fmin, fmax in FREQ_BANDS_RANGE.values():
        psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
        X.append(psds_band.reshape(len(psds), -1))

    return np.concatenate(X, axis=1)

frequency_domain_pipeline = Pipeline([
    ('get_psds_from_epochs', FunctionTransformer(get_psds_from_epochs, validate=False)),
    ('frequency_domain_features', FunctionTransformer(get_mean_psds, validate=False))
])

In [102]:
features = []
labels = []
groups = []

def load_data(filename):
    eeg_data = np.loadtxt(filename, skiprows=skiprows, delimiter=',', usecols=(1,2,3,4,5,6,7))
    
    npnts = len(eeg_data)
    timevec = np.arange(0, npnts)/srate   

    return eeg_data,timevec

def create_raw(eeg_data):
    raw_array = RawArray(eeg_data, info=RAW_INFO)

#     raw_array.notch_filter(60)
#     raw_array.filter(0.5, 60)

    return raw_array

def set_annotations(raw_array, timevec,event_onset):
    actif_onset_index = np.where(event_onset == 1)
    actif_onsets = timevec[actif_onset_index]
    actif_duration = 3
    actif_description = 'Actif'

    repos_onsets = actif_onsets + actif_duration
    repos_duration = 4
    repos_description = 'Repos'

    print(repos_onsets)
    onsets = np.concatenate((actif_onsets, repos_onsets))
    onsets = np.sort(onsets)

    durations = []
    descriptions = []
    for i in range(0, int(len(onsets)/2)):
        durations.append(actif_duration)
        descriptions.append(actif_description)

        durations.append(repos_duration)
        descriptions.append(repos_description)

    annotations = mne.Annotations(onsets, durations, descriptions)
    raw_array.set_annotations(annotations)
  
    
    return raw_array 

def create_epochs(raw_array, srate):
    event_id = {'Repos':0, 'Actif':1}
    events, annot_event_id = mne.events_from_annotations(raw_array, event_id=event_id, chunk_duration=1)
    epochs = mne.Epochs(raw=raw_array, events=events, event_id=event_id, tmin=0, tmax=1-(1/srate), preload=True, baseline=None, verbose=False)
    y = events[:,2]
    return epochs, y
    
eeg_data, timevec = load_data(filename)
eeg_data = eeg_data.T

In [103]:
event_onset_idx = np.insert(
    np.arange(srate*2+7*srate-1, len(eeg_data[0]), 7*srate),
    0,
    srate*2-1
)
print(event_onset_idx)
event_onset = np.zeros(timevec.shape[0])
event_onset[event_onset_idx] = 1
print(event_onset)

eeg_raw = create_raw(eeg_data)
eeg_raw = set_annotations(eeg_raw, timevec, event_onset)
eeg_epochs, y = create_epochs(eeg_raw, srate)


[   499   2249   3999   5749   7499   9249  10999  12749  14499  16249
  17999  19749  21499  23249  24999  26749  28499  30249  31999  33749
  35499  37249  38999  40749  42499  44249  45999  47749  49499  51249
  52999  54749  56499  58249  59999  61749  63499  65249  66999  68749
  70499  72249  73999  75749  77499  79249  80999  82749  84499  86249
  87999  89749  91499  93249  94999  96749  98499 100249 101999 103749
 105499 107249 108999 110749 112499 114249 115999 117749 119499 121249
 122999 124749 126499 128249 129999 131749 133499 135249 136999 138749
 140499 142249 143999 145749 147499 149249 150999 152749 154499 156249
 157999 159749 161499 163249 164999 166749 168499 170249 171999 173749
 175499 177249 178999 180749 182499 184249 185999 187749 189499 191249
 192999]
[0. 0. 0. ... 0. 0. 0.]
Creating RawArray with float64 data, n_channels=7, n_times=193495
    Range : 0 ... 193494 =      0.000 ...   773.976 secs
Ready.
[  4.996  11.996  18.996  25.996  32.996  39.996  46.996

  raw_array.set_annotations(annotations)
  raw_array.set_annotations(annotations)


In [104]:
features = frequency_domain_pipeline.transform(eeg_epochs)
labels = y

Effective window size : 1.000 (s)


In [105]:
features.shape, labels.shape

((769, 35), (769,))

In [106]:
X_train, X_test, y_train, y_test = train_test_split(features, labels, stratify=labels, random_state=0)

forest = Pipeline([
    ('std_scaler', StandardScaler()),
    ('clf', 
         RandomForestClassifier(n_estimators=100, random_state=0, max_depth=10)
#          RandomForestClassifier(n_estimators=800, random_state=0, max_depth=20, min_samples_split=2, min_samples_leaf=1, max_features='sqrt', bootstrap=False)
    ),
]).fit(X_train, y_train)

y_pred = forest.predict(X_test)
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.66      0.85      0.74       110
           1       0.67      0.42      0.52        83

    accuracy                           0.66       193
   macro avg       0.67      0.63      0.63       193
weighted avg       0.67      0.66      0.65       193



In [107]:
joblib.dump(forest, 'models/rf_openbci_trained.joblib')

['models/rf_openbci_trained.joblib']