## Encoder Estimation Verification Figure

### Author: Maneeshika Madduri

### Goal
Using the encoder model described below, this code runs the following tests:

(1) Compares the estimation $r^2$ to an encoder estimation from time-shuffled EMG data


(2) Reconstruct the position and velocity based on the encoder model and checks the $r^2$

### Encoder Model

$$ EMG = W*u $$

$$ u = 
\begin{bmatrix}
r_x \\ r_y \\ \dot{r_x} \\ \dot{r_y}  \\ r_x - p_x \\ r_y - p_y \\ \dot{r_x} - \dot{p_x}   \\ \dot{r_y} - \dot{p_y} \\ 1
\end{bmatrix}
$$

where 
$$
FF = \begin{bmatrix}
r_x \\ r_y \\ v_{int, x} \\ v_{int, y} 
\end{bmatrix}
,
FB = \begin{bmatrix}
r_x - p_x \\ r_y - p_y \\ v_{int,x} - v_{dec,x}  \\ v_{int,y} - v_{dec,y} 
\end{bmatrix}
,
b = \begin{bmatrix}
 1
\end{bmatrix}
$$

So $W \in R^{64 \times 9}$

In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import wilcoxon as wilcoxon


# meta analysis functions
import sys
sys.path.append('/code/')
from util import analysis
from util import plotting
from util import util_continuous as utils


In [None]:
PATH = '/data/'

In [None]:
# path of the processed data
with open(PATH + 'trial-related-data/times-decoder-update.pickle','rb') as handle:
    update_ix, update_times = pickle.load(handle)
# load the order of the trials in terms of when the subject saw the trial
with open(PATH + 'trial-related-data/decoder_trials_in_order.pickle','rb') as handle:
    trials_list, conds_trial_list = pickle.load(handle)

with open(PATH + 'encoder-estimation-data/encoder-recon-r2-data.pickle', 'rb') as handle:
    r2_recon_vel, r2_recon_pos, r2_rand_vel, r2_rand_pos = pickle.load(handle)

with open(PATH + 'encoder-estimation-data/encoder-recon-sample-subject.pickle','rb') as handle:
    ex_subject, ex_cond, ex_block, vel_rec, pos_rec, int_vel_rec, vel_rand, pos_rand, int_vel_rand, vel_dec_test, pos_act_test, ref_act_test, emg_orig, emg_rec, emg_rand2, D_test, ex_time = pickle.load(handle)

In [None]:
# check the encoder shape to match as expected
# assert(encoder.shape == (utils.n_blocks, utils.n_keys, utils.n_conds, len(update_ix) - 2, utils.n_ch, 9))

# set a time axis 
time_x = np.linspace(0, 5, len(pos_act_test))

set up the figure

In [None]:
# Import seaborn
import seaborn as sns

sns.set_theme(style="ticks", rc=utils.sns_custom_params, font_scale=0.6)


In [None]:
label_size = 6
## SETUP THE FIGURE HERE
## HAVE TO RE-REUN FROM HERE TO "CLEAR" THE PLOT
fig_enc = plt.figure(figsize = (6.3, 4), layout='constrained') # set the total figure size
mosaic = """
    aaab.
    cccd.
    """

# set up the axes
ax_dict = fig_enc.subplot_mosaic(mosaic)
for ii in ax_dict:
    plotting.remove_and_set_axes(ax_dict[ii], bottom=True, left=True)
    ax_dict[ii].tick_params(axis='both', which='major', labelsize = label_size)
    ax_dict[ii].tick_params(axis='both', which='minor', labelsize = label_size)
fig_enc.patch.set_facecolor('white')

# a
# reconstruct cursor velocity
# ax_dict['a'].set_xlabel("Time (minutes)")
ax_dict['a'].set_ylabel("Cursor Position")
ax_dict['a'].set_title("Sample Cursor Position/Velocity Reconstruction from Encoder Estimates")

# a 
# reconstruct cursor velocity r^2
# ax_dict['b'].set_xlabel("")
ax_dict['b'].set_ylabel("$r^2$ of Cursor Positions")
# ax_dict['b'].set_title("Comparing $r^2$ of Cursor Positions")

# c
# reconstruct cursor position
ax_dict['c'].set_xlabel("Time (minutes)")
ax_dict['c'].set_ylabel("Cursor Velocity")
# ax_dict['c'].set_title("Sample Reconstruction | Cursor Velocity vs Reconstructed Velocity")

