Fig 5 - comparing penalty parameters

note to maneeshika: boxplots are across entire trial as of may 2024 draft 

In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import wilcoxon as wilcoxon
from scipy.stats import ttest_rel
import scipy


# meta analysis functions
import sys
sys.path.append('/code/')
from util import analysis
from util import plotting
from util import util_continuous as utils

In [None]:
PATH = '/data/'

In [None]:
# using the task error that calculates time domain in cm
with open(PATH + 'time-domain-error/time-domain-error-30sec-in-cm.pkl','rb') as handle:
    td_error, td_error_first, td_error_last, t0_start, t0_end, t1_end, td_diff, td_diff_slow, td_diff_fast, td_diff_pos, td_diff_neg, td_diff_pD3, td_diff_pD4 = pickle.load(handle)

with open(PATH + 'trial-related-data/decoded-intended-vels.pickle','rb') as handle:
    dec_vels_block1, dec_vels_block2, int_vel_block1, int_vel_block2, conds =  pickle.load(handle)

with open(PATH + 'encoder-estimation-data/encoder-decoder-data.pickle', 'rb') as handle:
    encoder, encoder_r2, idx_dict, pos_vel_model, pos, dec_vels, decoders = pickle.load(handle)
keys = ['METACPHS_S106', 'METACPHS_S107','METACPHS_S108', 'METACPHS_S109', 'METACPHS_S110', 'METACPHS_S111', 'METACPHS_S112', 'METACPHS_S113', 'METACPHS_S114', 'METACPHS_S115', 'METACPHS_S116', 'METACPHS_S117', 'METACPHS_S118', 'METACPHS_S119']


In [None]:
assert(td_error.shape == (utils.n_blocks, utils.n_keys, utils.n_conds, utils.min_time))

In [None]:
# Import seaborn
import seaborn as sns

sns.set_theme(style="ticks", rc=utils.sns_custom_params, font_scale=0.6)


In [None]:
import matplotlib.ticker as ticker

label_size = 6
## SETUP THE FIGURE HERE
## HAVE TO RE-REUN FROM HERE TO "CLEAR" THE PLOT
fig_penalty = plt.figure(figsize = (8, 5), layout='constrained') # set the total figure size
# mosaic = """
#     aabbc.
#     ddefgg
#     """

mosaic = """
    bbcdde
    affghh
    """


# mosaic = """
#     aabccd
#     effghh
#     """

# set up the axes
ax_dict = fig_penalty.subplot_mosaic(mosaic)
for ii in ax_dict:
    plotting.remove_and_set_axes(ax_dict[ii], bottom=True, left=True)
    ax_dict[ii].tick_params(axis='both', which='major', labelsize = label_size)
    ax_dict[ii].tick_params(axis='both', which='minor', labelsize = label_size)
    ax_dict[ii].xaxis.set_major_locator(ticker.MultipleLocator(base=5))
    ax_dict[ii].xaxis.set_minor_locator(ticker.MultipleLocator(base=1))
fig_penalty.patch.set_facecolor('white')


# a - time-domain error/task performance
# ax_dict['a'].set_title("performance")
ax_dict['a'].set_ylabel('% Change in Error')


# b - decoder norm
ax_dict['b'].set_title("|D|")
ax_dict['b'].set_ylabel('$|D|_F$')
ax_dict['b'].set_xlabel('Time (min)')

# c - decoder norm significance
# ax_dict['c'].set_title("decoder norm significance")
ax_dict['c'].set_ylabel('$|D|_F$')


# d - encoder norm
# ax_dict['d'].set_title("|E|")
ax_dict['d'].set_ylabel('$|E|_F$')
ax_dict['d'].set_xlabel('Time (min)')

# e - decoder norm significance
# ax_dict['e'].set_title("|E| significance")
ax_dict['e'].set_ylabel('$|E|_F$')

# f - cursor velocity
# ax_dict['f'].set_title("cursor velocity")
ax_dict['f'].set_ylabel('cursor speed (cm/s)')

# f - cursor velocity
# ax_dict['g'].set_title("|v| significance")
ax_dict['g'].set_ylabel('$|v|_2$')

