__Purpose:__ Before implementing Federated Learning, re-implement the original simulatinos on the continuous data task we've been working off of to ensure that we implement the correct loss functions in the federated learning task. Simulated decoders should match the decoders found in Ws_block1 from the CPHS data pickle file.
<br>
1. The dec matrix is the weights to pass back and forth, although it comes out of SmoothBatch first
1. We are assuming we can test on the second half (updates 10-19ish) since learning should be complete by then!
1. Scipy.optimize.minimize() runs many iters to fully minimize its cost function.  You can change it to run as many iters as you'd like, although AFAIK you won't know how many it takes to converge.  But this is still a good set up for FL.
1. Hmm minimize() is doing BFGS rn and not SGD... not sure if that matters really.  Could probably implement SGD on my own or find it.  BFGS is 2nd order but we don't have a lot of parameters, I don't think.  Plus we can (already have?) solved analytically for the Hessian.  I think.

In [1]:
import pandas as pd
import os
import numpy as np
#from numpy.matlib import repmat
#from matplotlib import pyplot as plt
#from scipy.signal import detrend, firwin, freqz, lfilter
#from sklearn.model_selection import train_test_split, ShuffleSplit
from scipy.optimize import minimize, least_squares
import copy
from itertools import permutations

In [2]:
from experiment_params import *
from simulations import *
import time
# Do the below if you're in the pytch environment
#import pickle5 as pickle
import pickle

# Reminder of Conditions Order

NOTE: 