# d
# reconstruct cursor position r^2
# ax_dict['b'].set_xlabel("")
ax_dict['d'].set_ylabel("$r^2$ of Cursor Velocities")
# ax_dict['d'].set_title("Comparing $r^2$ of Cursor Velocities")


# remove the x-axis for fig a
ax_dict['a'].spines["bottom"].set_visible(False)
ax_dict['a'].set_xticks([])


# ax_dict['b'].spines["bottom"].set_visible(False)
# ax_dict['d'].spines["bottom"].set_visible(False)


fig_enc.patch.set_facecolor('white')


add the sample reconstructed trial to panel (a)

In [None]:
# set a time axis 
time_x = np.linspace(0, 5, len(pos_act_test))


In [None]:
fig_ax = ax_dict['a']

print("sample subject, trial:", ex_subject, ex_block, ex_cond)

fig_ax.plot(time_x, pos_act_test[:, 0],  lw = 1, label = 'actual', color = utils.colors['cursor'])
fig_ax.plot(time_x, pos_rec[:, 0], '--', lw = 1, label = 'reconstruction', color = 'black')
fig_ax.plot(time_x, pos_rand[:, 0], '--', lw = 1, label = 'shuffled', color = 'dimgray', alpha=0.3, zorder = 1)

fig_enc

In [None]:
fig_ax = ax_dict['c']

fig_ax.plot(time_x, vel_dec_test[:, 0],  lw = 1, label = 'actual', color = utils.colors['cursor'])
fig_ax.plot(time_x, vel_rec[:, 0], '--', lw = 1, label = 'reconstruction', color = 'black', alpha = 0.5)
fig_ax.plot(time_x, vel_rand[:, 1], '--', lw = 1, label = 'shuffled', color = 'dimgray', alpha=0.5, zorder = 1)
fig_ax.set_ylim(-50, 50)

# share the y-axis for both plots
# ax_dict['c'].get_shared_x_axes().join(ax_dict['a'], ax_dict['c'])


fig_enc


In [None]:
## WILCOXON

fig_ax = ax_dict['b']

data1 = np.mean(r2_recon_pos, axis = (0, 2, 3)).flatten()
data2 = np.mean(r2_rand_pos, axis = (0, 2, 3)).flatten()
data_groups = [data1, data2]
data_labels = ['Reconstructed', 'Random']
data_pos = [0, 0.4]


bplot = fig_ax.boxplot(data_groups, 
                    showfliers=False,
                    patch_artist=True,
                    positions=data_pos,
                    widths = 0.3,
                    boxprops=dict(edgecolor="white", alpha=0.5),
                    medianprops=dict(color='k', lw=0.1))


for patch, color in zip(bplot['boxes'], ['black', 'dimgray']):
    patch.set_facecolor(color)

# rotate labels  
fig_ax.set_xticks(data_pos,data_labels, rotation=40)
w = wilcoxon(data1, data2) 
print(w)


plotting.plot_significance(pvalue = w.pvalue, data1=data1, data2 = data2, data_pos = data_pos, 
                  fig=fig_enc, ax=fig_ax, fontsize=10, lw=0.5, y_asterix=0.6, y_bar=0.5)

fig_enc

In [None]:
## WILCOXON

fig_ax = ax_dict['d']

data1 = np.mean(r2_recon_vel, axis = (0, 2, 3)).flatten()
data2 = np.mean(r2_rand_vel, axis = (0, 2, 3)).flatten()
data_groups = [data1, data2]
data_labels = ['Reconstructed', 'Random']
data_pos = [-0, 0.4]


bplot = fig_ax.boxplot(data_groups, 
                    showfliers=False,
                    patch_artist=True,
                    positions=data_pos,
                    widths = 0.3,
                    boxprops=dict(edgecolor="white", alpha=0.5),
                    medianprops=dict(color='k', lw=0.1))


for patch, color in zip(bplot['boxes'], ['black', 'dimgray']):
    patch.set_facecolor(color)

# rotate labels  
fig_ax.set_xticks(data_pos,data_labels, rotation=40)
w = wilcoxon(data1, data2) 
print(w)


