In [1]:
import glob
import argparse
import os
import pickle
import logging

In [3]:
import numpy as np
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.ensemble import RandomForestClassifier
# from sklearn.ensemble import StackingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import cross_val_score

In [4]:
from brainflow.board_shim import BoardShim
from brainflow.data_filter import DataFilter

In [5]:
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

In [6]:
def write_model(intercept, coefs, model_type):
    coefficients_string = '%s' % (','.join([str(x) for x in coefs[0]]))
    file_content = '''
#include "%s"
// clang-format off
const double %s_coefficients[%d] = {%s};
double %s_intercept = %lf;
// clang-format on
''' % (f'{model_type}_model.h', model_type, len(coefs[0]), coefficients_string, model_type, intercept)
    file_name = f'{model_type}_model.cpp'
    file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'generated', file_name)
    with open(file_path, 'w') as f:
        f.write(file_content)

In [7]:
def prepare_data(first_class, second_class, blacklisted_channels=None):
    # use different windows, its kinda data augmentation
    window_sizes = [4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
    overlaps = [0.5, 0.475, 0.45, 0.425, 0.4, 0.375, 0.35] # percentage of window_size
    dataset_x = list()
    dataset_y = list()
    for data_type in (first_class, second_class):
        for file in glob.glob(os.path.join('data', data_type, '*', '*.csv')):
            logging.info(file)
            board_id = os.path.basename(os.path.dirname(file))
            try:
                board_id = int(board_id)
                data = DataFilter.read_file(file)
                sampling_rate = BoardShim.get_sampling_rate(board_id)
                eeg_channels = get_eeg_channels(board_id, blacklisted_channels)
                for num, window_size in enumerate(window_sizes):
                    cur_pos = sampling_rate * 10
                    while cur_pos + int(window_size * sampling_rate) < data.shape[1]:
                        data_in_window = data[:, cur_pos:cur_pos + int(window_size * sampling_rate)]
                        data_in_window = np.ascontiguousarray(data_in_window)
                        bands = DataFilter.get_avg_band_powers(data_in_window, eeg_channels, sampling_rate, True)
                        feature_vector = bands[0]
                        feature_vector = feature_vector.astype(float)
                        dataset_x.append(feature_vector)
                        if data_type == first_class:
                            dataset_y.append(0)
                        else:
                            dataset_y.append(1)
                        cur_pos = cur_pos + int(window_size * overlaps[num] * sampling_rate)
            except Exception as e:
                logging.error(str(e), exc_info=True)

    logging.info('1st Class: %d 2nd Class: %d' % (len([x for x in dataset_y if x == 0]), len([x for x in dataset_y if x == 1])))

    with open('dataset_x.pickle', 'wb') as f:
        pickle.dump(dataset_x, f, protocol=3)
    with open('dataset_y.pickle', 'wb') as f:
        pickle.dump(dataset_y, f, protocol=3)

    return dataset_x, dataset_y

In [8]:
def get_eeg_channels(board_id, blacklisted_channels):
    eeg_channels = BoardShim.get_eeg_channels(board_id)
    try:
        eeg_names = BoardShim.get_eeg_names(board_id)
        selected_channels = list()
        if blacklisted_channels is None:
            blacklisted_channels = set()
        for i, channel in enumerate(eeg_names):
            if not channel in blacklisted_channels:
                selected_channels.append(eeg_channels[i])
        eeg_channels = selected_channels
    except Exception as e:
        logging.warn(str(e))
    logging.info('channels to use: %s' % str(eeg_channels))
    return eeg_channels

In [9]:
def print_dataset_info(data):
    x, y = data
    first_class_ids = [idx[0] for idx in enumerate(y) if idx[1] == 0]
    second_class_ids = [idx[0] for idx in enumerate(y) if idx[1] == 1]
    x_first_class = list()
    x_second_class = list()
    
    for i, x_data in enumerate(x):
        if i in first_class_ids:
            x_first_class.append(x_data.tolist())
        elif i in second_class_ids:
            x_second_class.append(x_data.tolist())
    second_class_np = np.array(x_second_class)
    first_class_np = np.array(x_first_class)

    logging.info('1st Class Dataset Info:')
    logging.info('Mean:')
    logging.info(np.mean(first_class_np, axis=0))
    logging.info('2nd Class Dataset Info:')
    logging.info('Mean:')
    logging.info(np.mean(second_class_np, axis=0))

In [10]:
def train_svm_mindfulness(data):
    model = SVC(kernel='linear', verbose=True, random_state=1, class_weight='balanced', probability=True)
    logging.info('#### SVM ####')
    model.fit(data[0], data[1])
    initial_type = [('mindfulness_input', FloatTensorType([1, 5]))]
    onx = convert_sklearn(model, initial_types=initial_type, target_opset=11, options={type(model): {'zipmap': False}})
    with open('svm_mindfulness.onnx', 'wb') as f:
        f.write(onx.SerializeToString())

In [21]:
def main():
    logging.basicConfig(level=logging.INFO)
    #parser = argparse.ArgumentParser()
    #parser.add_argument('--reuse-dataset', action='store_true')
    #args = parser.parse_args()

    #if args.reuse_dataset:
    #    with open('dataset_x.pickle', 'rb') as f:
    #        dataset_x = pickle.load(f)
    #    with open('dataset_y.pickle', 'rb') as f:
    #        dataset_y = pickle.load(f)
    #    data = dataset_x, dataset_y
    #else:
    #    data = prepare_data('relaxed', 'focused', blacklisted_channels={'T3', 'T4'})
    data = prepare_data('relaxed', 'focused', blacklisted_channels={'T3', 'T4'})
    print_dataset_info(data)
    #train_regression_mindfulness(data)
    train_svm_mindfulness(data)
    #train_knn_mindfulness(data)
    #train_random_forest_mindfulness(data)
    #train_mlp_mindfulness(data)
    #train_stacking_classifier(data)

In [20]:
main()

INFO:root:1st Class: 0 2nd Class: 0
INFO:root:1st Class Dataset Info:
INFO:root:Mean:
  `ndarray`, however any non-default value will be.  If the
  return False
INFO:root:nan
INFO:root:2nd Class Dataset Info:
INFO:root:Mean:
INFO:root:nan


NameError: name 'train_regression_mindfulness' is not defined