# f - cursor velocity
# ax_dict['h'].set_title("|E| vs |v|")
ax_dict['h'].set_ylabel('$|v|$')
ax_dict['h'].set_xlabel('$|E|$')

plt.subplots_adjust(wspace=.7, hspace=0.5)
# fig_penalty.suptitle("$\lambda high = D low = v low = E high$ \n $\lambda low = D high = v high = E low $")


In [None]:
assert(td_error[:, :, utils.pD_3, -t1_end:].shape 
       == (utils.n_blocks, utils.n_keys, len(utils.pD_3), t0_end - t0_start))

In [None]:
td_diff_pD3.shape

In [None]:
## a - no difference in performance between initial and end

axs = ax_dict['a']

## WILCOXON
# early = first 60 seconds after ramp
td_error_first_med_pd3 = np.mean(td_error[:, :, utils.pD_3, t0_start: t0_end], axis = (0, 2, 3))
td_error_last_med_pd3 = np.mean(td_error[:, :, utils.pD_3, -t1_end:], axis = (0, 2, 3))

# late = last 60 seconds of trial
td_error_first_med_pd4 = np.mean(td_error[:, :, utils.pD_4, t0_start: t0_end], axis = (0, 2, 3))
td_error_last_med_pd4 = np.mean(td_error[:, :, utils.pD_4, -t1_end:], axis = (0, 2, 3))


# make sure that the Wilcoxon comparisons here are N of 14
assert(td_error_first_med_pd3.shape == (utils.n_keys, )) # make sure the data is the number of subjects
assert(td_error_first_med_pd4.shape == (utils.n_keys, )) # make sure the data is the number of subjects
assert(td_error_last_med_pd3.shape == (utils.n_keys, )) # make sure the data is the number of subjects
assert(td_error_last_med_pd4.shape == (utils.n_keys, )) # make sure the data is the number of subjects



data1 = np.ndarray.flatten(td_error_first_med_pd3)
data2 = np.ndarray.flatten(td_error_first_med_pd4)

data3 = np.ndarray.flatten(td_error_last_med_pd3)
data4 = np.ndarray.flatten(td_error_last_med_pd4)

data5 = np.ndarray.flatten(td_diff_pD3)
data6 = np.ndarray.flatten(td_diff_pD4)

data_groups = [data5, data6]
data_labels = ['high', 'low']
data_pos = [0, 0.4]
bplot = axs.boxplot(data_groups, 
                    showfliers=False,
                    patch_artist=True,
                    positions=data_pos,
                    widths = 0.3,
                    boxprops=dict(edgecolor="none"),
                     medianprops=dict(color='k', lw=1))


t = 0
if utils.colors is not None:
    for patch, color in zip(bplot['boxes'], [utils.colors['pD_3'], utils.colors['pD_4'], utils.colors['pD_3'], utils.colors['pD_4']]):
        patch.set_facecolor(color)
        if t < 2:
            patch.set_alpha(0.4)
        t = t + 1

# rotate labels  
axs.set_xticks(data_pos,data_labels, rotation=40)

w1 = wilcoxon(data5, data6) 
print(w1)
plotting.plot_significance(pvalue = w1.pvalue, data1=data5, data2 = data6, data_pos = data_pos, 
                           ax=axs, lw=0.5, fontsize = label_size, y_bar = 1, y_asterix = 2)

pt = ttest_rel(data5, data6) 
print(pt)


fig_penalty

In [None]:
update_len = len(utils.update_ix)
update_len

In [None]:
def calc_matrix_norm(M):
    '''
    calculates the frobenius norm squared of a 2-D matrix M
    '''
    M_norm = np.linalg.norm(M,'fro') #**2
    return M_norm

In [None]:
# def test_calc_matrix_norm():
#     # using example from: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
#     a = np.arange(9) - 4
#     b = a.reshape((3, 3))
#     ans = np.linalg.norm(b, 'fro')
#     assert(calc_matrix_norm(b) == ans**2)

