In [None]:
def get_pds(activations,start_idx,end_idx):
    # get mean activation from start_idx to end in each direction for each neuron
    mean_activation = np.mean(activations[:,start_idx:end_idx,:],axis=1)
    pds = np.zeros((mean_activation.shape[1],))
    depths = np.zeros((mean_activation.shape[1],))
    # fit mean fr with cosine model
    t=np.array(degs)
    guess_amp = 1; guess_phase = 0; guess_mean = 0;
    for unit in range(mean_activation.shape[1]):
        fit_data = mean_activation[:,unit]
        optimize_fn = lambda x: x[0]*np.cos(t+x[1]) + x[2] - fit_data
        fit = least_squares(optimize_fn,x0=[guess_amp,guess_phase,guess_mean],bounds=([0,-np.pi,-1],[3,np.pi,1]))        
        pds[unit] = fit.x[1]
        depths[unit] = fit.x[0]
        # use theta param as pd
    return pds, depths

def make_activation_grid(activations,cell,args):
    activation_grid = np.zeros((activations.shape[0],activations.shape[1],args.latent_shape[0],args.latent_shape[1]))
    pos_lis = cell.pos.astype("int")
    for i in range(activation_grid.shape[0]):
        for j in range(activation_grid.shape[1]):
            for k in range(pos_lis.shape[0]):
                activation_grid[i,j,pos_lis[k,0],pos_lis[k,1]] = activations[i,j,k]
    
    return activation_grid

def shuffle_map(act_map):
    act_map = np.transpose(act_map)
    act_map = np.random.permutation(act_map)
    act_map = np.transpose(act_map)
    return act_map

def get_lateral_loss(act_map):
    A = tf.matmul(tf.convert_to_tensor(act_map,dtype="float32"),cell.lateral_effect)
    # loss is -1*mean of diagonal elements
    return -1*tf.reduce_mean(tf.linalg.diag_part(A))

def evaluate_topography(activations):
    num_steps = activations.shape[1]
    shuffled_loss = np.zeros((num_steps,num_shuffles))
    act_loss = np.zeros((num_steps,))
    act_loss_perc = np.zeros_like(act_loss)
    # for each step in time
    for i_step in range(num_steps):
        # shuffle map some number of times
        activation_map = activations[:,i_step,:]
        act_loss[i_step] = get_lateral_loss(activation_map)
        
        for i_shuffle in range(num_shuffles):
            shuffled_map = shuffle_map(activation_map)
            shuffled_loss[i_step,i_shuffle] = get_lateral_loss(shuffled_map)
        # get percentile of act_loss
        act_loss_perc[i_step] = percentileofscore(shuffled_loss[i_step,:],act_loss[i_step])
        
        
    return act_loss, shuffled_loss, act_loss_perc


# stimulation experiments    
def mean_squared_difference(x,y):
    # x,y is a tgt x time x signal matrix, get meansquared difference between x,y over time in all tgt conditions across signals
    return np.mean((x-y)**2,axis=2)

def get_recovery_time(msd): # from end of stim
    # threshold is based on msd at end of trial
    threshold = 0.25*np.max(np.max(msd))
    # find first time under threshold post end of stim
    under_threshold = np.argwhere(msd[:,stim_params.stim_time[1]:]<=threshold)
    recov_time = np.zeros((msd.shape[0],),dtype=np.int)
    
    for tgt_idx in range(msd.shape[0]):
        is_tgt = np.nonzero(under_threshold[:,0]==tgt_idx)
        if(len(is_tgt[0])==0):
            recov_time[tgt_idx] = msd.shape[1]-stim_params.stim_time[1]+1
        else:
            recov_time[tgt_idx] = under_threshold[is_tgt[0][0],1]
    return recov_time
   
def get_stim_effect(msd):
    return np.mean(msd[:,stim_params.stim_time[0]:stim_params.stim_time[1]],axis=1)
    
def get_metrics(x,y):
    # get msd
    msd = mean_squared_difference(x,y)
    
    # get recovery time based on msd
    recov_time = get_recovery_time(msd)
    stim_effect = get_stim_effect(msd)
    return (msd,recov_time,stim_effect)


def get_stim_data():
    stim_params.is_stim=False
    cell.reset_counter()
    activations = activation_model.predict(x_test)
    rnn_activations_no_stim = activations[1]
    emg_pred_no_stim = activations[-1]

    stim_params.is_stim=True
    cell.reset_counter()
    activations = activation_model.predict(x_test)
    rnn_activations_stim = activations[1]
    emg_pred_stim = activations[-1]
    
    emg_metrics = get_metrics(emg_pred_no_stim,emg_pred_stim)
    rnn_metrics = get_metrics(rnn_activations_no_stim,rnn_activations_stim)
    
    return emg_metrics, rnn_metrics
    
