<div class="alert alert-block alert-info">
    <center><h1> Real-time Inference </h1></center>
</div>

# Sections

**Requirements:**
- [pyqtgraph](https://github.com/pyqtgraph/pyqtgraph)

**Resources:**
- https://www.learnpyqt.com/courses/start/creating-your-first-window/

## Visualization Interface

1. Clean log folder
2. Launch game (brainDriver executable)
3. Launch interface (run next cell)

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

import glob
import numpy as np
import pyqtgraph as pg
from pyqtgraph.Qt import QtGui, QtCore
from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import scipy.signal
import time
import socket

from preprocessing_functions.preproc_functions import filtering, standardize, rereference
from fake_signal_functions.fake_signals import fake_EEG_dataset
from data_loading_functions.data_loader import EEGDataset
from feature_extraction_functions.convnets import ShallowConvNet, DeepConvNet


class GameLogReader(QtCore.QRunnable):
    def __init__(self, parent, player_idx):
        super(GameLogReader, self).__init__()
        self.parent = parent
        self.classifier = self.parent.classifier_widget
        
        self.player_idx = player_idx
        self.actions = {"leftWinker" : "Left", "rightWinker" : "Right", "headlight" : "Light"}
        self.action_idxs = { "leftWinker" : 0, "rightWinker" : 1, "headlight" : 2 }
        
        self.logfilename = glob.glob("../../USB/BrainDriver_V1.0/StandaloneLinux64/log/raceLog*.txt")[-1]
    
    def follow(self, thefile):
        thefile.seek(0, 2)
        while True:
            line = thefile.readline()
            if not line:
                time.sleep(0.1)
                continue
            yield line
    
    @QtCore.pyqtSlot()
    def run(self):
        logfile = open(self.logfilename,"r")
        loglines = self.follow(logfile)
        
        for line in loglines:
            if ("p{}_expectedInput".format(self.player_idx) in line):
                if "none" in line:
                    self.classifier.groundtruth_label.setText('Groundtruth: Rest')
                    self.classifier.groundtruth_idx = 3
                else:
                    tmp = line.split(" ")[-1].strip()
                    expected_action = self.actions[tmp]
                    self.classifier.groundtruth_label.setText('Groundtruth: {}'.format(expected_action))
                    self.classifier.groundtruth_idx = self.action_idxs[tmp]

                    
class GamePlayer:
    def __init__(self, parent, player_idx):
        #super(GamePlayer, self).__init__()
        self.parent = parent
        self.classifier = self.parent.classifier_widget
        
        # Communication protocol with game
        self.UDP_IP = "127.0.0.1"
        self.UDP_PORT = 5555
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP
        self.commands = {0 : "\x0B", 1 : "\x0D", 2 : "\x0C", 3 : ""}
        
        # Multi-threading
        self.threadpool = QtCore.QThreadPool()
            
    def sendCommand(self, action_idx):
        ''' Send the command to the game after a delay in a separate thread. '''
        print(action_idx)
        if action_idx not in [None, 3]:
            command_sender = CommandSender(self, action_idx)
            self.threadpool.start(command_sender)
            
            
class CommandSender(QtCore.QRunnable):
    def __init__(self, parent, action_idx):
        super(CommandSender, self).__init__()
        self.parent = parent
        self.action_idx = action_idx
            
    @QtCore.pyqtSlot()
    def run(self):
        time.sleep(np.random.random_sample())
        self.parent.sock.sendto(bytes(self.parent.commands[self.action_idx], "utf-8"),
                                (self.parent.UDP_IP, self.parent.UDP_PORT))
        
        
class MainInterface(QtGui.QWidget):
    ''' Widget 1: Temporal plot of signal (channel selection, display filtered version on top)
        Widget 2: Frequency plot of signal (using matplotlib)
        Widget 3: Inference & signal properties & statistics
    '''
    def __init__(self):
        # Call __init__ function of class QtGui.QWidget 
        super(MainInterface, self).__init__()
        
        # Game mode
        self.self_playing_threshold = 1.
        self.player_idx = 1
        
        # Initialize main window
        self.setWindowTitle('Main Interface')
        self.setGeometry(200, 200, 1000, 800)
        
        # Populate main window (widgets & layout)
        self.__controls()
        self.__layout()
        
        # Set shortcuts
        shortcut_R = QtGui.QShortcut(QtGui.QKeySequence("R"), self)
        shortcut_R.activated.connect(self.shortcut_reref)
        
        shortcut_F = QtGui.QShortcut(QtGui.QKeySequence("F"), self)
        shortcut_F.activated.connect(self.shortcut_filter)
        
        shortcut_S = QtGui.QShortcut(QtGui.QKeySequence("S"), self)
        shortcut_S.activated.connect(self.shortcut_standardize)
        
        shortcut_Left = QtGui.QShortcut(QtGui.QKeySequence("Ctrl+Left"), self)
        shortcut_Left.activated.connect(self.shortcut_win_size_minus)
        
        shortcut_Right = QtGui.QShortcut(QtGui.QKeySequence("Ctrl+Right"), self)
        shortcut_Right.activated.connect(self.shortcut_win_size_plus)
        
        shortcut_Up = QtGui.QShortcut(QtGui.QKeySequence("Ctrl+Up"), self)
        shortcut_Up.activated.connect(self.shortcut_nb_channels_plus)
        
        shortcut_Down = QtGui.QShortcut(QtGui.QKeySequence("Ctrl+Down"), self)
        shortcut_Down.activated.connect(self.shortcut_nb_channels_minus)
        
        shortcut_Plus = QtGui.QShortcut(QtGui.QKeySequence("="), self)
        shortcut_Plus.activated.connect(self.shortcut_t_inc_plus)
        
        shortcut_Minus = QtGui.QShortcut(QtGui.QKeySequence("-"), self)
        shortcut_Minus.activated.connect(self.shortcut_t_inc_minus)
        
        shortcut_Q = QtGui.QShortcut(QtGui.QKeySequence("Q"), self)
        shortcut_Q.activated.connect(self.shortcut_quit)
        
        # Execute game_reader in separate thread (update groundtruth by reading the game logs)
        self.threadpool = QtCore.QThreadPool()
        self.threadpool.start(self.game_reader) 
        
    def shortcut_reref(self):
        toggle_val = 2 if self.signal_widget.check_reference.checkState()==0 else 0
        self.signal_widget.check_reference.setCheckState(toggle_val) # Toggle check box & call update_preproc
        
    def shortcut_filter(self):
        toggle_val = 2 if self.signal_widget.check_filter.checkState()==0 else 0
        self.signal_widget.check_filter.setCheckState(toggle_val) # Toggle check box & call update_preproc
        
    def shortcut_standardize(self):
        toggle_val = 2 if self.signal_widget.check_standardize.checkState()==0 else 0
        self.signal_widget.check_standardize.setCheckState(toggle_val) # Toggle check box & call update_preproc
        
    def shortcut_win_size_minus(self):
        self.signal_widget.update_win_size(self.signal_widget.win_size - 10)
        
    def shortcut_win_size_plus(self):
        self.signal_widget.update_win_size(self.signal_widget.win_size + 10)
        
    def shortcut_nb_channels_plus(self):
        self.signal_widget.update_nbr_chs(self.signal_widget.nbr_ch_displayed + 1)
        
    def shortcut_nb_channels_minus(self):
        self.signal_widget.update_nbr_chs(self.signal_widget.nbr_ch_displayed - 1)
        
    def shortcut_t_inc_plus(self):
        self.signal_widget.update_t_increment(self.signal_widget.t_increment + 1)
        
    def shortcut_t_inc_minus(self):
        self.signal_widget.update_t_increment(self.signal_widget.t_increment - 1)
        
    def shortcut_quit(self):
        self.close()
        
    def __load_data(self):
        dataloader_params = {
        'data_path' : '../../Datasets/BCI_IV_2a/formatted_raw/',
        #'data_path' : '../../Datasets/Pilots/Pilot_2/Session_3/formatted_filt_250Hz/',
        'fs' : 250,
        'start' : 2.,
        'end' : 6.,
        'load_test' : False
        }

        # Loading dataset
        X_train, y_train, _, _, _, _ = EEGDataset(pilot_idx=1, **dataloader_params, valid_ratio=0.2).load_dataset()
        n_trials, n_channels, n_samples = X_train.shape

        # Create continuous signal from trials
        X = np.concatenate(X_train, axis=-1) # Shape=(1, n_channels, n_trials*n_samples)
        
        # Add noise
        #X = X + 10*np.random.rand(*X.shape)
        
        self.trial_len = n_samples
        
        return X, y_train
        
    def __controls(self):
        # Load data
        self.signal_init, self.groundtruth = self.__load_data()
        self.fs = 250
        
        # Load signal viewer (multi channel plots + hyperparameters widgets)
        self.widgets = []
        self.signal_widget = MultiChannelSignalWidget(self, self.signal_init, self.fs, nbr_ch_displayed=1, 
                                                             title='Temporal signal')
        self.widgets.append(self.signal_widget)
        
        # Load inference viewer (select model, see inference)
        self.classifier_widget = ClassifierWidget(self)
        self.widgets.append(self.classifier_widget)
        
        # Initialize game communication
        self.game_reader = GameLogReader(self, self.player_idx)
        self.game_player = GamePlayer(self, self.player_idx)
        
        # Load stats viewer (nbr correct/incorrect inference, accuracy, confusion matrix)
        
        
    def __layout(self):
        self.mainLayout = QtGui.QHBoxLayout()
        [ self.mainLayout.addWidget(w) for w in self.widgets ]
        self.setLayout(self.mainLayout)

        
class ClassifierWidget(QtGui.QWidget):
    def __init__(self, parent=None):
        super(ClassifierWidget, self).__init__(parent)
        self.parent = parent
 
        # Inference parameters
        self.inference_cnt = 0
        self.inference_win_size = 1125 #1125
        self.self_playing_threshold = self.parent.self_playing_threshold
        self.groundtruth_idx = None
        
        self.__controls()
        self.__layout()
    
    def __controls(self):
        # Widget title
        self.block_title = QtGui.QLabel('Inference Block')
        self.block_title.setAlignment(QtCore.Qt.AlignCenter)
        self.block_title.setFont(QtGui.QFont('SansSerif', 12, QtGui.QFont.Bold))
                
        # Inference labels
        self.timestamp_label = QtGui.QLabel('Timestamp: {}'.format(self.parent.signal_widget.t))
        self.groundtruth_label = QtGui.QLabel('Groundtruth: {}'.format(self.parent.groundtruth[0]))
        self.inference_label = QtGui.QLabel('Inference: {}'.format(None))
        self.counter_label = QtGui.QLabel('Inference counts: {}'.format(self.inference_cnt))
    
        # Inference button
        self.inference_button = QtGui.QPushButton('Infer')
        self.inference_button.pressed.connect(self.update)
    
        # Loading trained model
        self.model = ShallowConvNet(4, n_channels=self.parent.signal_init.shape[0], n_samples=self.inference_win_size)
        self.model.load_weights('../online_pipeline/best_model.h5')
    
        # Inference timer
        self.timer = QtCore.QTimer(self)
        self.timer.timeout.connect(self.update)
        self.timer.start(2000)
    
    def __layout(self):
        self.mainLayout = QtGui.QVBoxLayout()
        
        groupBox = QtGui.QGroupBox("Model inference")
        vbox = QtGui.QVBoxLayout()
        vbox.addWidget(self.timestamp_label)
        vbox.addWidget(self.groundtruth_label)
        vbox.addWidget(self.inference_label)
        vbox.addWidget(self.counter_label)
        vbox.addWidget(self.inference_button)
        vbox.addStretch(1)
        groupBox.setLayout(vbox)
        groupBox.setStyleSheet('QGroupBox:title {color: green;} QGroupBox {font-size: 14px; font-weight: bold;}')
                
        self.mainLayout.addWidget(groupBox)
        self.setLayout(self.mainLayout)
        
    def update(self):
        # Update inference count
        self.inference_cnt += 1
        self.counter_label.setText('Inference counts: {}'.format(self.inference_cnt))
        
        # Update timestamp
        self.timestamp_label.setText('Timestamp: {}'.format(self.parent.signal_widget.t))
        
        # Update groundtruth
        trial_idx = self.parent.signal_widget.t // self.parent.trial_len
        #self.groundtruth_label.setText('Groundtruth: {}'.format(self.parent.groundtruth[trial_idx])) ## New

        # Send model inference with at a ratio (1-self_playing_threshold) otherwise send groundtruth command
        tmp = np.random.random_sample()
        if tmp > self.self_playing_threshold:
            signal = self.parent.signal_widget.signal[np.newaxis, np.newaxis, :, :self.inference_win_size]
            pred = self.model.predict(signal)
            pred = np.argmax(pred)
            self.inference_label.setText('Inference: {}'.format(pred))
            colors = ['color: yellow', 'color: green', 'color: blue', 'color: red']
            self.inference_label.setStyleSheet(colors[pred])
            self.parent.game_player.sendCommand(pred)
        else:
            self.parent.game_player.sendCommand(self.groundtruth_idx) # Groundtruth with random delay (up to 1s)
            self.inference_label.setText('Inference: Self-playing')
            self.inference_label.setStyleSheet('color:gray')
            
        # Update model inference
#         signal = self.parent.signal_widget.signal[np.newaxis, np.newaxis, :, :self.inference_win_size]
#         pred = self.model.predict(signal)
#         pred = np.argmax(pred)
#         self.inference_label.setText('Inference: {}'.format(pred))
        
        # Inference ROI
#         colors2 = ['y', 'g', 'b', 'r']
#         wins_right = self.parent.signal_widget.win_right
#         wins_left = self.parent.signal_widget.win_left 
#         [ win.setPen(colors2[pred], width=2) for win in wins_right ]
#         [ win.setPen(colors2[pred], width=2) for win in wins_left ]
        
        
class MultiChannelSignalWidget(QtGui.QWidget):
    ''' Title, legend, update rules, interaction widgets '''
    def __init__(self, parent=None, signal=np.zeros(1000), fs=250, nbr_ch_displayed=1, title=''):
        super(MultiChannelSignalWidget, self).__init__(parent)
        self.parent = parent
        
        # Signal properties
        self.signal = signal
        self.fs = fs
        if len(signal.shape)==1:
            self.n_channels = 1
            self.n_samples = signal.shape[-1]
        elif len(signal.shape)==2:
            self.n_channels, self.n_samples = signal.shape
        else:
            print('Something is wrong with the signal.')
            return
        
        # Time display properties
        self.win_size = fs
        self.t = self.win_size
        self.t_max = self.n_samples
        self.x = np.arange(self.n_samples)/fs
        self.nbr_ch_displayed = nbr_ch_displayed
        self.t_increment = 1
        self.now = time.time()
        
        # Init
        pg.setConfigOptions(background='w', foreground='k', antialias=True)
        self.__controls()
        self.__layout()
        
        self.createSignalWidget()
        
    def __controls(self):
        # Signal plots
        self.graph = pg.GraphicsWindow('Temporal domain')  # Automatically generates grids with multiple items
        self.plots = []
        self.curves = []
        self.win_right = []
        self.win_left = []
        
        # Preprocessing
        self.check_reference = QtGui.QCheckBox(self)
        self.check_reference.setText("Re-reference")
        self.check_reference.stateChanged.connect(self.update_preproc)
        
        self.check_filter = QtGui.QCheckBox(self)
        self.check_filter.setText("Filter")
        self.check_filter.stateChanged.connect(self.update_preproc)
        
        self.check_standardize = QtGui.QCheckBox(self)
        self.check_standardize.setText("Standardize")
        self.check_standardize.stateChanged.connect(self.update_preproc)
            
        # Initialize timer & connect to update function
        self.timer = QtCore.QTimer(self)
        self.timer.timeout.connect(self.update)
        self.timer.start(10)
        
    def __layout(self):
        self.mainLayout = QtGui.QVBoxLayout()
        
        # Timestamp
        self.timestamp_label = QtGui.QLabel("Timestamp: {}".format(self.t))
        self.mainLayout.addWidget(self.timestamp_label)
        
        # Signal plots
        self.mainLayout.addWidget(self.graph)
        
        # Preprocessing
        groupBox = QtGui.QGroupBox("Signal preprocessing")
        groupBox.setStyleSheet('QGroupBox:title {color: green;} QGroupBox {font-size: 14px; font-weight: bold;}')
        
        vbox = QtGui.QVBoxLayout()
        vbox.addWidget(self.check_reference)
        vbox.addWidget(self.check_filter)
        vbox.addWidget(self.check_standardize)
        vbox.addStretch(1)
        groupBox.setLayout(vbox)
        self.mainLayout.addWidget(groupBox)
                
        # Set main layout
        self.setLayout(self.mainLayout)
    
    def update_win_size(self, val):
        # Set boundaries
        val = self.fs//10 if val < self.fs//10 else val
        val = 5*self.fs if val > 5*self.fs else val
        
        # Update value and label
        self.win_size = val
        
    def update_nbr_chs(self, val):
        # Set boundaries
        val = 1 if val < 1 else val
        val = self.n_channels if val > self.n_channels else val
        
        # Remove all plots
        for c in range(self.nbr_ch_displayed)[::-1]:
            del self.curves[c]
            
            self.graph.removeItem(self.plots[c])
            del self.plots[c]
            
            del self.win_left[c]
            del self.win_right[c]

        self.nbr_ch_displayed = val
        
        # Create plots
        self.createSignalWidget()
        
        if self.check_reference.isChecked() or self.check_filter.isChecked() or self.check_standardize.isChecked():
            nb_chs = self.nbr_ch_displayed
            [ self.curves[c].append(self.plots[c].plot()) for c in range(nb_chs) ]  
            [ self.curves[c][1].setPen((255, 165, 0), width=1.5) for c in range(nb_chs) ]
            
            self.legend.addItem(self.curves[0][1], "Preprocessed signal")
        
    def update_t_increment(self, val):
        # Set boundaries
        val = 0 if val < 0 else val
        val = 50 if val > 50 else val
        
        # Update value and label
        self.t_increment = val
        
    def update_preproc(self):
        ''' Update curve & legend for preprocessed signal. '''
        nb_chs = self.nbr_ch_displayed
                
        if self.check_filter.isChecked() or self.check_reference.isChecked() or self.check_standardize.isChecked():
            # Display preprocessed signal - Add curves to plot and set color
            if len(self.curves[0]) < 2:
                [ self.curves[c].append(self.plots[c].plot()) for c in range(nb_chs) ]  
                [ self.curves[c][1].setPen((255, 165, 0), width=1.5) for c in range(nb_chs) ]
            
                # Add preprocessed signal's legend
                self.legend.addItem(self.curves[0][1], "Preprocessed signal")
        else:
            # Don't display preprocessed signal - Remove curves from plots 
            for c in range(nb_chs):
                self.plots[c].removeItem(self.curves[c][1])
                del self.curves[c][1]
             
            # Remove preprocessed signal's legend
            self.legend.removeItem("Preprocessed signal")
            
    def createSignalWidget(self):
        ''' Initialize plotItems, axisItems & dataItems for each channel & legend. '''
        nb_chs = self.nbr_ch_displayed
        
        # Init plots (plotItems)
        [ self.plots.append(self.graph.addPlot(row=c, col=0)) for c in range(nb_chs)]
        [ self.plots[c].setLabel('left', text='Ch {}'.format(c)) for c in range(nb_chs) ]
        [ self.plots[c].getViewBox().setRange(yRange=[np.min(self.signal[c,:1000]), np.max(self.signal[c,:1000])]) for c in range(nb_chs) ]
        [ self.plots[c].setClipToView(True) for c in range(nb_chs) ]
        [ plt.hideAxis('bottom') for plt in self.plots[:-1] ]
        self.plots[-1].setLabel('bottom', text='Time [s]')
        self.plots[-1].getAxis('bottom').setTickSpacing(2,1)
        
        # Init signal curves (dataItems)
        [ self.curves.append([self.plots[c].plot()]) for c in range(nb_chs) ]  
        [ self.curves[c][0].setPen((0, 142, 204), width=1.5) for c in range(nb_chs) ]
                
        # Init legend
        self.legend = self.plots[0].addLegend()
        self.legend.addItem(self.curves[0][0], "Raw signal")

    def update(self):
        # Convenient renamings
        nb_chs = self.nbr_ch_displayed
        win_size = self.win_size

        # Reached end of signal
        if self.t > self.t_max:
            self.timer.stop()
            print("End of signal reached !")
        
        # Update time position
        self.t += self.t_increment
        self.timestamp_label.setText("Timestamp: {} - FPS: {:.4}" \
                                     .format(self.t, self.t_increment/(time.time() - self.now)))
        self.now = time.time()
        
        # Update main signal plot        
        self.x = np.roll(self.x, -self.t_increment)
        self.signal = np.roll(self.signal, -self.t_increment)
        [ self.plots[c].getViewBox().setRange(xRange=[self.x[0], self.x[win_size]]) for c in range(nb_chs) ]
        [ self.curves[c][0].setData(self.x[:win_size], self.signal[c,:win_size], autoDownsample=True) 
         for c in range(nb_chs) ]
        
        # Update filtered signal plot
        signal_proc = self.signal[:,:win_size]
        if self.check_filter.isChecked():
            signal_proc = filtering(signal_proc, fs=self.fs, f_order=3, f_low=0, f_high=10)
            [ self.curves[c][1].setData(self.x[:win_size], signal_proc[c,:]) for c in range(nb_chs) ]
        
        # Update standardized signal plot
        if self.check_standardize.isChecked():
            signal_proc = standardize(signal_proc)
            [ self.curves[c][1].setData(self.x[:win_size], signal_proc[c,:]) for c in range(nb_chs) ]
        
        # Update re-referenced signal plot
        if self.check_reference.isChecked():
            signal_proc = rereference(signal_proc)
            [ self.curves[c][1].setData(self.x[:win_size], signal_proc[c,:]) for c in range(nb_chs) ]
        
        # Update inference window
        if len(self.win_left) == len(self.win_right) < 1:
            [ self.win_left.append(self.plots[c].addLine(x=0, pen=pg.mkPen(color='w', width=0)))
              for c in range(nb_chs) ]
            
            [ self.win_right.append(self.plots[c].addLine(x=0, pen=pg.mkPen(color='w', width=0))) 
              for c in range(nb_chs) ] 
        else:
            pos_left = self.x[win_size] - 0.5 #self.parent.classifier_widget.inference_win_size]
            pos_right = self.x[win_size]
            [ self.win_left[c].setValue(pos_left) for c in range(nb_chs) ]
            
            [ self.win_right[c].setValue(pos_right) for c in range(nb_chs) ]

def main():
    app = QtGui.QApplication(sys.argv)    
    myapp = MainInterface()
    myapp.show()
    app.exec_()

if __name__ == '__main__':
    main()

Using TensorFlow backend.


Properties: 288 train trials - 22 channels - 7.5s trial length
Selecting classes [0, 1, 2, 3] & balancing...
Selecting 22 channels
Selecting time-window [2.0 - 6.0]s - (4.0s)...
Output shapes:  (230, 22, 1000) (58, 22, 1000) (58, 22, 1000)
Output classes:  [0 1 2 3]
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
None
None
None
None
None
0
0
0
3
3
3
1
1
1
3
2
2
2
2
0
0
0
0
3
3
3
3
1
1
1
1
2
2
2
2
2
0
0
0
0
3
3
3
3
1
