# LSTM decoding exploration

Going to be looking at using LSTMs (with potentially some changes to the cost function) to decode EMGs from cortical data. This is all based off work from Steph and Josh.


Going to start with the old Jango data, then update from there as necessary.

In [1]:
import numpy as np
import pandas as pd
from scipy.io import loadmat
from scipy.optimize import least_squares
from matplotlib import pyplot as plt
# import ipympl


# we'll use ridge regression as a comparisson
from sklearn import linear_model
from sklearn.model_selection import train_test_split, KFold
from sklearn import metrics

from tkinter import Tk
from tkinter import filedialog as fd # just so I don't have to repeatedly manually enter filenames

# and tf stuff
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, LSTM
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import backend as K # this was in Josh's version. Not sure why can't use numpy

# datetime stuff for logs
from datetime import datetime as dt

# utility functions that I moved to the LSTM_utils.py file
from LSTM_utils import *

# for saving things
from os import path
import pickle

In [2]:
# %matplotlib inline
# import ipympl
%matplotlib qt5

In [3]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

### Import data

In [4]:
# request the filename
root = Tk()
mat_fn = fd.askopenfilename(master=root,filetypes=[('matlab data','*.mat')])
root.destroy()

data = loadmat(mat_fn)

# Training
Iso_train_timestamps, Iso_train_firing, Iso_train_EMG, _, _, = load_josh_mat(data['IsoTrain'])
Spr_train_timestamps, Spr_train_firing, Spr_train_EMG, _, _, = load_josh_mat(data['SprTrain'])
Mov_train_timestamps, Mov_train_firing, Mov_train_EMG, _, _, = load_josh_mat(data['WmTrain'])

# Testing
Iso_test_timestamps, Iso_test_firing, Iso_test_EMG, _, _, = load_josh_mat(data['IsoTest'])
Spr_test_timestamps, Spr_test_firing, Spr_test_EMG, _, _, = load_josh_mat(data['SprTest'])
Mov_test_timestamps, Mov_test_firing, Mov_test_EMG, _, _, = load_josh_mat(data['WmTest'])


# EMG names -- makes things easier later on
EMG_name = Iso_test_EMG.columns

## Summarize the data

Plot out some good information on the max and variances of each muscle. Should also look at something for the cortex, maybe depth of modulation or avg firing rates?

Bar plots of representative values (max and 95th pctl) of different EMG

In [5]:
# --------------------------------------------
# Max EMG values bar plot

fig_emg_max, ax_emg_max = plt.subplots()
i_muscles = np.arange(Mov_train_EMG.shape[1]) # indexing on the x axis
bar_width = .25


ax_emg_max.bar(i_muscles, np.max(Mov_train_EMG,axis=0), width = bar_width, label='Movement')
ax_emg_max.bar(i_muscles+bar_width, np.max(Iso_train_EMG, axis=0), width = bar_width, label='Isometric')
ax_emg_max.bar(i_muscles+bar_width*2, np.max(Spr_train_EMG, axis=0), width = bar_width, label='Spring')

ax_emg_max.set_xticks(i_muscles+bar_width)
ax_emg_max.set_xticklabels(Mov_train_EMG.columns)

fig_emg_max.show()
_ = ax_emg_max.legend()

ax_emg_max.set_xlabel('Muscle')
ax_emg_max.set_ylabel('Max Value')

# For each bar in the chart, add a text label.
for bar in ax_emg_max.patches:
    # The text annotation for each bar should be its height.
    bar_value = bar.get_height()
    # Format the text with commas to separate thousands. You can do
    # any type of formatting here though.
    text = f'{bar_value:.02f}'
    # This will give the middle of each bar on the x-axis.
    text_x = bar.get_x() + bar.get_width() / 2
    # get_y() is where the bar starts so we add the height to it.
    text_y = np.max([bar.get_y() + bar_value, 0.01]) # keep it from going too far below zero, and disappearing
    # If we want the text to be the same color as the bar, we can
    # get the color like so:
    bar_color = bar.get_facecolor()
    # If you want a consistent color, you can just set it as a constant, e.g. #222222
    ax_emg_max.text(text_x, text_y, text, ha='center', va='bottom', color=bar_color,
            size=12)
    
for spine in ['right','top','bottom','left']:
    ax_emg_max.spines[spine].set_visible(False)

ax_emg_max.set_title('Maximum EMG Value')

# --------------------------------------------
# 95th percentile
fig_emg_95, ax_emg_95 = plt.subplots()
i_muscles = np.arange(Mov_train_EMG.shape[1]) # indexing on the x axis
bar_width = .25


ax_emg_95.bar(i_muscles, np.percentile(Mov_train_EMG, 95,axis=0), width = bar_width, label='Movement')
ax_emg_95.bar(i_muscles+bar_width, np.percentile(Iso_train_EMG, 95, axis=0), width = bar_width, label='Isometric')
ax_emg_95.bar(i_muscles+2*bar_width, np.percentile(Spr_train_EMG, 95, axis=0), width = bar_width, label='Spring')

ax_emg_95.set_xticks(i_muscles+bar_width)
ax_emg_95.set_xticklabels(Mov_train_EMG.columns)

# fig_emg_95.show()
_ = ax_emg_95.legend()

ax_emg_95.set_xlabel('Muscle')
ax_emg_95.set_ylabel('95th Percentile')

# For each bar in the chart, add a text label.
for bar in ax_emg_95.patches:
    # The text annotation for each bar should be its height.
    bar_value = bar.get_height()
    # Format the text with commas to separate thousands. You can do
    # any type of formatting here though.
    text = f'{bar_value:.02f}'
    # This will give the middle of each bar on the x-axis.
    text_x = bar.get_x() + bar.get_width() / 2
    # get_y() is where the bar starts so we add the height to it.
    text_y = np.max([bar.get_y() + bar_value, 0.01]) # keep it from going too far below zero, and disappearing
    # If we want the text to be the same color as the bar, we can
    # get the color like so:
    bar_color = bar.get_facecolor()
    # If you want a consistent color, you can just set it as a constant, e.g. #222222
    ax_emg_95.text(text_x, text_y, text, ha='center', va='bottom', color=bar_color,
            size=12)


for spine in ['right','top','bottom','left']:
    ax_emg_95.spines[spine].set_visible(False)

ax_emg_95.set_title('95th Percentile of EMG')


Text(0.5, 1.0, '95th Percentile of EMG')

Comparing mean firing rates of different conditions

In [6]:
fig_cort, ax_cort = plt.subplots(ncols=3)

Wm_means = np.mean(Mov_train_firing, axis=0)
Iso_means = np.mean(Iso_train_firing, axis=0)
Spr_means = np.mean(Spr_train_firing, axis=0)