def run_stim_exp():    
    emg_recov = np.zeros((stim_pos_test.shape[0],len(stim_dist_tau_test),len(stim_duration_test),x_test.shape[0]))
    rnn_recov = np.zeros_like(emg_recov)
    emg_stim_effect = np.zeros_like(emg_recov)
    rnn_stim_effect = np.zeros_like(emg_recov)
    emg_msd = np.zeros((stim_pos_test.shape[0],len(stim_dist_tau_test),len(stim_duration_test),x_test.shape[0],x_test.shape[1]))
    rnn_msd = np.zeros_like(emg_msd)
    
    
    for i_pos in range(stim_pos_test.shape[0]):
        print(i_pos/stim_pos_test.shape[0])
        stim_params.stim_pos = stim_pos_test[i_pos,:]
        for i_tau in range(len(stim_dist_tau_test)):
            stim_params.stim_dist_tau = stim_dist_tau_test[i_tau]
            for i_dur in range(len(stim_duration_test)):
                stim_params.stim_time = [go_cue_idx+10,go_cue_idx+10+stim_duration_test[i_dur]-1]
                
                emg_metrics, rnn_metrics = get_stim_data()
                emg_msd[i_pos,i_tau,i_dur,:,:] = emg_metrics[0]
                emg_recov[i_pos,i_tau,i_dur,:] = emg_metrics[1]
                emg_stim_effect[i_pos,i_tau,i_dur,:] = emg_metrics[2]
                rnn_msd[i_pos,i_tau,i_dur,:,:] = rnn_metrics[0]
                rnn_recov[i_pos,i_tau,i_dur,:] = rnn_metrics[1]
                rnn_stim_effect[i_pos,i_tau,i_dur,:] = rnn_metrics[2]

    
    print("done")
    return emg_msd,rnn_msd,emg_recov, emg_stim_effect, rnn_recov, rnn_stim_effect


def run_example_stim_exp():
    stim_params.is_stim = True
    cell.reset_counter()
    activations = activation_model.predict(x_test)
    rnn_activations_stim = activations[1]
    y_pred_stim = activations[-1]
    rnn_activation_grid_stim = make_activation_grid(rnn_activations_stim,cell,args)
    
    stim_params.is_stim=False
    cell.reset_counter()
    activations = activation_model.predict(x_test)
    rnn_activations_no_stim = activations[1]
    y_pred_no_stim = activations[-1]
    rnn_activation_grid_no_stim = make_activation_grid(rnn_activations_no_stim,cell,args)
    
    return rnn_activations_stim,rnn_activations_no_stim,\
        rnn_activation_grid_stim, rnn_activation_grid_no_stim, \
        y_pred_stim,y_pred_no_stim
    
    
def get_recorded_neuron(stim_pos,cell_pos):
    # assign prob of recording based on distance
    max_dist = 5 # no real units....
    
    dist_to_stim = euclidean_distances(cell_pos,stim_pos.reshape(1,-1))
    
    within_max_dist = dist_to_stim < max_dist
    prob_sample = np.exp(-dist_to_stim)*within_max_dist
    # make probabilities sum to 1
    prob_sample = prob_sample/np.sum(prob_sample)
    
    # sample neuron and return idx in cell_pos
    rec_neuron = np.random.choice(prob_sample.shape[0],p=prob_sample[:,0])
    
    return rec_neuron



def get_null_potent_space(neural_act, muscle_act):
    
    neural_act_flat = np.reshape(neural_act,(neural_act.shape[0]*neural_act.shape[1],neural_act.shape[2]))
    muscle_act_flat = np.reshape(muscle_act,(muscle_act.shape[0]*muscle_act.shape[1],muscle_act.shape[2]))

    # do linear regression in lower dimensional space
    reg = LinearRegression().fit(neural_act_flat,muscle_act_flat)

    W = reg.coef_

    # get potent and null space of W
    null_space = linalg.null_space(W)
    potent_space = linalg.orth(np.transpose(W))

    return null_space, potent_space
    
    
def project_into_space(neural_act, space):
    # reshape to 2 dims
    neural_act_flat = np.reshape(neural_act,(neural_act.shape[0]*neural_act.shape[1],neural_act.shape[2]))
    
    # project neural data into potent and null space
    neural_proj = np.matmul(neural_act_flat, space)

    # undo reshape
    neural_proj = np.reshape(neural_proj,(neural_act.shape[0],neural_act.shape[1],neural_proj.shape[1]))
    
    return neural_proj

def flatten_activation_mat(act_mat):
    return np.reshape(act_mat,(act_mat.shape[0]*act_mat.shape[1],act_mat.shape[2]))

def unflatten_activation_mat(act_mat, orig_mat):
    return np.reshape(act_mat,(orig_mat.shape[0],orig_mat.shape[1],act_mat.shape[1]))