plotting.plot_significance(pvalue = w.pvalue, data1=data1, data2 = data2, data_pos = data_pos, 
                  fig=fig_enc, ax=fig_ax, fontsize=10, lw=0.5, y_asterix=0.6, y_bar=0.5)

fig_enc

In [None]:
# image_format = 'pdf' # e.g .png, .svg, etc.
# image_name = 'fig2-encoder-r2-mean.pdf'
# PATH = '/Users/mmadduri/OneDrive - UW/PhD_Research/Figures/myo-coadapt-2023/nov2023/fig2/python-figs/'
# fig_enc.savefig(PATH + image_name, format=image_format, dpi=300)

In [None]:
## a zoomed-in version

In [None]:
label_size = 6
## SETUP THE FIGURE HERE
## HAVE TO RE-REUN FROM HERE TO "CLEAR" THE PLOT
fig_enc = plt.figure(figsize = (5, 1.3), layout='constrained') # set the total figure size
mosaic = """
    aabccd
    """

# set up the axes
ax_dict = fig_enc.subplot_mosaic(mosaic)
for ii in ax_dict:
    plotting.remove_and_set_axes(ax_dict[ii], bottom=True, left=True)
    ax_dict[ii].tick_params(axis='both', which='major', labelsize = label_size)
    ax_dict[ii].tick_params(axis='both', which='minor', labelsize = label_size)
fig_enc.patch.set_facecolor('white')

# a
# reconstruct cursor velocity
# ax_dict['a'].set_xlabel("Time (minutes)")
ax_dict['a'].set_ylabel("Cursor Position")
# ax_dict['a'].set_title("Sample Cursor Position/Velocity Reconstruction from Encoder Estimates")

# a 
# reconstruct cursor velocity r^2
# ax_dict['b'].set_xlabel("")
ax_dict['b'].set_title("$r^2$ of Cursor Positions")
# ax_dict['b'].set_title("Comparing $r^2$ of Cursor Positions")

# c
# reconstruct cursor position
ax_dict['c'].set_xlabel("Time (minutes)")
ax_dict['c'].set_ylabel("Cursor Velocity")
# ax_dict['c'].set_title("Sample Reconstruction | Cursor Velocity vs Reconstructed Velocity")

# d
# reconstruct cursor position r^2
# ax_dict['b'].set_xlabel("")
ax_dict['d'].set_title("$r^2$ of Cursor Velocities")
# ax_dict['d'].set_title("Comparing $r^2$ of Cursor Velocities")


# remove the x-axis for fig a
# ax_dict['a'].spines["bottom"].set_visible(False)
# ax_dict['a'].set_xticks([])


# ax_dict['b'].spines["bottom"].set_visible(False)
# ax_dict['d'].spines["bottom"].set_visible(False)


fig_enc.patch.set_facecolor('white')


In [None]:
label_size = 6
## SETUP THE FIGURE HERE
## HAVE TO RE-REUN FROM HERE TO "CLEAR" THE PLOT
fig_r2 = plt.figure(figsize = (5, 1.3), layout='constrained') # set the total figure size
mosaic = """
    abcd
    """

# set up the axes
ax_r2 = fig_r2.subplot_mosaic(mosaic)
for ii in ax_r2:
    plotting.remove_and_set_axes(ax_r2[ii], bottom=True, left=True)
    ax_r2[ii].tick_params(axis='both', which='major', labelsize = label_size)
    ax_r2[ii].tick_params(axis='both', which='minor', labelsize = label_size)
fig_r2.patch.set_facecolor('white')

# a
# reconstruct cursor velocity
# ax_dict['a'].set_xlabel("Time (minutes)")
ax_r2['a'].set_ylabel("$r^2$ of Cursor Position")
# ax_dict['a'].set_title("Sample Cursor Position/Velocity Reconstruction from Encoder Estimates")

# a 
# reconstruct cursor velocity r^2
# ax_dict['b'].set_xlabel("")
ax_r2['b'].set_title("Distribution $r^2$ of Cursor Positions")
# ax_dict['b'].set_title("Comparing $r^2$ of Cursor Positions")

# c
# reconstruct cursor position
ax_r2['c'].set_ylabel("$r^2$ of Cursor Velocities")
# ax_dict['c'].set_title("Sample Reconstruction | Cursor Velocity vs Reconstructed Velocity")