#     # using definition from: https://mathworld.wolfram.com/FrobeniusNorm.html
#     b_inner = [[b[i][j]**2 for j in range(b.shape[0])] for i in range(b.shape[1])]
#     assert( (np.sqrt(np.sum(b_inner))**2) == calc_matrix_norm(b) )

In [None]:
# b - decoder norms

ax = ax_dict['b']
D_effort = np.zeros((utils.n_blocks, utils.n_keys, utils.n_conds, update_len - 1)) # 2 x 7 x 8 x 18
# update_ix - 1 because the last update is not evenly spaced

# calculate the decoder "effort" which is the decoder norm squared
for iK, key in enumerate(utils.keys):
    for iC, cond in enumerate(conds):
        # BLOCK 1
        b = 0
        # W1 = Ws_block1[key][iC][utils.update_ix] # so W1 = 19 x 2 x 64 
        W1 = decoders[b, iK, iC]
        # calculate each decoder norm per update
        D_effort[b, iK, iC, :] = np.array([calc_matrix_norm(W1[ii]) for ii in range(update_len - 1)])
        
        # BLOCK 2
        b = 1
        # W2 = Ws_block2[key][iC][utils.update_ix] # W2 = 19 x 2 x 64
        W2 = decoders[b, iK, iC]
        # calculate each decoder norm per update
        D_effort[b, iK, iC, :] = np.array([calc_matrix_norm(W2[ii]) for ii in range(update_len - 1)])    
            
# check the shape
assert(D_effort.shape == (utils.n_blocks, utils.n_keys, utils.n_conds, update_len - 1))

# check that all decoder norms were calculated
assert(np.all(D_effort))

# take the mean decoder norm per subject 
D_norm_pd3 = np.mean(D_effort[:, :, utils.pD_3], axis = (0, 2)) # so each subject has a D_norm_pd3 plot
D_norm_pd4 = np.mean(D_effort[:, :, utils.pD_4], axis = (0, 2))

# axis = 0
# make sure the N's are correctly compared
assert(D_norm_pd3.shape == (utils.n_keys, update_len-1))
assert(D_norm_pd4.shape == (utils.n_keys, update_len-1))


# take the median + interquartile distribution across subjects
# NOTE: taking the first time index (1:) off b/c it's the initial decoder and is very small in comparison
##  looks weird on the graph
D_norm_pd3_25, D_norm_pd3_50, D_norm_pd3_75 = np.percentile(D_norm_pd3[:, 1:], 
                                                            [25, 50, 75] , axis=0)
D_norm_pd4_25, D_norm_pd4_50, D_norm_pd4_75 = np.percentile(D_norm_pd4[:, 1:], 
                                                            [25, 50, 75] , axis=0)

# x-axis for graphing
xn = np.linspace(20, 300, len(D_norm_pd3_50))/60 # minutes


ax.fill_between(xn, D_norm_pd3_25, D_norm_pd3_75, 
                alpha=0.1, color = utils.colors['pD_3'], edgecolor = None)
ax.fill_between(xn, D_norm_pd4_25, D_norm_pd4_75, 
                alpha=0.1, color = utils.colors['pD_4'], edgecolor = None)

ax.plot(xn, D_norm_pd3_50, '-o', alpha=1, linewidth=1, markersize=1,
        color = utils.colors['pD_3'], label = 'high $\lambda$')#, \lambda_D = 1e-3$')

ax.plot(xn, D_norm_pd4_50, '--o', alpha=1, linewidth=1, markersize=1, 
        color = utils.colors['pD_4'], label = 'low $\lambda$') #, \lambda_D = 1e-4$')


ax.legend(labelcolor='linecolor', handlelength=2, frameon=False,  loc='upper right', fontsize = label_size)

fig_penalty

In [None]:
axs = ax_dict['c']

# take mean across entire trial
# D_effort already cuts out the last decoder update that is short
D_norm_pd3_subj = np.mean(D_effort[:, :, utils.pD_3, :], axis = (0, 2, -1)) # so each subject has a D_norm_pd3 plot
D_norm_pd4_subj = np.mean(D_effort[:, :, utils.pD_4, :], axis = (0, 2, -1))
assert(D_norm_pd3_subj.shape == D_norm_pd4_subj.shape == (utils.n_keys, ))

