In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.special import logsumexp
from scipy.stats import norm
import random

# Set some defaults
plt.rc("axes.spines", top=False, right=False)

sns.set_theme(context="paper", font_scale=1.2)
sns.set_style("ticks")

# Import python library/function for function optimization
import scipy.optimize

%config InlineBackend.figure_format = "retina"

## The Plan

### Parameter and model recovery (Can I do both at same time?)
1) Simulate data using best-fit parameters from each model (or samples from within the range of best-fits). For example, 100 simulations of each model within range of best-fit parameters.
2) Re-estimate parameters.
3) Model comparison with AIC/BIC. 
4) Correlate parameters used for simulation with recovered parameters.


In [2]:
# Read in vmr data
df = pd.read_csv("../results/vmr_all.csv")

# Read in csv with MLEs from all models
# Need a function to convert string back to numpy array
def converter(input_str):
    return np.fromstring(input_str[1:-1], sep=' ')

fits = pd.read_csv("../results/params_mle.csv", converters={"theta":converter})
fits.head()


Unnamed: 0,subj_num,model,theta,loglik,bic,delta_bic
0,1,pea,"[7.49231206, 0.22499953]",-2997.492434,6009.209524,7.784242
1,1,premo,"[1.0, 2.73091915, 0.5, 4.63727082, 1.0]",-2982.931822,6001.425282,0.0
2,1,rem,"[1.0, 17.73350857, 10.0, 1.14074609]",-2997.678305,6023.80592,22.380638
3,1,piece,"[5.64750593, 0.05, 4.05411942]",-2752.690198,5526.717378,-474.707904
4,2,pea,"[2.5712943, 0.57002164]",-2979.736177,5973.697008,-90.044514


In [4]:
from scipy.stats import norm

def pea(sigma_int, B, sigma_motor, num_trials, vis_fb, rotation, fit=False, x_hand=None):
    '''
    Returns:
        x_stl : state estimate
        x_hand : motor output
    '''
    T = 0
    x_stl = np.zeros(num_trials)
    
    if fit == True:
        x_hand = x_hand
    else:
        x_hand = np.zeros(num_trials)
    
    for i in range(num_trials - 1):
        if vis_fb[i] == 0:
            x_v = 0
            sigma_v = np.inf
        else:
            x_v = x_hand[i] + rotation[i]
            sigma_v = 1.179 + 0.384 * np.abs(x_v)  # after Zhang et al
        
        # Compute estimated hand position
        w_int = (1 / sigma_int**2) / (1 / sigma_int**2 + 1 / sigma_v**2)
        w_v = (1 / sigma_v**2) / (1 / sigma_int**2 + 1 / sigma_v**2)
        xhat_hand = w_v * x_v 
        
        # Update rule
        x_stl[i + 1] = B * (T - xhat_hand)
        if fit == False:
            x_hand[i + 1] = x_stl[i + 1] + np.random.normal(0, sigma_motor)
        
    return x_stl, x_hand


def premo(B, sigma_v, sigma_p, sigma_pred, eta_p, sigma_motor, 
          num_trials, vis_fb, rotation, fit=False, x_hand=None):
    '''
    Model parameters
        B
        sigma_v
        sigma_p
        sigma_pred
        eta_p
    '''
    
    beta_p_sat = 5
    T = 0
    x_stl = np.zeros(num_trials)
    sigma_u = sigma_pred
    
    if fit == True:
        x_hand = x_hand
    else:
        x_hand = np.zeros(num_trials)
    
    for i in range(num_trials - 1):
        if vis_fb[i] == 0:
            x_v = 0
            sigma_v = 1e2
        else:
            x_v = x_hand[i] + rotation[i]
        x_p = x_hand[i]
        
        # Precisions
        J_v = 1 / sigma_v**2
        J_p = 1 / sigma_p**2
        J_u = 1 / sigma_u**2
  
        # Weights for each modality
        w_v = J_v / (J_v + J_u)
        w_p = J_p / (J_p + J_u)
    
        # beta_p is proprio shift due to crossmodal recal from vision
        if x_v > x_p:
            beta_p = np.min([np.abs(beta_p_sat), np.abs(eta_p * (w_v * x_v - w_p * x_p))])
        else:    
            beta_p = -np.min([np.abs(beta_p_sat), np.abs(eta_p * (w_v * x_v - w_p * x_p))])
        x_prop_per = w_p * x_p + beta_p  # Perceived hand position 

        x_stl[i + 1] = B * (T - x_prop_per)
        if fit == False:
            x_hand[i + 1] = x_stl[i + 1] + np.random.normal(0, sigma_motor)
    
    return x_stl, x_hand