# d
# reconstruct cursor position r^2
# ax_dict['b'].set_xlabel("")
ax_r2['d'].set_title("Distribution $r^2$ of Cursor Velocites")
# ax_dict['d'].set_title("Comparing $r^2$ of Cursor Velocities")


fig_r2.patch.set_facecolor('white')

In [None]:
fig_ax = ax_dict['a']
t0 = 17000
t1 = 22000



fig_ax.plot(time_x[t0:t1], pos_act_test[t0:t1, 0] * utils.x_cm_to_au,  lw = 1, label = 'actual', color = utils.colors['cursor'])
fig_ax.plot(time_x[t0:t1], pos_rec[t0:t1, 0]* utils.x_cm_to_au, '--', lw = 1, label = 'reconstruction', color = utils.colors['E'], zorder = 10)
fig_ax.plot(time_x[t0:t1], pos_rand[t0:t1, 0]* utils.x_cm_to_au, '--', lw = 1, label = 'shuffled', color = 'dimgray', alpha=0.3, zorder = 1)
fig_ax.set_ylim(-25, 25)


fig_ax = ax_dict['c']

fig_ax.plot(time_x[t0:t1], vel_dec_test[t0:t1, 0]* utils.x_cm_to_au,  lw = 1, label = 'actual', color = utils.colors['cursor'])
fig_ax.plot(time_x[t0:t1], vel_rec[t0:t1, 0]* utils.x_cm_to_au, '--', lw = 1, label = 'reconstruction', color = utils.colors['E'], alpha = 0.5, zorder = 10)
fig_ax.plot(time_x[t0:t1], vel_rand[t0:t1, 0]* utils.x_cm_to_au,  '--', lw = 1, label = 'shuffled', color = 'dimgray', alpha=0.3, zorder = 1)
fig_ax.set_ylim(-25, 25)

# share the y-axis for both plots
# ax_dict['c'].get_shared_x_axes().join(ax_dict['a'], ax_dict['c'])


fig_enc



In [None]:
## WILCOXON

fig_ax = ax_dict['b']

data1 = np.mean(r2_recon_pos, axis = (0, 2, 3)).flatten()
data2 = np.mean(r2_rand_pos, axis = (0, 2, 3)).flatten()
data_groups = [data1, data2]
data_labels = ['Reconstructed', 'Random']
data_pos = [0, 0.4]


bplot = fig_ax.boxplot(data_groups, 
                    showfliers=False,
                    patch_artist=True,
                    positions=data_pos,
                    widths = 0.3,
                    boxprops=dict(edgecolor="white", alpha=0.5),
                    medianprops=dict(color='k', lw=0.1))


for patch, color in zip(bplot['boxes'], [utils.colors['E'], 'dimgray']):
    patch.set_facecolor(color)

# rotate labels  
fig_ax.set_xticks(data_pos,data_labels, rotation=0)
w = wilcoxon(data1, data2) 
print(w)

fig_ax.set_ylim(-2,1)

plotting.plot_significance(pvalue = w.pvalue, data1=data1, data2 = data2, data_pos = data_pos, 
                  fig=fig_enc, ax=fig_ax, fontsize=10, lw=0.5, y_asterix=0.6, y_bar=0.5)

fig_enc

In [None]:
## WILCOXON

fig_ax = ax_r2['a']

data1 = np.mean(r2_recon_pos, axis = (0, 2, 3)).flatten()
data2 = np.mean(r2_rand_pos, axis = (0, 2, 3)).flatten()
print("recont r^2 pos: ", np.mean(r2_recon_pos))
print("rand r^2 pos: ", np.mean(r2_rand_pos))

data_groups = [data1, data2]
data_labels = ['Reconstructed', 'Random']
data_pos = [0, 0.4]


bplot = fig_ax.boxplot(data_groups, 
                    showfliers=False,
                    patch_artist=True,
                    positions=data_pos,
                    widths = 0.3,
                    boxprops=dict(edgecolor="white", alpha=0.5),
                    medianprops=dict(color='k', lw=0.1))


for patch, color in zip(bplot['boxes'], [utils.colors['E'], 'dimgray']):
    patch.set_facecolor(color)

# rotate labels  
fig_ax.set_xticks(data_pos,data_labels, rotation=0)
w = wilcoxon(data1, data2) 
print(w)
fig_ax.set_ylim(-2,1)