data1 = np.ndarray.flatten(D_norm_pd3_subj)
data2 = np.ndarray.flatten(D_norm_pd4_subj)


data_groups = [data1, data2]
data_labels = ['high', 'low',]
data_pos = [0, 0.4]
bplot = axs.boxplot(data_groups, 
                    showfliers=False,
                    patch_artist=True,
                    positions=data_pos,
                    widths = 0.3,
                    boxprops=dict(edgecolor="none"),
                     medianprops=dict(color='k', lw=1))


t = 0
if utils.colors is not None:
    for patch, color in zip(bplot['boxes'], [utils.colors['pD_3'], utils.colors['pD_4']]):
        patch.set_facecolor(color)


# rotate labels  
axs.set_xticks(data_pos,data_labels, rotation=40)

w1 = wilcoxon(data1, data2) 
print(w1)
plotting.plot_significance(pvalue = w1.pvalue, data1=data1, data2 = data2, data_pos = data_pos[:2], 
                           ax=axs, lw=0.5, fontsize = 10, y_bar = 0.3, y_asterix = 0.5)

pt = ttest_rel(data1, data2) 
print(pt)

#w2 = wilcoxon(data3, data4) 
#print(w2)

#plotting.plot_significance(pvalue = w2.pvalue, data1=data3, data2 = data4, data_pos = data_pos[-2:], 
                           

fig_penalty

In [None]:
enc_linear = encoder[:, :, :, :, :, :-1]
assert(enc_linear.shape == (utils.n_blocks, utils.n_keys, utils.n_conds, update_len - 2, utils.n_ch, 8))

In [None]:
# b - decoder norms

ax = ax_dict['d']
E_effort = np.zeros((utils.n_blocks, utils.n_keys, utils.n_conds, update_len - 2)) # 2 x 7 x 8 x 18
# update_ix - 1 because the last update is not evenly spaced

enc_linear = encoder[:, :, :, :, :, :-1]
# calculate the decoder "effort" which is the decoder norm squared
for iK, key in enumerate(utils.keys):
    for iB in range(utils.n_blocks):
        for iC, cond in enumerate(conds):
            
            enc = enc_linear[iB, iK, iC]# so E1 = num updates x 64 x 8
            E_norm = np.array([calc_matrix_norm(enc[ii, :, : ]) for ii in range(update_len - 2)])
            assert(E_norm.shape == (update_len - 2, ))
            
            E_effort[iB, iK, iC, :] = E_norm  
                
# check the shape
assert(E_effort.shape == (utils.n_blocks, utils.n_keys, utils.n_conds, update_len - 2))

# check that all decoder norms were calculated
assert(np.all(E_effort))

# take the mean per subject and then interquartile across subjects
E_norm_pd3 = np.mean(E_effort[:, :, utils.pD_3], axis = (0, 2)) # so each subject has a D_norm_pd3 plot
E_norm_pd4 = np.mean(E_effort[:, :, utils.pD_4], axis = (0, 2))

# axis = 0
# make sure the N's are correctly compared
assert(E_norm_pd3.shape == (utils.n_keys, update_len-2))
assert(E_norm_pd4.shape == (utils.n_keys, update_len-2))


E_norm_pd3_25, E_norm_pd3_50, E_norm_pd3_75 = np.percentile(E_norm_pd3, [25, 50, 75] , axis=0)
E_norm_pd4_25, E_norm_pd4_50, E_norm_pd4_75 = np.percentile(E_norm_pd4, [25, 50, 75] , axis=0)


xn = np.linspace(20, 300, len(E_norm_pd3_50))/60 # minutes
ax.fill_between(xn, E_norm_pd3_25, E_norm_pd3_75, 
                alpha=0.1, color = utils.colors['pD_3'], edgecolor = None)
ax.fill_between(xn, E_norm_pd4_25, E_norm_pd4_75, 
                alpha=0.1, color = utils.colors['pD_4'], edgecolor = None)

ax.plot(xn, E_norm_pd3_50, '-o', alpha=1, linewidth=1, markersize=1,
        color = utils.colors['pD_3'], label = 'high $\lambda$')#, \lambda_D = 1e-3$')

