In [6]:
import os
import time
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import load_model

from joblib import load
from utils.visualize import showMe
from IPython.display import clear_output
import datetime
import brainflow
from scipy import signal

from brainflow.board_shim import BoardShim, BrainFlowInputParams, BoardIds, BrainFlowError

from tqdm.notebook import tqdm


from utils.svm import DCFilter, clip, remove2channel, cut_out, show_me_cut

from matplotlib import pyplot as plt

%matplotlib inline



from utils.visualize import showMe, stat
from utils.ros import connect, commands

from config.default import *

%load_ext autoreload
%autoreload 2

classes = settings['classes']

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


NOT TESTED

In [None]:
ros, talker = connect()

In [7]:
####      INIT BOARD        #######
BoardShim.enable_dev_board_logger()
params = BrainFlowInputParams()
board = BoardShim(BoardIds.MINDROVE_WIFI_BOARD, params)


try:
    board.stop_stream()
    board.release_session()
except:
    ...
    
board.prepare_session()
sample_rate = board.get_sampling_rate(16)
print("Device ready (sampling rate: {}hz)".format(sample_rate))

Device ready (sampling rate: 500hz)


In [8]:
model = keras.models.load_model('saved_models/04_02_val_acc_92')

In [9]:
def predict(model, data):
    data = remove2channel(data)

    data =clip(data)

    prediction = model.predict(data.reshape(-1,4,500))[0]
    print(prediction)
    prediction =  np.argmax(prediction)
    #######################################
    try:
        if int(prediction) == 2:
            talker.publish(commands['forward'])
        if int(prediction) == 1:
            talker.publish(commands['left'])
        if int(prediction) == 3:
            talker.publish(commands['right'])])
    except:
        print("ROS unavailable")
    #######################################

    prediction_class = classes[int(prediction)]
    
    showMe(data)
    print(f'Prediction: {prediction_class}')
    return prediction_class

def JustDoIt():
    std_threshold = settings['std_threshold']
    cut = settings['clip_value']
    command_history = []
    print("Loading model...")

    

  
    board.start_stream(450000)

    time.sleep(2)   #WAIT TO DC AVERAGE STAND IN
    print("Go ahead!")

    action_time = None
    block_time = None
    while True:
        time.sleep(0.001)
        data = board.get_current_board_data(sample_rate*settings['DC_length_control']) 
        data = DCFilter(data)
        data = data[:6,-500:] #keep the data of the eeg channels only, and remove data over the trial length
        _, _, std = stat(data)
       
        
        if std == 0:
            print("[ERROR] No data collected! Check the MindRove device and try again.")
            break
        # clear_output()
        # print(std)
        is_block = False
        if block_time is not None:
            is_block = datetime.datetime.now() < block_time
        if action_time is None and not is_block:
            if std > std_threshold:
                action_time  =  datetime.datetime.now()
                stop_time = action_time + datetime.timedelta(milliseconds=500)
        else:
            time_diff = (datetime.datetime.now() - stop_time).total_seconds()
            if abs(time_diff) < 0.01:
                clear_output()
                command_history.append(predict(model, data))
                print(command_history)
                
                action_time = None
                block_time = datetime.datetime.now()+datetime.timedelta(milliseconds=settings['block_time'])


In [None]:
JustDoIt()

In [None]:
def done():
    talker.unadvertise()
    ros.terminate()
    board.stop_stream()
    board.release_session()
done()