def rem(B, sigma_comb, s, c, sigma_motor, num_trials, vis_fb, rotation, 
        fit=False, x_hand=None):
    '''
    Model parameters
        B
        sigma_Comb
        s 
        c
    '''
    
    T = 0
    x_stl = np.zeros(num_trials)
    
    if fit == True:
        x_hand = x_hand
    else:
        x_hand = np.zeros(num_trials)
    
    for i in range(num_trials - 1):
        if vis_fb[i] == 0:
            x_v = 0
        else:
            x_v = x_hand[i] + rotation[i]
        x_p = x_hand[i]

        p_rel = s * norm.pdf(x_v, loc=x_v, scale=sigma_comb) / (norm.pdf(x_v, loc=x_v, scale=sigma_comb) + c)
        x_stl[i + 1] = B * (T - p_rel * x_v)
        if fit == False:
            x_hand[i + 1] = x_stl[i + 1] + np.random.normal(0, sigma_motor)
        
    return x_stl, x_hand


def piece(sigma_pert, sigma_pred, sigma_p, sigma_motor, num_trials, vis_fb, rotation, fit=False, x_hand=None):
    '''
    Model parameters
        sigma_pert
        sigma_pred
        sigma_p
        sigma_motor
         
    Returns
        x_state : state estimate
        x_f : motor output
    '''
    
    # Function for computing Gaussian log-probabilities
    f = lambda x, mu, sigma: -0.5 * np.log(2 * np.pi * sigma**2) - 0.5 * (x - mu)**2 / sigma**2
    
    # Possible endpoint locations
    x_grid = np.arange(-15, 15, 0.1)  

    # For vectorized code
    x_fs = x_grid.reshape((len(x_grid), 1))  # possible finger endpoint locations (col vec)
    d_xvs = x_grid.reshape((1, len(x_grid)))  # possible rotation sizes (row vec)
    x_vs = x_grid  # possible locations of visual cues
    x_ps = x_grid  # possible locations of proprioceptive cues
    x_us = x_grid  # possible locations of predictive cues

    # Ideal observer     
    x_f_hat = np.zeros(num_trials)
    x_state = np.zeros(num_trials)
    sigma_v = np.zeros(num_trials)
    K = np.zeros(num_trials)
    like_nopert_analytical = []
    mu_pert = 0
    b = 0
    prior_pert = 0.5
    
    # If fitting model to data, use actual hand data
    if fit == True:
        x_f = x_hand
    else: 
        x_f = np.zeros(num_trials) 
        x_f[0] = np.random.normal(0, sigma_motor)
    
    # Loop through trials
    for i in range(num_trials - 1):
        xhat_pred = x_f[i]
        xhat_p = x_f[i]

        if vis_fb[i] == 0:
            xhat_v = 0
            sigma_v[i] = 1e2
        else:
            xhat_v = x_f[i] + rotation[i]
            sigma_v[i] = 1.179 + 0.384 * np.abs(xhat_v)  # After Zhang et al 
        J_v = 1 / sigma_v[i]**2
        J_pert = 1 / sigma_pert**2
        K[i] = J_v / (J_v + J_pert)

        # Compute no perturbation likelihood (working with log-probs for numerical accuracy)
        loglik_nopert = (
            f(xhat_pred, x_fs, sigma_pred) + f(xhat_v, x_fs, sigma_v[i]) 
            + f(xhat_p, x_fs, sigma_p) + f(x_fs, b, sigma_motor) 
        )
        loglik_nopert = logsumexp(loglik_nopert.flatten(), b=0.1)
        likelihood_nopert = np.exp(loglik_nopert)

        # Compute perturbation likelihood
        loglik_pert = (
            f(xhat_pred, x_fs, sigma_pred) + f(xhat_v, x_fs + d_xvs, sigma_v[i]) 
            + f(xhat_p, x_fs, sigma_p) + f(d_xvs, mu_pert, sigma_pert) 
            + f(x_fs, b, sigma_motor)  
        )
        loglik_pert = logsumexp(loglik_pert.flatten(), b=0.01)
        
        # To account for no vis fb trials
        if vis_fb[i] == 1:
            likelihood_pert = np.exp(loglik_pert)
        else:
            likelihood_pert = 0

        # Posterior over Causal node
        normalization_const = prior_pert * likelihood_pert + ((1 - prior_pert) * likelihood_nopert)
        post_pert = (prior_pert * likelihood_pert) / normalization_const  # Posterior over Cause

        # Simulate trial-by-trial adaptation
        x_state[i + 1] = post_pert * K[i] * (rotation[i]) * -1
        if fit == False: 
            x_f[i + 1] = x_state[i + 1] + np.random.normal(0, sigma_motor)
        
    return x_state, x_f