ax.plot(xn, E_norm_pd4_50, '--o', alpha=1, linewidth=1, markersize=1, 
        color = utils.colors['pD_4'], label = 'low $\lambda$') #, \lambda_D = 1e-4$')


ax.legend(labelcolor='linecolor', handlelength=2, frameon=False,  loc='upper right', fontsize = label_size)

ax.set_ylim(0, 8)
fig_penalty

In [None]:
axs = ax_dict['e']

# pd3 =  10^-3 > 10^-4 ; pd3 = high, pd4 = low

# take the median per subject for each condition
E_norm_pd3_subj = np.mean(E_effort[:, :, utils.pD_3, :], axis = (0, 2, -1)) # so each subject has a D_norm_pd3 plot
E_norm_pd4_subj = np.mean(E_effort[:, :, utils.pD_4, :], axis = (0, 2, -1))
assert(E_norm_pd3_subj.shape == E_norm_pd4_subj.shape == (utils.n_keys, ))

data1 = np.ndarray.flatten(E_norm_pd3_subj) # high
data2 = np.ndarray.flatten(E_norm_pd4_subj) # low


print(E_norm_pd3_subj - E_norm_pd4_subj)

data_groups = [data1, data2]
data_labels = ['high', 'low',]
data_pos = [0, 0.4]
bplot = axs.boxplot(data_groups, 
                    showfliers=False,
                    patch_artist=True,
                    positions=data_pos,
                    widths = 0.3,
                    boxprops=dict(edgecolor="none"),
                    medianprops=dict(color='k', lw=1))


t = 0
if utils.colors is not None:
    for patch, color in zip(bplot['boxes'], [utils.colors['pD_3'], utils.colors['pD_4']]):
        patch.set_facecolor(color)


# rotate labels  
axs.set_xticks(data_pos,data_labels, rotation=40)

w1 = wilcoxon(data1, data2) 
print(w1)
plotting.plot_significance(pvalue = w1.pvalue, data1=data1, data2 = data2, data_pos = data_pos[:2], 
                           ax=axs, lw=0.5, fontsize = 10, y_bar = 0.5, y_asterix = 1)


pt = ttest_rel(data1, data2) 
print(pt)


axs.set_ylim(0, 20)
w2 = wilcoxon(data1, data2, alternative='greater') 
print(w2)


#plotting.plot_significance(pvalue = w2.pvalue, data1=data3, data2 = data4, data_pos = data_pos[-2:], 
                           

fig_penalty

In [None]:
# check data
for iK, key in enumerate(keys):
    for iC, cond in enumerate(conds):
        # block 1
        assert(np.all(dec_vels[0, iK, iC] == dec_vels_block1[key][iC, :, :]))

        # block 2
        assert(np.all(dec_vels[1, iK, iC] == dec_vels_block2[key][iC, :, :]))

In [None]:
dec_vels.shape

In [None]:
20770/5 

In [None]:
last_min = int(utils.min_time/5)
last_min

In [None]:
# f
axs = ax_dict['g']

# mean across entire trial
dvels_pd3_subj = np.mean(np.linalg.norm(dec_vels[:, :, utils.pD_3, :] * (utils.x_cm_to_au,utils.y_cm_to_au), 
                                    ord = 2, axis = -1), axis = (0, 2, 3)) 
assert(dvels_pd3_subj.shape == (utils.n_keys, ))


dvels_pd4_subj = np.mean(np.linalg.norm(dec_vels[:, :, utils.pD_4, :] * (utils.x_cm_to_au,utils.y_cm_to_au),
                                    ord = 2,axis = -1), axis = (0, 2, 3))

assert(dvels_pd4_subj.shape == (utils.n_keys, ))

data1 = np.ndarray.flatten(dvels_pd3_subj)
data2 = np.ndarray.flatten(dvels_pd4_subj)


data_groups = [data1, data2]
data_labels = ['high', 'low',]
data_pos = [0, 0.4]
bplot = axs.boxplot(data_groups, 
                    showfliers=False,
                    patch_artist=True,
                    positions=data_pos,
                    widths = 0.3,
                    boxprops=dict(edgecolor="none"),
                    medianprops=dict(color='k', lw=1))


