In [1]:
%matplotlib qt

In [2]:
"""
Collect data in the beginnining of the experiment 
with a PsychoPy experiment to display
the task, and train a machine learning model with
the processed data. Apply the model to the data,
and see the results. Display the predictions with
PyQt4.
"""
from mne import Epochs, find_events, set_log_level
from mne.time_frequency import psd_multitaper

import numpy as np

from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.model_selection import cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

import rteeg

set_log_level('error')

In [3]:
# Pick channels?

def create_X_y(inst):
    """Return X and y for a machine learning model.
    
    Parameters
    ----------
    inst : mne.io.RawArray
        Raw EEG data, which must include markers.
    
    Returns
    -------
    X : ndarray of shape (n_samples, n_features)
        Features for machine learning.
    y : ndarray of shape (n_samples,)
        Array of classes.
    """
    inst.filter(0.5, 40.)
    events = find_events(inst)

    epochs = Epochs(inst, events, tmin=0., tmax=10., baseline=None, 
                    add_eeg_ref=False)
    
    # Calculate power.
    psd, freqs = psd_multitaper(epochs, fmin=0.5, fmax=7.0)

    psd = psd.reshape(len(psd), -1)
    psd = StandardScaler().fit_transform(psd)
    
    X = PCA(n_components=15).fit_transform(psd)
    y = epochs.events[:,-1]
    
    if len(X) != len(y):
        raise ValueError("Number of samples in X not equal to number of "
                         "samples in y.")

    return X, y

    
def cross_validate(inst, clf, n_splits, test_size=0.2):
    """Cross validate to assess a model.
    
    Parameters
    ----------
    inst : mne.io.RawArray
        Raw EEG data, which must include markers.
    
    Returns
    -------
    scores.mean() : float
        Mean of the cross-validation scores.
    scores.std() : float
        Standard deviation of the cross-validation scores.
    """

    X, y = create_X_y(inst)
    cv = StratifiedShuffleSplit(n_splits=n_splits, test_size=test_size)
    scores = cross_val_score(clf, X=X, y=y, scoring='accuracy', cv=cv)
    
    return scores.mean(), scores.std()

# All of the predictions will go in here.
predictions = []
def predict(data_duration, clf, event_id, predictions):
    """Predict on new data.
    
    Parameters
    ----------
    data_duration : int, float
        Duration of data to use (in seconds). If 5, will predict
        on last five seconds of data.
    clf : scikit-learn classifier
        Must be fitted.
    event_id : dict
    predictions : list
        Nested list of [[y_true, y_pred], [...]].
    
    Returns
    -------
    output : str
        HTML and CSS to be displayed following each prediction.
    """
    
    output_base = """
    <head>
        <style>
        table {{font-size: 30px;}}
        th, td {{
            background-color: white;
            text-align: center;
            padding: 20px;
        }}

        </style>
    </head>
    <body>
        <table align="center">
            <tr>
                <th>Prediction</th>
                <th>iteration</th>
            </tr>
            <tr>
                <td style="background-color:{color}"><b>{prediction}</b></td>
                <td>{i}</td>
            </tr>
        </table>
    </body>
    """
    X, y = create_X_y(stream.make_raw())
    
    y_pred = clf.predict(X)[0]
    
    predictions.append([y, y_pred])
    
    output = output_base.format(prediction=event_id[y_pred], 
                                i=len(predictions))
    
    return output

In [None]:
stream = rteeg.Stream()
stream.connect(eeg=True, markers=True, eeg_montage='Enobio32')

In [None]:
find_events(stream.make_raw())

In [None]:
stream.recording_duration()

In [None]:
clf = RandomForestClassifier()

print cross_validate(stream.make_raw(), clf, 10)