# scatters between mean firing rates
ax_cort[0].scatter(Wm_means, Iso_means, s=2)
ax_cort[0].plot([0,40],[0,40],c='k', alpha=.3)
ax_cort[1].scatter(Wm_means, Spr_means, s=2)
ax_cort[1].plot([0,40],[0,40],c='k', alpha=.3)
ax_cort[2].scatter(Iso_means, Spr_means, s=2)
ax_cort[2].plot([0,40],[0,40],c='k', alpha=.3)


# x and y labels
ax_cort[0].set_xlabel('Movement')
ax_cort[0].set_ylabel('Isometric')
ax_cort[1].set_xlabel('Movement')
ax_cort[1].set_ylabel('Spring')
ax_cort[2].set_xlabel('Isometric')
ax_cort[2].set_ylabel('Spring')

# making the axes square and removing the spines
for axis in ax_cort:
    axis.set_aspect('equal',adjustable='box')
    for spine in ['right','top','bottom','left']:
        ax_emg_95.spines[spine].set_visible(False)
    axis.set_title('Mean Firing Rates')


fig_cort.set_label('Compared Mean Firing Rates across conditions')



Same with the variance of the firing rates (thinking depth of modulation sort of thing)

In [7]:
fig_cort, ax_cort = plt.subplots(ncols=3)

Wm_vars = np.var(Mov_train_firing, axis=0)
Iso_vars = np.var(Iso_train_firing, axis=0)
Spr_vars = np.var(Spr_train_firing, axis=0)



# scatters between var firing rates
ax_cort[0].scatter(Wm_vars, Iso_vars, s=2)
ax_cort[1].scatter(Wm_vars, Spr_vars, s=2)
ax_cort[2].scatter(Iso_vars, Spr_vars, s=2)


# x and y labels
ax_cort[0].set_xlabel('Movement')
ax_cort[0].set_ylabel('Isometric')
ax_cort[1].set_xlabel('Movement')
ax_cort[1].set_ylabel('Spring')
ax_cort[2].set_xlabel('Isometric')
ax_cort[2].set_ylabel('Spring')

# making the axes square and removing the spines
for axis in ax_cort:
    axis.set_aspect('equal',adjustable='box')
    max_lim = np.max([axis.get_xlim()[1],axis.get_ylim()[1]])
    min_lim = np.min([axis.get_xlim()[0],axis.get_ylim()[1]])
    axis.set_xlim([min_lim, max_lim])
    axis.set_ylim([min_lim, max_lim])
    for spine in ['right','top','bottom','left']:
        axis.spines[spine].set_visible(False)
    axis.set_title('Variance of Firing Rates')
    axis.plot([min_lim, max_lim],[min_lim, max_lim],c='k', alpha=.3)

fig_cort.set_label('Compared var Firing Rates across conditions')



## Model building 
Look through a couple of different model types, look to see how they are trained


### Linear with Static Non-linearity

Using the lab default -- build a wiener filter, fit it, then fit with a static polynomial on top of the form

$Ax^2 + Bx + C$

Where $x$ is the original EMG prediction, and the output is the new EMG prediction

Alternatively, I have also allowed to predict using an exponential activation 

$Ae^{Bx} + C$

Also giving the options for a sigmoid


**Starting with defining our nonlinearity methods**

In [None]:
# using scipy's least_squares:
def non_linearity(p, y_pred, nonlinear_type):
    if nonlinear_type == 'poly':
        return p[0] + p[1]*y_pred + p[2]*y_pred**2
    elif nonlinear_type == 'exponential':
        return p[0]*np.exp(p[1]*y_pred) + p[2]
    elif nonlinear_type == 'sigmoid':
        return p[1] * 1/(1 + np.exp(-10*(y_pred-p[0])))

def non_linearity_residuals(p, y_pred, y_act, nonlinear_type):
    if nonlinear_type == 'poly':
        return y_act - (p[0] + p[1]*y_pred + p[2]*y_pred**2)
    elif nonlinear_type == 'exponential':
        return y_act - (p[0]*np.exp(p[1]*y_pred) + p[2])
    elif nonlinear_type == 'sigmoid':
        return y_act - (p[1] * 1/(1 + np.exp(-10*(y_pred-p[0]))))



Next a function that compares Wiener filter models with Wiener cascades, reports the VAF (Cooefficient of Determination) and gives plots for the validation predictions

In [None]:
# Set up a function that we can just call multiple times, so then we can just quickly run through all of the different combinations


def basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type, save_plot=True):
    wiener_input = pd.DataFrame() # empty dataframe
    n_lags = 10 # number of lags
    for ii in np.arange(n_lags): # create the lagged dataframe
        col_dict = dict(zip(train_firing.columns, train_firing.columns+f"_lag{ii}"))
        wiener_input = wiener_input.join(train_firing.shift(-ii, fill_value=0).rename(columns=col_dict), how='outer')

    wiener_test = pd.DataFrame() # empty dataframe
    for ii in np.arange(n_lags): # create the lagged dataframe
        col_dict = dict(zip(test_firing.columns, test_firing.columns+f"_lag{ii}"))
        wiener_test = wiener_test.join(test_firing.shift(-ii, fill_value=0).rename(columns=col_dict), how='outer')
        
    wiener_retrain = pd.DataFrame() # empty dataframe
    for ii in np.arange(n_lags): # create the lagged dataframe
        col_dict = dict(zip(retrain_firing.columns, retrain_firing.columns+f"_lag{ii}"))
        wiener_retrain = wiener_retrain.join(retrain_firing.shift(-ii, fill_value=0).rename(columns=col_dict), how='outer')



    if nonlinear_type == 'poly':
        init_pred = [.1, .1, .1]
    elif nonlinear_type == 'exponential':
        init_pred = [1, .1, .2]
    elif nonlinear_type == 'sigmoid':
        init_pred = [.1, .5]

    mdl_A = linear_model.LinearRegression(fit_intercept=True)
    mdl_B = {} # a dictionary of numpy polynomials. Probably a better way to do this, maybe Xuan's method... oh well

    mdl_A.fit(wiener_input, train_EMG)# fit the first model
    prefit_EMG = mdl_A.predict(wiener_input) # get the initially predicted EMGs
    retrain_pred = mdl_A.predict(wiener_retrain) # from a training set on a separate condition -- to properly hold out test data
    prefit_VAF = metrics.explained_variance_score(train_EMG, prefit_EMG, multioutput='raw_values')
    nonlin_EMG = np.zeros(prefit_EMG.shape)

    print('Training Results\n')

    for ii in np.arange(len(train_EMG.columns)):
        muscle = train_EMG.columns[ii]
        mdl_B[muscle] = least_squares(non_linearity_residuals, init_pred, args=(prefit_EMG[:,ii], train_EMG.iloc[:,ii].to_numpy(), nonlinear_type)).x
        # print('--------------------------------------------------------')
        print(muscle)
        print(f"\tLinear VAF: {prefit_VAF[ii]:.03f}")
        nonlin_EMG[:,ii] = non_linearity(mdl_B[muscle],(prefit_EMG[:,ii]), nonlinear_type)
        print(f"\tNonLinear VAF: {metrics.explained_variance_score(train_EMG.iloc[:,ii],nonlin_EMG[:,ii]):.03f}")




    # predicting the test set
    prefit_test = mdl_A.predict(wiener_test)
    prefit_test_VAF = metrics.explained_variance_score(test_EMG, prefit_test, multioutput='raw_values')
    nonlin_test = np.zeros(prefit_test.shape)
    nonlin_within_test = np.zeros(prefit_test.shape)
    nonlin_VAF = np.zeros(prefit_test_VAF.shape)
    nonlin_within_VAF = np.zeros(prefit_test_VAF.shape)
    mdl_C = {} # for a separate non-linearity, built for the second condition.

    print('\n------------------------------------------------------------\n')
    print('Testing Results\n')
    for ii in np.arange(len(train_EMG.columns)):
        muscle = test_EMG.columns[ii]
        mdl_C[muscle] = least_squares(non_linearity_residuals, init_pred, args=(retrain_pred[:,ii], retrain_EMG.iloc[:,ii].to_numpy(), nonlinear_type)).x
        # print('--------------------------------------------------------')
        print(muscle)
        print(f"\tLinear VAF: {prefit_test_VAF[ii]:.03f}")
        nonlin_test[:,ii] = non_linearity(mdl_B[muscle],(prefit_test[:,ii]), nonlinear_type)
        nonlin_within_test[:,ii] = non_linearity(mdl_C[muscle],(prefit_test[:,ii]), nonlinear_type)
        nonlin_VAF[ii] = metrics.explained_variance_score(test_EMG.iloc[:,ii],nonlin_test[:,ii])
        nonlin_within_VAF[ii] = metrics.explained_variance_score(test_EMG.iloc[:,ii],nonlin_within_test[:,ii])
        print(f"\tPre-built Nonlinearity VAF: {nonlin_VAF[ii]:.03f}")
        print(f"\tRe-built Nonlinearity VAF: {nonlin_within_VAF[ii]:.03f}")

    # Plotting the test data -- so that we can see how  the non-linearities act between types
    n_rows = int(np.ceil(np.sqrt(len(train_EMG.columns))))
    fig_nl_test, ax_nl_test = plt.subplots(nrows=n_rows, ncols=n_rows, sharex=True, constrained_layout=True)

    for muscle_ii in np.arange(len(train_EMG.columns)):
        row_i = int(muscle_ii//n_rows)
        col_i = int(muscle_ii%n_rows)
        ax_nl_test[row_i,col_i].plot(test_timestamps, test_EMG.iloc[:,muscle_ii], label='Recorded')
        ax_nl_test[row_i,col_i].plot(test_timestamps,prefit_test[:,muscle_ii], label=f'Linear VAF: {prefit_test_VAF[muscle_ii]:.03f}')
        ax_nl_test[row_i,col_i].plot(test_timestamps,nonlin_test[:,muscle_ii], label=f'NonLin VAF: {nonlin_VAF[muscle_ii]:.03f}')
        ax_nl_test[row_i,col_i].plot(test_timestamps,nonlin_within_test[:,muscle_ii], label=f'Rebuilt NonLin VAF: {nonlin_within_VAF[muscle_ii]:.03f}')
        ax_nl_test[row_i,col_i].set_title(f"{test_EMG.columns[muscle_ii]}")
        ax_nl_test[row_i,col_i].set_xlabel(f"Time (s)")
        ax_nl_test[row_i,col_i].set_ylabel("EMG envelope")
        
        _ = ax_nl_test[row_i,col_i].legend()

        # turn off the spines
        for spine in ['right','top','bottom','left']:
            ax_nl_test[row_i,col_i].spines[spine].set_visible(False)
            
    if save_plot:
        fig_nl_test.savefig('Wiener_Cascade_Comparison.svg')

    # let's also plot the VAFs in a clean manner so that it's easy to compare
    fig_vaf, ax_vaf = plt.subplots()
    i_muscles = np.arange(len(test_EMG.columns)) # indexing on the x axis
    bar_width = .25

    ax_vaf.bar(i_muscles, prefit_test_VAF, width = bar_width, label='Linear')
    ax_vaf.bar(i_muscles + bar_width, nonlin_VAF, width = bar_width, label='Nonlinear')
    ax_vaf.bar(i_muscles + 2*bar_width, nonlin_within_VAF, width = bar_width, label='Rebuilt Nonlinear')

    ax_vaf.set_xticks(i_muscles + 1*bar_width)
    ax_vaf.set_xticklabels(test_EMG.columns)

    ax_vaf.set_ylim([-.05, 1.05])
    ax_vaf.set_xlabel('Muscle')
    ax_vaf.set_ylabel('Coefficient of Determination')

    ax_vaf.legend()

    # For each bar in the chart, add a text label.
    for bar in ax_vaf.patches:
        # The text annotation for each bar should be its height.
        bar_value = bar.get_height()
        # Format the text with commas to separate thousands. You can do
        # any type of formatting here though.
        text = f'{bar_value:.02f}'
        # This will give the middle of each bar on the x-axis.
        text_x = bar.get_x() + bar.get_width() / 2
        # get_y() is where the bar starts so we add the height to it.
        text_y = np.max([bar.get_y() + bar_value, 0.01]) # keep it from going too far below zero, and disappearing
        # If we want the text to be the same color as the bar, we can
        # get the color like so:
        bar_color = bar.get_facecolor()
        # If you want a consistent color, you can just set it as a constant, e.g. #222222
        ax_vaf.text(text_x, text_y, text, ha='center', va='bottom', color=bar_color,
                size=12)

    # turn off the spines
    for spine in ['right','top','bottom','left']:
        ax_vaf.spines[spine].set_visible(False)


    if save_plot:
        fig_vaf.savefig('Wiener_VAF.svg')

    return nonlin_EMG


In [13]:
def LSTM_comparisons(train_firing_dict:dict, train_EMG_dict:dict, 
        test_firing_dict:dict, test_EMG_dict:dict, EMG_name, 
        plot_sig_flag:bool=False, plot_VAF_flag:bool = True, 
        return_VAF_flag:bool = False):
    # hyper params
    layer_0_units = 300
    drop_in = .25     # input dropout percentage for LSTM layer
    drop_rec = 0    # recurrent dropout for LSTM
    drop_lay = .15    # dropout layer?


    # append the variance to the EMG values
    n_EMGs = len(EMG_name)
    seq_len = train_firing_dict['Combined'].shape[1] # num of lags
    n_neurons = train_firing_dict['Combined'].shape[2]

    # pull out the "Combined" conditions for training
    train_i = train_firing_dict['Combined']
    train_o = train_EMG_dict['Combined']

        # Set up the LSTMs
    mdl = tf.keras.models.Sequential()

    mdl.add(tf.keras.layers.LSTM(layer_0_units, input_shape = (seq_len,n_neurons), dropout=drop_in, recurrent_dropout=drop_rec))
    if drop_lay:
        mdl.add(tf.keras.layers.Dropout(drop_lay)) # dropout layer if wanted
    mdl.add(tf.keras.layers.Dense(n_EMGs)) # try with linear -- how does it compare?
    mdl.compile(loss='mse', optimizer='rmsprop', metrics=['MeanAbsoluteError'])

    mdl.fit(train_i, train_o, epochs=50, verbose=False)

    train_preds = {}
    train_VAFs = {}
    test_preds = {}
    test_VAFs = {}

    for cond in train_firing_dict.keys():
        train_preds[cond] = mdl.predict(train_firing_dict[cond])
        train_VAFs[cond] = metrics.r2_score(train_EMG_dict[cond][:,:n_EMGs], train_preds[cond][:,:n_EMGs], multioutput='raw_values')

    for cond in test_firing_dict.keys():
        if cond != 'Combined': # don't really care about the "combined" test case
            test_preds[cond] = mdl.predict(test_firing_dict[cond])
            test_VAFs[cond] = metrics.r2_score(test_EMG_dict[cond][:,:n_EMGs], test_preds[cond][:,:n_EMGs], multioutput='raw_values')


    print('----------------------------------------')
    for ii_name, name in enumerate(EMG_name):
        print(name)
        for in_name, in_vaf in train_VAFs.items():
            print(f"\tTrain VAF for {in_name}: {in_vaf[ii_name]}")
        for out_name,out_vaf in test_VAFs.items():
            print(f"\tTest VAF for {out_name}: {out_vaf[ii_name]}")

    if plot_sig_flag:
        plot_rec_pred(train_EMG_dict, train_preds, EMG_name, train_VAFs, title_append = 'training set')
        plot_rec_pred(test_EMG_dict, test_preds, EMG_name, test_VAFs, title_append = 'testing set')

    if plot_VAF_flag:
        plot_VAFs(train_VAFs, test_VAFs, EMG_name, sup_title='Vanilla LSTM')

    if return_VAF_flag:
        return train_VAFs, test_VAFs


### Define Custom Loss functions

First the weighted loss function. This one calculates the MSE, but weights the value for each point in time by the variance of that particular task. This balances the training so that the model is able to train for conditions with different EMG ranges (the whole idea behind our system...)

In [8]:
# weight the losses by the std of that particular range in time and muscle.
# for hybrid decoders
def hybrid_weight_loss(target, pred):
    # inputs: 
    #         target is the recorded data, plus the weights since this needs to be callable by tf
    #                First half of the columns of the data will be EMG, second half will be the weights
    #         pred   is the current prediction values
    num_targets = K.shape(target)[1]//2 # number of cols / 2
    err = (target[:, :num_targets] - pred[:,:num_targets]) # subtract the values
    se = tf.divide(K.square(err), target[:,num_targets:]) # multiply the square error by the gains
    mse = K.mean(se, axis=-1)
#     tf.print(f"Error Shape: {mse.shape}")
    
    return mse
    


In [9]:
def LSTM_across_weighted(train_firing_dict: dict, train_EMG_orig: dict,
        train_var_dict: dict, test_firing_dict: dict, test_EMG_dict: dict, 
        EMG_name, plot_sig_flag:bool=False, plot_VAF_flag:bool=True,
        return_VAF_flag:bool = False):
    # hyper params
    layer_0_units = 300
    drop_in = .25     # input dropout percentage for LSTM layer
    drop_rec = 0    # recurrent dropout for LSTM
    drop_lay = .15    # dropout layer?


    # append the variance to the EMG values
    n_EMGs = len(EMG_name)
    seq_len = train_firing_dict['Combined'].shape[1] # num of lags
    n_neurons = train_firing_dict['Combined'].shape[2]
    train_EMG_dict = {}
    for ii_cond, cond in enumerate(train_EMG_orig.keys()):
        train_EMG_dict[cond] = train_EMG_orig[cond].copy()
        train_EMG_dict[cond] = np.append(train_EMG_dict[cond], train_var_dict[cond], axis=1)

    # pull out the "Combined" conditions for training
    train_i = train_firing_dict['Combined']
    train_o = train_EMG_dict['Combined']

        # Set up the LSTMs
    mdl = tf.keras.models.Sequential()

    mdl.add(tf.keras.layers.LSTM(layer_0_units, input_shape = (seq_len,n_neurons), dropout=drop_in, recurrent_dropout=drop_rec))
    if drop_lay:
        mdl.add(tf.keras.layers.Dropout(drop_lay)) # dropout layer if wanted
    mdl.add(tf.keras.layers.Dense(n_EMGs*2)) # try with linear -- how does it compare?
    mdl.compile(loss=hybrid_weight_loss, optimizer='rmsprop', metrics=['MeanAbsoluteError'])

    mdl.fit(train_i, train_o, epochs=50, verbose=False)

    train_preds = {}
    train_VAFs = {}
    test_preds = {}
    test_VAFs = {}

    for cond in train_firing_dict.keys():
        train_preds[cond] = mdl.predict(train_firing_dict[cond])
        train_VAFs[cond] = metrics.r2_score(train_EMG_dict[cond][:,:n_EMGs], train_preds[cond][:,:n_EMGs], multioutput='raw_values')

    for cond in test_firing_dict.keys():
        if cond != 'Combined': # don't really care about the "combined" test case
            test_preds[cond] = mdl.predict(test_firing_dict[cond])
            test_VAFs[cond] = metrics.r2_score(test_EMG_dict[cond][:,:n_EMGs], test_preds[cond][:,:n_EMGs], multioutput='raw_values')


    print('----------------------------------------')
    for ii_name, name in enumerate(EMG_name):
        print(name)
        for in_name, in_vaf in train_VAFs.items():
            print(f"\tTrain VAF for {in_name}: {in_vaf[ii_name]}")
        for out_name,out_vaf in test_VAFs.items():
            print(f"\tTest VAF for {out_name}: {out_vaf[ii_name]}")

    if plot_sig_flag:
        plot_rec_pred(train_EMG_dict, train_preds, EMG_name, train_VAFs, title_append = 'training set')
        plot_rec_pred(test_EMG_dict, test_preds, EMG_name, test_VAFs, title_append = 'testing set')

    if plot_VAF_flag:
        plot_VAFs(train_VAFs, test_VAFs, EMG_name, sup_title = f'LSTM with weighting')

    if return_VAF_flag:
        return train_VAFs, test_VAFs

In [None]:
def LSTM_grid_search(firing, EMG, n_iter = 150, n_fold = 10, n_epochs = 20, unit_range = [100, 400], drop_in_range = [0,.5], drop_rec_range = [0,.5], drop_lay_range = [0,.5], seq_range=[5,20], plot=True):
    # Runs a monte-carlo style grid search on hyper parameters
    # It will run mfxval, so there is no need for a separate training group
    #
    #  This will allow me to compare the number of lstm units, drop percentages,
    #  batch and sequence sizes, etc.
    
    EMG_names = EMG.columns
    cols = [f"{name}_train_VAF" for name in EMG_names]
    cols += [f'{name}_test_VAF' for name in EMG_names]
    cols += ['n_units','drop_in','drop_rec','drop_lay','seq_len']
    log = pd.DataFrame(columns=cols)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs/LSTM_exploration', histogram_freq=1)

    
    n_neurons = firing.shape[1] # number of neurons
    n_EMGs = EMG.shape[1]
    
    # get the indices of the folds
    kf = KFold(n_splits=n_fold, random_state=None, shuffle=False) # working in chunks, not random indices
#     train_idx,test_idx = kf.split(firing) # get the indices. Won't split until later, since we will need to set up the sequences each round
    
    # intiialize the random number generator for the monte carlo
    rng = np.random.default_rng()
    
    for iter in np.arange(n_iter):
        layer_0_units = rng.integers(unit_range[0],unit_range[1])
        drop_in = rng.uniform(drop_in_range[0], drop_in_range[1])
        drop_rec = rng.uniform(drop_rec_range[0], drop_rec_range[1])
        drop_lay = rng.uniform(drop_lay_range[0], drop_lay_range[1])
        seq_len = rng.integers(seq_range[0], seq_range[1])
        
        rnn_i = np.ndarray((firing.shape[0], seq_len, firing.shape[1]))
        for ii in np.arange(seq_len):
            rnn_i[:,ii,:] = firing.shift(-ii, fill_value=0).to_numpy()
        
        
        train_VAFs = np.zeros((n_fold, n_EMGs))
        test_VAFs = np.zeros((n_fold, n_EMGs))
        
        log_entry = {'n_units':layer_0_units, 'drop_in':drop_in, 'drop_rec':drop_rec,\
                        'drop_lay':drop_lay, 'seq_len':seq_len}
        
        fold_idx = 0
        for train_idx,test_idx in kf.split(firing):

            rnn_train_i = np.zeros((len(train_idx),seq_len, firing.shape[1]))
            rnn_test_i = np.zeros((len(test_idx),seq_len, firing.shape[1]))
            rnn_train_o = EMG.iloc[train_idx,:].to_numpy()
            rnn_test_o = EMG.iloc[test_idx,:].to_numpy()
            
            # split training and testing inputs
            rnn_train_i = rnn_i[train_idx,:,:]
            rnn_test_i = rnn_i[test_idx,:,:]


            # normalize the EMGs
            EMG_std = np.std(rnn_train_o, axis=0)
            for ii in np.arange(rnn_train_o.shape[1]):
                rnn_train_o[:,ii] = rnn_train_o[:,ii]/EMG_std[ii]
                rnn_test_o[:,ii] = rnn_test_o[:,ii]/EMG_std[ii]

            # Set up the LSTMs
            mdl = tf.keras.models.Sequential()


            mdl.add(tf.keras.layers.LSTM(layer_0_units, input_shape = (seq_len,n_neurons), dropout=drop_in, recurrent_dropout=drop_rec))
            if drop_lay:
                mdl.add(tf.keras.layers.Dropout(drop_lay)) # dropout layer if wanted
        #     mdl.add(tf.keras.layers.Dense(n_EMGs, activation='relu')) # dense combination layer
            mdl.add(tf.keras.layers.Dense(n_EMGs)) # dense combination layer
            mdl.compile(loss='mse', optimizer='adam', metrics=['accuracy'])

            mdl.fit(rnn_train_i, rnn_train_o, epochs=n_epochs, verbose=False, callbacks=[tensorboard_callback])


            train_pred = mdl.predict(rnn_train_i, verbose=False)
            test_pred = mdl.predict(rnn_test_i, verbose=False)
            train_VAFs[fold_idx,:] = metrics.explained_variance_score(rnn_train_o, train_pred, multioutput='raw_values')
            test_VAFs[fold_idx,:] = metrics.explained_variance_score(rnn_test_o, test_pred, multioutput='raw_values')
            
            fold_idx += 1
            
        # store the VAFs in the log entry
        for emg_iter,emg_name in enumerate(EMG_names):
            log_entry[f"{emg_name}_train_VAF"] = np.mean(train_VAFs[:,emg_iter])
            log_entry[f"{emg_name}_test_VAF"] = np.mean(test_VAFs[:,emg_iter])
        
        
#         return log_entry
#         print(pd.DataFrame.from_records([log_entry]))
        log = pd.concat([log,pd.DataFrame.from_records([log_entry])], ignore_index=True)
        print(f"Looped iteration {iter} of {n_iter}")
        
    return log

In [10]:
def LSTM_rex(train_firing_dict: dict, train_EMG_orig: dict, train_oh_dict:dict,
        test_firing_dict: dict, test_EMG_dict: dict, EMG_name, beta=3,
        plot_sig_flag:bool=False, plot_VAF_flag:bool=True,
        return_VAF_flag:bool=False):

    # hyper params
    layer_0_units = 300
    drop_in = .25     # input dropout percentage for LSTM layer
    drop_rec = 0    # recurrent dropout for LSTM
    drop_lay = .15    # dropout layer?

    # append the variance to the EMG values
    n_EMGs = len(EMG_name)
    seq_len = train_firing_dict['Combined'].shape[1] # num of lags
    n_neurons = train_firing_dict['Combined'].shape[2]
    train_EMG_dict = {}
    for ii_cond, cond in enumerate(train_EMG_orig.keys()):
        train_EMG_dict[cond] = train_EMG_orig[cond]
        train_EMG_dict[cond] = np.append(train_EMG_dict[cond], train_oh_dict[cond], axis=1)

    # pull out the "Combined" conditions for training
    train_i = train_firing_dict['Combined'] 
    train_o = train_EMG_dict['Combined']

    # define the vrex loss function -- so that we can dynamically change the Beta value etc
    def vrex_loss(target, pred):
        # from the Risk Extrapolation paper
        B = beta #

        n_target = n_EMGs
        err = target[:,:n_target] - pred[:,:n_target] # without the condition flag
        se = K.square(err) # squared error -- TxM
        mse = K.mean(se, axis=-1) # mean squared error for each sample -- Tx1

        # now pull in the one-hot matrix for flagging
        cond_oh = tf.transpose(target[:,n_target:]) # transpose it so that we can add everything later CxT

        # risk for each condition -- ie MSE for each condition
        risk = tf.matmul(cond_oh,se) # Cx1
        risk = tf.divide(risk, tf.reduce_sum(cond_oh, 1, keepdims=1)) # mean to account for differen num samples

        rex = B*K.var(risk) + K.sum(risk)

        return rex

    # create the model
    mdl = tf.keras.models.Sequential()
    # add the LSTM layer
    mdl.add(tf.keras.layers.LSTM(layer_0_units, input_shape = (seq_len,n_neurons), dropout=drop_in, recurrent_dropout=drop_rec))
    mdl.add(tf.keras.layers.Dropout(drop_lay))
    mdl.add(tf.keras.layers.Dense(train_o.shape[1]))

    mdl.compile(loss=vrex_loss, optimizer='rmsprop', metrics='mse')
    mdl.fit(train_i, train_o, epochs=50, verbose=0)

    # Get the predictions
    train_preds = {}
    train_VAFs = {}
    test_preds = {}
    test_VAFs = {}

    for cond in train_firing_dict.keys():
        train_preds[cond] = mdl.predict(train_firing_dict[cond])
        train_VAFs[cond] = metrics.r2_score(train_EMG_dict[cond][:,:n_EMGs], train_preds[cond][:,:n_EMGs], multioutput='raw_values')
    
    for cond in test_firing_dict.keys():
        if cond != 'Combined': # don't really care about the "combined" test case
            test_preds[cond] = mdl.predict(test_firing_dict[cond])
            test_VAFs[cond] = metrics.r2_score(test_EMG_dict[cond][:,:n_EMGs], test_preds[cond][:,:n_EMGs], multioutput='raw_values')


    print('----------------------------------------')
    for ii_name, name in enumerate(EMG_name):
        print(name)
        for in_name, in_vaf in train_VAFs.items():
            print(f"\tTrain VAF for {in_name}: {in_vaf[ii_name]}")
        for out_name,out_vaf in test_VAFs.items():
            print(f"\tTest VAF for {out_name}: {out_vaf[ii_name]}")

    if plot_sig_flag:
        plot_rec_pred(train_EMG_dict, train_preds, EMG_name, train_VAFs, title_append = 'training set')
        plot_rec_pred(test_EMG_dict, test_preds, EMG_name, test_VAFs, title_append = 'testing set')

    if plot_VAF_flag:
        plot_VAFs(train_VAFs, test_VAFs, EMG_name, sup_title=f'REx, Beta = {beta}')

    if return_VAF_flag:
        return train_VAFs, test_VAFs


In [11]:
def LSTM_rex_weighted(train_firing_dict: dict, train_EMG_orig: dict, 
        train_var_dict:dict, train_oh_dict:dict, test_firing_dict: dict,
        test_EMG_dict: dict, EMG_name, beta=3,
        plot_sig_flag:bool=False, plot_VAF_flag:bool=True,
        return_VAF_flag:bool=True):
    # hyper params
    layer_0_units = 300
    drop_in = .25     # input dropout percentage for LSTM layer
    drop_rec = 0    # recurrent dropout for LSTM
    drop_lay = .15    # dropout layer?

    # append the variance to the EMG values
    n_EMGs = len(EMG_name)
    seq_len = train_firing_dict['Combined'].shape[1] # num of lags
    n_neurons = train_firing_dict['Combined'].shape[2]
    n_cond = len(list(train_firing_dict.keys()))-1 # to avoid counting "Combined"
    train_EMG_dict = {}
    for ii_cond, cond in enumerate(train_EMG_orig.keys()):
        train_EMG_dict[cond] = train_EMG_orig[cond].copy()
        train_EMG_dict[cond] = np.append(train_EMG_dict[cond], train_var_dict[cond], axis=1)
        train_EMG_dict[cond] = np.append(train_EMG_dict[cond], train_oh_dict[cond], axis=1)

    # pull out the "Combined" conditions for training
    train_i = train_firing_dict['Combined'] 
    train_o = train_EMG_dict['Combined']

    # define the loss function -- allows us to change beta etc dynamically
    def vrex_weighted(target, pred):
        B = beta 

        n_target = n_EMGs
        n_conds = n_cond

        err = target[:,:n_target] - pred[:,:n_target] # find the error
        se_musc = K.square(err) # square error per muscle per timepoint
        # pull out the variances and divide, then take the mean. Tx1
        se = K.mean(tf.divide(se_musc, target[:,n_target:2*n_target]), axis=1, keepdims=True) 

        # pull out the one-hot matrix for flagging
        cond_oh = tf.transpose(target[:,-n_conds:]) # transpose -- CxT

        risk = tf.matmul(cond_oh, se) # this should give us a Cx1 array
        risks = tf.divide(risk, tf.reduce_sum(cond_oh, 1, keepdims=1))

        rex = B*K.var(risks) + K.sum(risks)
        return rex


    # create the model
    mdl = tf.keras.models.Sequential()

    # add the LSTM layer
    mdl.add(tf.keras.layers.LSTM(layer_0_units, input_shape = (seq_len,n_neurons), dropout=drop_in, recurrent_dropout=drop_rec))
    mdl.add(tf.keras.layers.Dropout(drop_lay))
    mdl.add(tf.keras.layers.Dense(train_o.shape[1]))

    mdl.compile(loss=vrex_weighted, optimizer='rmsprop', metrics='mse')

    mdl.fit(train_i, train_o, epochs=50, verbose=0)



    train_preds = {}
    train_VAFs = {}
    test_preds = {}
    test_VAFs = {}

    for cond in train_firing_dict.keys():
        train_preds[cond] = mdl.predict(train_firing_dict[cond])
        train_VAFs[cond] = metrics.r2_score(train_EMG_dict[cond][:,:n_EMGs], train_preds[cond][:,:n_EMGs], multioutput='raw_values')
    
    for cond in test_firing_dict.keys():
        if cond != 'Combined': # don't really care about the "combined" test case
            test_preds[cond] = mdl.predict(test_firing_dict[cond])
            test_VAFs[cond] = metrics.r2_score(test_EMG_dict[cond][:,:n_EMGs], test_preds[cond][:,:n_EMGs], multioutput='raw_values')


    print('----------------------------------------')
    for ii_name, name in enumerate(EMG_name):
        print(name)
        for in_name, in_vaf in train_VAFs.items():
            print(f"\tTrain VAF for {in_name}: {in_vaf[ii_name]}")
        for out_name,out_vaf in test_VAFs.items():
            print(f"\tTest VAF for {out_name}: {out_vaf[ii_name]}")

    if plot_sig_flag:
        plot_rec_pred(train_EMG_dict, train_preds, EMG_name, train_VAFs, title_append = 'training set')
        plot_rec_pred(test_EMG_dict, test_preds, EMG_name, test_VAFs, title_append = 'testing set')

    if plot_VAF_flag:
        plot_VAFs(train_VAFs, test_VAFs, EMG_name, sup_title=f'REx Weighted, Beta = {beta}')

    if return_VAF_flag:
        return train_VAFs, test_VAFs



### Compare different combinations of train/test sets

So that we can quickly run through all of the iterations

Set the nonlinearity type, to compare across iterations

In [None]:
nonlinear_type = 'poly'



First train wrist movement, test wrist movement

In [None]:
train_timestamps, train_firing, train_EMG, train_force, train_kin = load_josh_mat(data['WmTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['WmTest'])
# using the training set twice as the "refit" gain -- for sanity's sake
mov_mov_linVAF = basic_decoder_comparison(train_firing, train_EMG, train_firing, train_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
mov_mov_LSTMVAF = LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

print(f"Mean Linear {np.mean(mov_mov_linVAF)}")
print(f"Mean LSTM {np.mean(mov_mov_LSTMVAF)}")

Train wrist movement, test iso

In [None]:
train_timestamps, train_firing, train_EMG, train_force, train_kin = load_josh_mat(data['WmTrain'])
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['IsoTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['IsoTest'])
mov_iso_linVAF = basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
mov_iso_LSTMVAF = LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

print(f"Mean Linear {np.mean(mov_iso_linVAF)}")
print(f"Mean LSTM {np.mean(mov_iso_LSTMVAF)}")

Train WM, test spring

In [None]:
train_timestamps, train_firing, train_EMG, train_force, train_kin = load_josh_mat(data['WmTrain'])
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['SprTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['SprTest'])
mov_spr_linVAF = basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
mov_spr_LSTMVAF = LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

print(f"Mean Linear {np.mean(mov_spr_linVAF)}")
print(f"Mean LSTM {np.mean(mov_spr_LSTMVAF)}")

Train Iso, test Iso

In [None]:
train_timestamps, train_firing, train_EMG, train_force, train_kin = load_josh_mat(data['IsoTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['IsoTest'])
iso_iso_linVAF = basic_decoder_comparison(train_firing, train_EMG, train_firing, train_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
iso_iso_LSTMVAF = LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

print(f"Mean Linear {np.mean(iso_iso_linVAF)}")
print(f"Mean LSTM {np.mean(iso_iso_LSTMVAF)}")

Train Iso, test WM

In [None]:
train_timestamps, train_firing, train_EMG, train_force, train_kin = load_josh_mat(data['IsoTrain'])
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['WmTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['WmTest'])
iso_mov_linVAF = basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
iso_mov_LSTMVAF = LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

print(f"Mean Linear {np.mean(iso_mov_linVAF)}")
print(f"Mean LSTM {np.mean(iso_mov_linVAF)}")

Train Iso, Test Spring

In [None]:
train_timestamps, train_firing, train_EMG, train_force, train_kin = load_josh_mat(data['IsoTrain'])
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['SprTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['SprTest'])
iso_spr_linVAF = basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
iso_spr_LSTMVAF = LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

print(f"Mean Linear {np.mean(iso_spr_linVAF)}")
print(f"Mean LSTM {np.mean(iso_spr_linVAF)}")

Train Spring, test each of the three

In [None]:
train_timestamps, train_firing, train_EMG, train_force, train_kin = load_josh_mat(data['SprTrain'])

print('Test Spring')
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['SprTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['SprTest'])
spr_spr_linVAF = basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
spr_spr_LSTMVAF = LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

print('Test Iso')
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['IsoTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['IsoTest'])
spr_iso_linVAF = basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
spr_iso_LSTMVAF = LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

print('Test Movement')
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['WmTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['WmTest'])
spr_mov_linVAF = basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
spr_mov_LSTMVAF = LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

dntxt.send('Spring Linear and non weighted done')

train on hybrid, test on each

In [None]:
# Hybrid 3 -- this should be a blend of all three, I think for training
train_timestamps, train_firing, train_EMG, train_force, train_kin = load_josh_mat(data['Hybrid3'])

print("test on movement")
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['WmTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['WmTest'])
basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

print("\n\ntest on Iso")
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['IsoTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['IsoTest'])
basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)


print("\n\ntest on Spring")
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['SprTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['SprTest'])
basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

dntxt.send('Hybrid basic testing complete')

In [None]:
# Hybrid 3 -- this should be a blend of all three, I think for training
train_timestamps, train_firing, train_EMG, train_force, train_kin = load_josh_mat(data['Hybrid3'])
_, retrain_firing, retrain_EMG, _, _, = load_josh_mat(data['IsoTrain'])
test_timestamps, test_firing, test_EMG, test_force, test_kin = load_josh_mat(data['IsoTest'])
basic_decoder_comparison(train_firing, train_EMG, retrain_firing, retrain_EMG, test_firing, test_EMG, test_timestamps, nonlinear_type)
LSTM_comparisons(train_firing, train_EMG, test_firing, test_EMG, test_timestamps)

### Compare Different NN Loss functions

Comparing each type of NN defined above:
1. Vanilla LSTM
1. Vanilla LSTM with "weighting"
1. LSTM with REx
1. LSTM with REx and "weighting"

Here we're defining weighting as the inverse of a muscle's variance for that condition. That means that the squared error for each muscle for each condition is multiplied by that value before the losses are summed etc. This is meant to account for different EMG magnitudes so that we are training to all conditions and muscles


First we'll test and train on iso, and set the Beta value for the Risk Extrapolation conditions to zero. This should make the plain REx equal to the vanilla LSTM, and the weighting REx equal to the vanilla weighting

In [None]:
train_firing = {'Iso': Iso_train_firing}
test_firing = {'Iso': Iso_test_firing}
train_EMG = {'Iso': Iso_train_EMG}
test_EMG = {'Iso': Iso_test_EMG}

ret_vals = LSTM_preprocess(train_firing, train_EMG, test_firing, test_EMG)

train_firing_dict = ret_vals[0]
train_EMG_dict = ret_vals[1]
train_oh_dict = ret_vals[2]
train_var_dict = ret_vals[3]
test_firing_dict = ret_vals[4]
test_EMG_dict = ret_vals[5]
test_oh_dict = ret_vals[6]
test_var_dict = ret_vals[7]


print('Vanilla LSTM')
LSTM_comparisons(train_firing_dict, train_EMG_dict, test_firing_dict, test_EMG_dict, EMG_name)

print('LSTM w/ REx, Beta=0')
LSTM_rex(train_firing_dict=train_firing_dict, train_EMG_orig=train_EMG_dict,
    train_oh_dict=train_oh_dict, test_firing_dict=test_firing_dict, 
    test_EMG_dict=test_EMG_dict, beta=0, EMG_name=EMG_name)

print('LSTM Weighted')
LSTM_across_weighted(train_firing_dict=train_firing_dict, train_EMG_orig=train_EMG_dict,
    train_var_dict=train_var_dict, test_firing_dict=test_firing_dict, 
    test_EMG_dict=test_EMG_dict, EMG_name=EMG_name)

print('LSTM w/ REx Weighted, Beta=0')
LSTM_rex_weighted(train_firing_dict = train_firing_dict, train_EMG_orig = train_EMG_dict, 
    train_oh_dict = train_oh_dict, train_var_dict = train_var_dict, 
    test_firing_dict=test_firing_dict, test_EMG_dict = test_EMG_dict, 
    beta=0, EMG_name=EMG_name)


Train on Movement and Spring, test on everything

Train on Movement and Isometric, test on everything

In [14]:
train_firing = {'Iso': Iso_train_firing, 'Mov': Mov_train_firing}
test_firing = {'Iso': Iso_test_firing, 'Mov': Mov_test_firing, 'Spr': Spr_test_firing}
train_EMG = {'Iso': Iso_train_EMG, 'Mov': Mov_train_EMG}
test_EMG = {'Iso': Iso_test_EMG, 'Mov':Mov_test_EMG, 'Spr':Spr_test_EMG}

ret_vals = LSTM_preprocess(train_firing, train_EMG, test_firing, test_EMG)

train_firing_dict = ret_vals[0]
train_EMG_dict = ret_vals[1]
train_oh_dict = ret_vals[2]
train_var_dict = ret_vals[3]
test_firing_dict = ret_vals[4]
test_EMG_dict = ret_vals[5]
test_oh_dict = ret_vals[6]
test_var_dict = ret_vals[7]

line_split = 20*'-'

print(line_split)
print(line_split)
print('Vanilla LSTM')
van_train_VAF, van_test_VAF = LSTM_comparisons(train_firing_dict, train_EMG_dict,
        test_firing_dict, test_EMG_dict, 
        EMG_name, return_VAF_flag=True, plot_VAF_flag=False)

print('\n')
print(line_split)
print('LSTM Weighted')
weight_train_VAF, weight_test_VAF = LSTM_across_weighted(train_firing_dict=train_firing_dict, train_EMG_orig=train_EMG_dict,
    train_var_dict=train_var_dict, test_firing_dict=test_firing_dict, 
    test_EMG_dict=test_EMG_dict, EMG_name=EMG_name, return_VAF_flag=True, plot_VAF_flag=False)

rex_train_VAF = {}
rex_test_VAF = {}
wrex_train_VAF = {}
wrex_test_VAF = {}


for beta in np.logspace(0,1,5) - 1:
    print('\n')
    print(line_split)
    print(line_split)
    print(f'LSTM w/ REx, Beta={beta}')
    rex_train_VAF[beta], rex_test_VAF[beta] = LSTM_rex(train_firing_dict=train_firing_dict, train_EMG_orig=train_EMG_dict,
        train_oh_dict=train_oh_dict, test_firing_dict=test_firing_dict, 
        test_EMG_dict=test_EMG_dict, beta=beta, EMG_name=EMG_name, return_VAF_flag=True, plot_VAF_flag=False)


    print('\n')
    print(line_split)
    print(f'LSTM w/ REx Weighted, Beta={beta}')
    wrex_train_VAF[beta], wrex_test_VAF[beta] = LSTM_rex_weighted(train_firing_dict = train_firing_dict, train_EMG_orig = train_EMG_dict, 
        train_oh_dict = train_oh_dict, train_var_dict = train_var_dict, 
        test_firing_dict=test_firing_dict, test_EMG_dict = test_EMG_dict, 
        beta=beta, EMG_name=EMG_name, return_VAF_flag=True, plot_VAF_flag=False)



--------------------
--------------------
Vanilla LSTM
----------------------------------------
FCU
	Train VAF for Iso: 0.2955166201712843
	Train VAF for Mov: -0.121460232653688
	Train VAF for Combined: 0.1781815267459239
	Test VAF for Iso: 0.396220455789414
	Test VAF for Mov: -0.261803774145825
	Test VAF for Spr: -0.06607942177782156
FCR
	Train VAF for Iso: 0.5653002996825884
	Train VAF for Mov: 0.31259742828742254
	Train VAF for Combined: 0.48851708672709404
	Test VAF for Iso: 0.4528618798450986
	Test VAF for Mov: 0.21164712132998476
	Test VAF for Spr: 0.1332298546205043
ECU
	Train VAF for Iso: 0.5631036852004543
	Train VAF for Mov: 0.542739737984893
	Train VAF for Combined: 0.6284191770224433
	Test VAF for Iso: 0.3999265473505047
	Test VAF for Mov: 0.3693289973271976
	Test VAF for Spr: 0.24051933522594326
ECR
	Train VAF for Iso: 0.9086027231210612
	Train VAF for Mov: 0.5802531935298063
	Train VAF for Combined: 0.8977802690141967
	Test VAF for Iso: 0.8386881544202691
	Test VAF for Mo

In [21]:
# plot the different versions as a function of Beta 
# (the non-REx versions will only be plotted as threshold lines)

# will have one subplot per muscle
# fig_vaf, ax_vaf = plt.subplots(nrows = len(EMG_name))
task_names = 'Iso','Mov','Spr'

Betas = np.logspace(0,1,5) - 1 # Beta values from above
colors = plt.get_cmap('tab10').colors


# iterate for muscles and tasks
for i_musc, musc in enumerate(EMG_name):
    for i_task,task in enumerate(task_names):
        # need to pull out the VAFs for all of the betas -- yay list comprehension
        fig, ax_vaf = plt.subplots()
        temp_rex = [rex_test_VAF[beta][task][i_musc] for beta in Betas]
        temp_wrex= [wrex_test_VAF[beta][task][i_musc] for beta in Betas]

        # plot it all
        ax_vaf.plot([min(Betas),max(Betas)],[weight_test_VAF[task][i_musc],weight_test_VAF[task][i_musc]], label=f"Simple Weighted")

        ax_vaf.plot(Betas, temp_rex, label = f"REx")
        ax_vaf.plot(Betas, temp_wrex, label = f"weighted REx")

        ax_vaf.set_title(f"{musc}, {task}")
        ax_vaf.set_xlabel('Beta value')
        ax_vaf.set_ylabel('VAF')
        ax_vaf.set_xscale('log')
        ax_vaf.legend()


In [17]:
rex_test_VAF.keys()

dict_keys([0.0, 0.7782794100389228, 2.1622776601683795, 4.623413251903491, 9.0])

In [None]:

file_save = path.splitext(mat_fn)[0] + 'train_IsoMove_VAFs.pkl'

with open(file_save, 'wb') as fid:
    pickle.dump((van_train_VAF, weight_train_VAF, rex_train_VAF, wrex_train_VAF, van_test_VAF, weight_test_VAF, rex_test_VAF, wrex_test_VAF), fid)


# plot the changes in testing error as a function of the Beta value
# compared with the original 


Train on everything, test on each

## LSTM grid search

Time to look at some non-linearities!

Let's run through the grid search on hyper parameters

In [None]:
_, firing, EMG, _, _ = load_josh_mat(data['WmTrain'])

log = LSTM_grid_search(firing, EMG, n_iter = 30, n_fold = 4)

## LSTM with REx

Run through the options, see what comes out

## Scratch Space
Sometimes the variable viewer isn't very goood

Getting bits of code to work

In [None]:
%tensorboard --logdir logs/LSTM_exploration

In [None]:
import tensorboard
%load_ext tensorboard

In [None]:
import pickle

In [None]:
exportname = 'BetaComparisons.pkl'

with open(exportname, 'w') as fid:
    pickle.dump([rex_train_VAF, rex_test_VAF, wrex_train_VAF, wrex_train_VAF], )

In [None]:
c = plt.get_cmap('tab20')

In [None]:
c

In [None]:
c.colors