t = 0
if utils.colors is not None:
    for patch, color in zip(bplot['boxes'], [utils.colors['pD_3'], utils.colors['pD_4']]):
        patch.set_facecolor(color)


# rotate labels  
axs.set_xticks(data_pos,data_labels, rotation=40)

w1 = wilcoxon(data1, data2) 
print(w1)
plotting.plot_significance(pvalue = w1.pvalue, data1=data1, data2 = data2, data_pos = data_pos[:2], 
                           ax=axs, lw=0.5, fontsize = 10, y_bar =-4, y_asterix =-4)

pt = ttest_rel(data1, data2) 
print(pt)

# axs.set_ylim(11, 16)


fig_penalty

In [None]:
dvels_pd3 = np.mean(np.linalg.norm(dec_vels[:, :, utils.pD_3] * (utils.x_cm_to_au,utils.y_cm_to_au),
                                    ord = 2, axis = -1), axis = (0, 2))
dvels_pd4 = np.mean(np.linalg.norm(dec_vels[:, :, utils.pD_4] * (utils.x_cm_to_au,utils.y_cm_to_au),
                                    ord = 2,axis = -1), axis = (0, 2))


dvels_pd3_25, dvels_pd3_50, dvels_pd3_75 = np.percentile(dvels_pd3, [25, 50, 75] , axis=0)
dvels_pd4_25, dvels_pd4_50, dvels_pd4_75 = np.percentile(dvels_pd4, [25, 50, 75] , axis=0)

# fig, ax = plt.subplots(1, 1, figsize=(2, 2))

# ax = ax_dict2['a']
ax = ax_dict['f']
kernal_size = int((utils.min_time)/300 * 10) # in seconds
plotting.plot_smooth_time_domain(np.linspace(0, 5, utils.min_time), dvels_pd3, data_len=utils.n_keys, 
                                 axis=0, kernal_size=kernal_size,ax=ax, lw =1, color = utils.colors['pD_3'], remove_axes=False)
plotting.plot_smooth_time_domain(np.linspace(0, 5, utils.min_time), dvels_pd4, data_len=utils.n_keys, 
                                 axis=0, kernal_size=kernal_size,ax=ax, lw =1, color = utils.colors['pD_4'], remove_axes=False, ls = '--')

# fig_vel2
fig_penalty

In [None]:
from matplotlib import cm
# Function to determine color based on slope
def get_color(slope, cmap):
    norm = plt.Normalize(-40, 20)  # Assuming slopes range from -1 to 1
    # slope = abs(slope)
    return cmap(norm(slope))

In [None]:
dvels_pd3.shape

In [None]:
# pd3 = high
# pd4 = low
ax = ax_dict['h']

ax.xaxis.set_major_locator(ticker.MultipleLocator(base=3))
ax.xaxis.set_minor_locator(ticker.MultipleLocator(base=1))


cmap = cm.get_cmap('Greys')  # You can choose any colormap


for iK in range(utils.n_keys):
    x2 = E_norm_pd3_subj[iK] # expect higher norm
    y2 = dvels_pd3_subj[iK] # expect lower velocity

    x1 = E_norm_pd4_subj[iK] # low, expect lower norm
    y1 = dvels_pd4_subj[iK] # low, expect higher velocity
    slope = (y2 - y1)/(x2 - x1)
    print(keys[iK])
    print(slope)

    ax.scatter(x1 - x2, y1 - y2, color = 'black', s = 8)


ax.set_xlabel("Change in E $|E_{low}| - |E_{high}|$")
ax.set_ylabel("Change in v $v_{low} - v_{high}$")

ax.vlines(x=0, ymin=0, ymax=5, color = 'gray', ls= '--')
ax.set_ylim(-.5, 5)


fig_penalty

In [None]:
image_format = 'pdf' # e.g .png, .svg, etc.
image_name = 'fig5-penalty-effect-cm.pdf'
PATH = '/results/'
fig_penalty.savefig(PATH + image_name, format=image_format, dpi=300)