def negloglik_test(B=None, sigma_motor=None, sigma_int=None, sigma_pert=None, 
                   sigma_pred=None, sigma_p=None, sigma_v=None, sigma_comb=None,
                   s=None, c=None, model=None, num_trials=None, rotation=None, 
                   fit=True,  vis_fb=None, eta_p=None, beta_p_sat=None, x_hand=None):
    if model == "pea":
        mu, _ = pea(sigma_int, B, sigma_motor, num_trials, vis_fb, rotation, 
                    fit, x_hand=x_hand)
    elif model == "piece":
        mu, _ = piece(sigma_pert, sigma_pred, sigma_p, sigma_motor, num_trials, 
                      vis_fb, rotation, fit, x_hand=x_hand)
    elif model == "ssm":
        mu, _ = ssm_ege(B, sigma_motor, rotation)
    elif model == "premo":
        mu, _ = premo(B, sigma_v, sigma_p, sigma_pred, eta_p, sigma_motor, 
                      num_trials, vis_fb, rotation, fit, x_hand=x_hand)
    elif model == "rem":
        mu, _ = rem(B, sigma_comb, s, c, sigma_motor, num_trials, vis_fb, rotation, 
                    fit, x_hand=x_hand)
    
    nll = -np.sum(-1 / 2 * np.log(2 * np.pi * sigma_motor**2) - ((x_hand - mu)**2 / (2 * sigma_motor**2)))
    
    return nll


def calc_bic(ll, num_params, num_trials):
    bic = -2 * ll + num_params * np.log(num_trials)
    
    return bic

In [3]:
# Lambda functions for computing negative log-likelihoods
nll_pea = lambda x: negloglik_test(model="pea", sigma_int=x[0], B=x[1],  
                                   sigma_motor=motor_sd, num_trials=len(subj), 
                                   vis_fb=vis_fb, rotation=rotation, x_hand=x_hand)
nll_premo = lambda x: negloglik_test(model="premo", B=x[0], sigma_pred=x[1], 
                                     sigma_v=x[2], sigma_p=x[3], eta_p=x[4], 
                                     sigma_motor=motor_sd, num_trials=len(subj), 
                                     vis_fb=vis_fb, rotation=rotation, 
                                     x_hand=x_hand)
nll_piece = lambda x: negloglik_test(model="piece", sigma_pert=x[0], sigma_pred=x[1], 
                                     sigma_p=x[2], sigma_motor=motor_sd, 
                                     num_trials=len(subj), vis_fb=vis_fb, 
                                     rotation=rotation, x_hand=x_hand)