plotting.plot_significance(pvalue = w.pvalue, data1=data1, data2 = data2, data_pos = data_pos, 
                  fig=fig_enc, ax=fig_ax, fontsize=10, lw=0.5, y_asterix=0.6, y_bar=0.5)

fig_r2

In [None]:
## WILCOXON

fig_ax = ax_dict['d']

data1 = np.mean(r2_recon_vel, axis = (0, 2, 3)).flatten()
data2 = np.mean(r2_rand_vel, axis = (0, 2, 3)).flatten()

print("recont r^2 vel: ", np.mean(r2_recon_vel))
print("rand r^2 vel: ", np.mean(r2_rand_vel))

data_groups = [data1, data2]
data_labels = ['Reconstructed', 'Random']
data_pos = [-0, 0.4]


bplot = fig_ax.boxplot(data_groups, 
                    showfliers=False,
                    patch_artist=True,
                    positions=data_pos,
                    widths = 0.3,
                    boxprops=dict(edgecolor="white", alpha=0.5),
                    medianprops=dict(color='k', lw=0.1))


for patch, color in zip(bplot['boxes'], [utils.colors['E'], 'dimgray']):
    patch.set_facecolor(color)

# rotate labels  
fig_ax.set_xticks(data_pos,data_labels, rotation=20)
w = wilcoxon(data1, data2) 
print(w)


fig_ax.set_ylim(-2,1)

plotting.plot_significance(pvalue = w.pvalue, data1=data1, data2 = data2, data_pos = data_pos, 
                  fig=fig_enc, ax=fig_ax, fontsize=10, lw=0.5, y_asterix=0.6, y_bar=0.5)

fig_enc

In [None]:
image_format = 'pdf' # e.g .png, .svg, etc.
image_name = 'fig2-encoder-zoom-mean-cm.pdf'
PATH = '/results/'
fig_enc.savefig(PATH + image_name, format=image_format, dpi=300)

In [None]:
## WILCOXON

fig_ax = ax_r2['c']

data1 = np.mean(r2_recon_vel, axis = (0, 2, 3)).flatten()
data2 = np.mean(r2_rand_vel, axis = (0, 2, 3)).flatten()
data_groups = [data1, data2]
data_labels = ['Reconstructed', 'Random']
data_pos = [-0, 0.4]


bplot = fig_ax.boxplot(data_groups, 
                    showfliers=False,
                    patch_artist=True,
                    positions=data_pos,
                    widths = 0.3,
                    boxprops=dict(edgecolor="white", alpha=0.5),
                    medianprops=dict(color='k', lw=0.1))


for patch, color in zip(bplot['boxes'], [utils.colors['E'], 'dimgray']):
    patch.set_facecolor(color)

# rotate labels  
fig_ax.set_xticks(data_pos,data_labels, rotation=20)
w = wilcoxon(data1, data2) 
print(w)


plotting.plot_significance(pvalue = w.pvalue, data1=data1, data2 = data2, data_pos = data_pos, 
                  fig=fig_enc, ax=fig_ax, fontsize=10, lw=0.5, y_asterix=0.6, y_bar=0.5)

fig_r2

In [None]:
fig_ax = ax_r2['b']
fig_ax.hist(r2_recon_pos.flatten(), color = utils.colors['E'], alpha = 0.2)
fig_ax.vlines(np.mean(r2_recon_pos.flatten()), 0, 150, color = 'gray')
print(np.mean(r2_recon_pos.flatten()))
print(np.std(r2_recon_pos.flatten()))


fig_ax = ax_r2['b']
fig_ax.hist(r2_recon_vel.flatten(), color = utils.colors['E'], alpha = 0.2)
fig_ax.vlines(np.mean(r2_recon_vel.flatten()), 0, 150, color = 'black')
print(np.mean(r2_recon_vel.flatten()))
print(np.std(r2_recon_vel.flatten()))

fig_r2

In [None]:
utils.colors['E']

In [None]:
# image_format = 'pdf' # e.g .png, .svg, etc.
# image_name = 'fig2-encoder-zoom-mean.pdf'
# PATH = '/Users/mmadduri/OneDrive - UW/PhD_Research/Figures/myo-coadapt-2023/nov2023/fig2/python-figs/'
# fig_ax.savefig(PATH + image_name, format=image_format, dpi=300)