* **CONDITIONS** = array(['D_1', 'D_2', 'D_5', 'D_6', 'D_3', 'D_4', 'D_7','D_8']
* **LEARNING RATES:** alpha = 0.25 and 0.75; alpha = 0.25 for D1, D2, D5, D6; alpha = 0.75 for D3, D4, D7, D8
* **SMOOTHBATCH:** W_next = alpha*W_old + ((1 - alpha) * W_calc)

* **DECODER INIT:** pos for D1 - D4, neg for D5 - D8

* **PENALTY TERM:** $\lambda_E$ = 1e-6 for all, $\lambda_F$ = 1e-7 for all, $\lambda_D$ = 1e-3 for 1, 3, 5, 7 and 1e-4 for 2, 4, 6, 8 


| DECODER | ALPHA | PENALTY | DEC INIT |
| --- | --- | --- | --- |
| 1 | 0.25 | 1e-3 | + |
| 2 | 0.25 | 1e-4 | + |
| 3 | 0.75 | 1e-3 | + |
| 4 | 0.75 | 1e-4 | + |
| 5 | 0.25 | 1e-3 | - |
| 6 | 0.25 | 1e-4 | - |
| 7 | 0.75 | 1e-3 | - |
| 8 | 0.75 | 1e-4 | - |


## Load Our Data In

In [5]:
t0 = time.time()

with open('Data\continuous_full_data_block1.pickle', 'rb') as handle:
    #refs_block1, poss_block1, dec_vels_block1, int_vel_block1, emgs_block1, Ws_block1, Hs_block1, alphas_block1, pDs_block1, times_block1, conditions_block1 = pickle.load(handle)
    refs_block1, _, _, _, emgs_block1, Ws_block1, _, _, _, _, _ = pickle.load(handle)

#with open('Data\continuous_full_data_block2.pickle', 'rb') as handle:
    #refs_block2, poss_block2, dec_vels_block2, int_vel_block2, emgs_block2, Ws_block2, Hs_block2, alphas_block2, pDs_block2, times_block2, conditions_block2 = pickle.load(handle)
    #refs_block2, _, _, _, emgs_block2, Ws_block2, _, _, _, _, _ = pickle.load(handle)

t1 = time.time()
total = t1-t0  
print(total)

10.384710788726807


In [6]:
# 8 conditions, 20770 data points (only 19 unique sets!), xy, channels
Ws_block1[keys[0]].shape

(8, 20770, 2, 64)

In [7]:
update_ix

array([    0,  1200,  2402,  3604,  4806,  6008,  7210,  8412,  9614,
       10816, 12018, 13220, 14422, 15624, 16826, 18028, 19230, 20432,
       20769])

In [8]:
dec_cond0_user1_update0 = Ws_block1[keys[0]][0,0,:,:]
dec_cond0_user1_update1 = Ws_block1[keys[0]][0,update_ix[1],:,:]
dec_cond0_user1_update2 = Ws_block1[keys[0]][0,update_ix[2],:,:]

print(f"Shape of decoder: {dec_cond0_user1_update0.shape}")
print()
print(f"Total difference between dec0 and dec1: {(dec_cond0_user1_update0 - dec_cond0_user1_update1).sum()}")
print("E.g., as previously shown, the first two decs are the same")
print()
print(f"Total difference between dec0 and dec2: {(dec_cond0_user1_update0 - dec_cond0_user1_update2).sum()}")

Shape of decoder: (2, 64)

Total difference between dec0 and dec1: 0.0
E.g., as previously shown, the first two decs are the same

Total difference between dec0 and dec2: 3.1981579823181594


In [9]:
#emg_cond0_user1_update0 = emg_data_df.iloc[:64,:].shape

# (Condition, datapoints, channels)
print(emgs_block1[keys[0]][:,:,:].shape)

# Condition 0 of subject 1 ("0")
print(emgs_block1[keys[0]][0,:,:].shape)

(8, 20770, 64)
(20770, 64)


## Run One Iteration On Above Data and Check Decoders Are the Same
1. Modifying Simulations Code

In [10]:
# Just 1 person
filtered_signals = emgs_block1[keys[0]][0,:,:]
p_reference_full = refs_block1[keys[0]][0,:,:]

print(filtered_signals.shape)
print(p_reference_full.shape)

(20770, 64)
(20770, 2)


In [11]:
# Previously created random decoder, but we are trying to rerun
#D_0 = np.random.rand(2,64)
D_0 = Ws_block1[keys[0]][0,0,:,:]
total_datapoints = emgs_block1[keys[0]][0,:,:].shape[0]

# Original learning batch was 8
learning_batch = update_ix[1] 

In [12]:
# this was the original simulations file code, but not what the conditions were from the study
#alpha = .95 # higher alpha means more old decoder (slower update)
#alphaF = 1e-1
#alphaD = 1e-1

# For condition 1:
alpha = .25 # higher alpha means more old decoder (slower update)
# Assuming these are the same as lambda's, the decoder cost penalties
alphaF = 1e-7
alphaD = 1e-3
#where is lambda E?

In [14]:
D = []
D.append(D_0)

# Added 2 new parameters
#def simulation(D,learning_batch,alpha,alphaF=1e-2,alphaD=1e-2,display_info=False,num_iters=False):
#D  # Already defined
#learning_batch  # Already defined
#alpha  # Already defined
#alphaF=1e-2  #defined as something else earlier...
#alphaD=1e-2  #defined as something else earlier...
display_info=True

#num_updates = int(np.floor((filtered_signals.shape[0]-1)/learning_batch)) # how many times can we update decoder based on learning batch    
num_updates = 19  # This is 19 for us

dt = 1/60

# batches the trials into each of the update batch
# Do num_updates-1 because the very last update is only 1 datapoint, the 2nd to last is only 337
for ix in range(num_updates-1):
    #print(ix)
    # Set to False for less cluttered output when debugging
    display_info = False
    
    # Instead of using learning_batch, we should get the same results just using update_ix values
    lower_bound = update_ix[ix]
    if ix==(num_updates-1):
        upper_bound = total_datapoints
    else:
        upper_bound = update_ix[ix+1]

    s = np.transpose(filtered_signals[lower_bound:upper_bound,:])
    v_actual = D[-1]@s
    # Numerical integration of v_actual to get p_actual
    p_actual = np.sum(v_actual, axis=1)*dt  # dt=1/60
    print(f"p_actual: {p_actual}")
    p_actual = np.reshape(p_actual, (p_actual.shape[0], 1))
    # Update decoder
    p_reference = np.transpose(p_reference_full[lower_bound:upper_bound,:])
    #(r-y)/60=D_new@s  # This is the optimization problem
    V = (p_reference - p_actual)*dt
    F = copy.deepcopy(s[:,:-1]) # note: truncate F for estimate_decoder # why?
    print()
    
    # set alphas
    H = np.zeros((2,2))
    # use scipy minimize for gradient descent and provide pre-computed analytical gradient for speed
    out = minimize(lambda D: cost_l2(F,D,H,V,learning_batch,alphaF,alphaD), D[-1], method='BFGS', jac=lambda D: gradient_cost_l2(F,D,H,V,learning_batch,alphaF,alphaD), options={'disp': display_info})

    # reshape to decoder parameters
    W_hat = np.reshape(out.x,(2, 64))

    # DO SMOOTHBATCH
    W_new = alpha*D[-1] + ((1 - alpha) * W_hat)
    D.append(W_new)

p_actual: [115.36797626 113.00405666]

p_actual: [22.72360033 16.40996318]


  out = minimize(lambda D: cost_l2(F,D,H,V,learning_batch,alphaF,alphaD), D[-1], method='BFGS', jac=lambda D: gradient_cost_l2(F,D,H,V,learning_batch,alphaF,alphaD), options={'disp': display_info})



p_actual: [-7.50248981 -2.26672241]

p_actual: [1.24928953 0.50110169]

p_actual: [ 4.32007666 -3.13001625]

p_actual: [-3.09600118  2.87251262]

p_actual: [3.69345638 0.89407124]

p_actual: [-3.56062261 -1.50971456]

p_actual: [ 1.18209308 -0.99187499]

p_actual: [ 2.06120276 -0.45183494]

p_actual: [-3.36211091 -0.34978566]

p_actual: [-4.77076278  0.34285687]

p_actual: [11.59465639 -3.45077105]

p_actual: [1.81763164 0.03320953]

p_actual: [-0.59338261 -0.30132625]

p_actual: [-1.00304442 -0.08678   ]

p_actual: [-2.62390182  0.42086859]

p_actual: [ 1.59015409 -0.2304428 ]



In [15]:
# The first instance where they could concievable have the same dec value is the 3rd dec in Ws_block (AKA index 2)
print((D[1] - Ws_block1[keys[0]][0,update_ix[2],:,:]).sum())

2.412033777795463


In [16]:
# Check how different the final decs are, this is all we really care about
# Although if the earlier decs are different how could the last ones be the same lol
print((D[-1] - Ws_block1[keys[0]][0,update_ix[-1],:,:]).sum())

2.4766577196241712


In [17]:
# Differences between consecutive decoders

# From this file
print(f"Length of D (sims code): {len(D)}")
print(f"Length of Ws_block1 (cphs code): {len(update_ix)}")
print()
print("Labels;       D (Sims);     Ws (CPHS);     Sim - CPHS")
for i in range(len(D)-2):
    print(f"Dec{i+1} - Dec{i}: {(D[i+1] - D[i]).sum():9.5f};    {(Ws_block1[keys[0]][0,update_ix[i+1],:,:] - Ws_block1[keys[0]][0,update_ix[i],:,:]).sum():9.5f};      {(D[i] - Ws_block1[keys[0]][0,update_ix[i],:,:]).sum():9.5f}")

Length of D (sims code): 19
Length of Ws_block1 (cphs code): 19

Labels;       D (Sims);     Ws (CPHS);     Sim - CPHS
Dec1 - Dec0:  -0.78612;      0.00000;        0.00000
Dec2 - Dec1:   0.06560;     -3.19816;       -0.78612
Dec3 - Dec2:   0.05867;      8.21960;        2.47763
Dec4 - Dec3:   0.01251;     -7.00649;       -5.68330
Dec5 - Dec4:  -0.02612;      2.21186;        1.33571
Dec6 - Dec5:   0.06320;     -3.56196;       -0.90227
Dec7 - Dec6:  -0.02860;     10.80750;        2.72288
Dec8 - Dec7:  -0.01792;    -12.59955;       -8.11321
Dec9 - Dec8:  -0.01211;     11.45546;        4.46842
Dec10 - Dec9:  -0.04393;    -12.38370;       -6.99915
Dec11 - Dec10:   0.05112;     -0.43172;        5.34062
Dec12 - Dec11:   0.00597;     -1.10106;        5.82346
Dec13 - Dec12:   0.02963;      7.55098;        6.93049
Dec14 - Dec13:  -0.07686;      0.16666;       -0.59086
Dec15 - Dec14:   0.01526;     -4.90057;       -0.83438
Dec16 - Dec15:   0.00900;     -0.14916;        4.08145
Dec17 - Dec16:   0.0

In [18]:
# Adding one to account for the fact that Ws_block 0 and 1 are the same.
for i in range(len(D)-2):
    print(f"{(D[i] - Ws_block1[keys[0]][0,update_ix[i+1],:,:]).sum():9.5f}")

  0.00000
  2.41203
 -5.74196
  1.32320
 -0.87615
  2.65969
 -8.08461
  4.48634
 -6.98704
  5.38455
  5.77234
  6.92452
 -0.62049
 -0.75752
  4.06619
  4.23061
 -0.05777