nll_rem = lambda x: negloglik_test(model="rem", B=x[0], sigma_comb=x[1], s=x[2], c=x[3], 
                                   sigma_motor=motor_sd, num_trials=len(subj), 
                                   vis_fb=vis_fb, rotation=rotation, x_hand=x_hand)

In [8]:
models = ["piece", "pea", "premo", "rem"]
subj_num = []
model = []
winner = []

# Model recovery analysis
for i in range(len(np.unique(df["SN"]))):
    print(i)
    
    # Create dataframe with one subject's data
    sid = i + 1
    subj = df.loc[df["SN"] == sid, :].reset_index(drop=True)

    # Extract important variables
    motor_sd = subj.loc[0, "motor_sd"]
    x_hand = subj["theta_maxradv_clean"].values
    rotation = subj["rotation"].values
    vis_fb = subj["fbi"].values

    # Pull out best-fit params for chosen subject:
    params_idx = fits["subj_num"] == sid 
    params = fits.loc[params_idx, ["theta", "model"]].iloc[0:4, :].reset_index(drop=True)
    
    # Pull out MLEs for each model
    pea_mle = params.loc[params["model"] == "pea", "theta"].values[0]
    premo_mle = params.loc[params["model"] == "premo", "theta"].values[0]
    piece_mle = params.loc[params["model"] == "piece", "theta"].values[0]
    rem_mle = params.loc[params["model"] == "rem", "theta"].values[0]
    
    # Find perturbation trials
    pert_indices = np.flatnonzero(subj["perturbation"])

    # Loop through models
    for j in range(4):
        if j == 0:
            # Simulate with best-fit parameters
            # PIECE model
            _, xhat = piece(
                piece_mle[0], 
                piece_mle[1], 
                piece_mle[2], 
                motor_sd, len(subj), 
                subj["fbi"], 
                subj["rotation"],
                fit=False
            )
            simulated_model = "piece"
        elif j == 1:
            # PEA model
            _, xhat = pea(
                pea_mle[0], 
                pea_mle[1], 
                motor_sd, len(subj), 
                subj["fbi"], 
                subj["rotation"],
                fit=False
            )
            simulated_model = "pea"
        elif j == 2:
            # PReMo 
            _, xhat = premo(
                premo_mle[0],
                premo_mle[1],
                premo_mle[2], 
                premo_mle[3],
                premo_mle[4], 
                motor_sd, len(subj), 
                subj["fbi"], 
                subj["rotation"],
                fit=False
            )
            simulated_model = "premo"
        elif j == 3:
            # REM
            _, xhat = rem(
                rem_mle[0],
                rem_mle[1],
                rem_mle[2], 
                rem_mle[3], 
                motor_sd, len(subj), 
                subj["fbi"], 
                subj["rotation"],
                fit=False
            )
            simulated_model = "rem"
        
        # Assign simulated hand position to correct var name
        x_hand = xhat
        
        # Fit simulated data
        # PIECE model params: sigma_pert, sigma_pred, sigma_p
        bounds = ((0.05, 30), (0.05, 10), (0.05, 25))
        piece_results = scipy.optimize.minimize(
            fun=nll_piece, 
            bounds=bounds,
            x0=np.array([np.random.uniform(low=bounds[0][0], high=bounds[0][1]),
                         np.random.uniform(low=bounds[1][0], high=bounds[1][1]),
                         np.random.uniform(low=bounds[2][0], high=bounds[2][1])])
        )
        bic_piece = calc_bic(piece_results.fun * -1, len(piece_results.x), len(subj))
        print(f"{piece_results.message}")
        
        # PEA model params: sigma_comb, B
        bounds = ((0.5, 25), (0, 1))
        pea_results = scipy.optimize.minimize(
            fun=nll_pea, 
            bounds=bounds,
            x0=np.array([np.random.uniform(low=bounds[0][0], high=bounds[0][1]),
                         np.random.uniform(low=bounds[1][0], high=bounds[1][1])])
        )
        bic_pea = calc_bic(pea_results.fun * -1, len(pea_results.x), len(subj))
        print(f"{pea_results.message}")
        
        # PReMo model params: B, sigma_v, sigma_p, sigma_pred, eta_p
        bounds = ((0, 1), (0.05, 10), (0.5, 25), (0.5, 25), (0, 1))
        premo_results = scipy.optimize.minimize(
            fun=nll_premo,  
            bounds=bounds,
            x0=np.array([np.random.uniform(low=bounds[0][0], high=bounds[0][1]),
                         np.random.uniform(low=bounds[1][0], high=bounds[1][1]),
                         np.random.uniform(low=bounds[2][0], high=bounds[2][1]),
                         np.random.uniform(low=bounds[3][0], high=bounds[3][1]),
                         np.random.uniform(low=bounds[4][0], high=bounds[4][1])])
        )
        bic_premo = calc_bic(premo_results.fun * -1, len(premo_results.x), len(subj))
        print(f"{premo_results.message}")
        
        # REM model params: B, sigma_comb, s, c
        bounds = ((0, 1), (0.5, 25), (0, 10), (0, 10))
        rem_results = scipy.optimize.minimize(
            fun=nll_rem, 
            bounds=bounds,
            x0=np.array([np.random.uniform(low=bounds[0][0], high=bounds[0][1]),
                         np.random.uniform(low=bounds[1][0], high=bounds[1][1]),
                         np.random.uniform(low=bounds[2][0], high=bounds[2][1]),
                         np.random.uniform(low=bounds[3][0], high=bounds[3][1])])
        )
        bic_rem = calc_bic(rem_results.fun * -1, len(rem_results.x), len(subj))
        print(f"{rem_results.message}")
        
        # Store winning model
        winner.append(models[np.argmin([bic_piece, bic_pea, bic_premo, bic_rem])])

        # Update variables
        subj_num.append(sid)
        model.append(simulated_model)

df_model = pd.DataFrame({
    "subj_num":subj_num, 
    "model":model, 
    "winner":winner  
})

0
CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
1
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
CONVERGENCE: REL_REDUCTION_OF_F

In [9]:
df_model

Unnamed: 0,subj_num,model,winner
0,1,piece,piece
1,1,pea,pea
2,1,premo,premo
3,1,rem,pea
4,2,piece,piece
...,...,...,...
59,15,rem,pea
60,16,piece,piece
61,16,pea,pea
62,16,premo,pea


In [10]:
print(df_model.head(40))

    subj_num  model winner
0          1  piece  piece
1          1    pea    pea
2          1  premo  premo
3          1    rem    pea
4          2  piece  piece
5          2    pea    pea
6          2  premo  premo
7          2    rem    pea
8          3  piece  piece
9          3    pea    pea
10         3  premo    pea
11         3    rem    pea
12         4  piece  piece
13         4    pea    pea
14         4  premo  premo
15         4    rem    pea
16         5  piece  piece
17         5    pea    pea
18         5  premo  premo
19         5    rem    pea
20         6  piece  piece
21         6    pea    pea
22         6  premo  premo
23         6    rem    pea
24         7  piece  piece
25         7    pea    pea
26         7  premo  premo
27         7    rem    pea
28         8  piece  piece
29         8    pea    pea
30         8  premo  premo
31         8    rem    pea
32         9  piece  piece
33         9    pea    pea
34         9  premo  premo
35         9    rem    pea
3

In [11]:
df_model.iloc[40:, :]

Unnamed: 0,subj_num,model,winner
40,11,piece,piece
41,11,pea,pea
42,11,premo,premo
43,11,rem,pea
44,12,piece,piece
45,12,pea,pea
46,12,premo,premo
47,12,rem,pea
48,13,piece,piece
49,13,pea,pea
