In [2]:
import pystan
import bebi103
import numpy as np
import bokeh.io
import bokeh.plotting
bokeh.io.output_notebook()
import matplotlib.pyplot as plt
import scipy.io
import trackball_suite
import time

%matplotlib notebook

## Define the dynamic logistic regression model

In [54]:
behavior_inference_model_code = """
data {
  int<lower=1> N;
  int<lower=0, upper=1> y[N];
  vector[N] g1;
  vector[N] g2;
  vector[N] g3;
  vector[N] g4;
}


parameters {
  real<lower=0> sigma;

  vector[N] v1;
  vector[N] v2;
  vector[N] v3;
  vector[N] v4;
}


transformed parameters {
  vector[N] w1;
  vector[N] w2;
  vector[N] w3;
  vector[N] w4;
  vector[N] n1;
  vector[N] n2;
  vector[N] n3;
  vector[N] n4;
  
  w1 = cumulative_sum(n1);
  w2 = cumulative_sum(n2);
  w3 = cumulative_sum(n3);
  w4 = cumulative_sum(n4);
  
    
  n1 = v1 * sigma;
  n2 = v2 * sigma;
  n3 = v3 * sigma;
  n4 = v4 * sigma;
}

model {
  // Priors
  sigma ~ cauchy(0, 2);
  print(1, sigma);
  v1[1] ~ normal(0, 16);
  v2[1] ~ normal(0, 16);
  v3[1] ~ normal(0, 16);
  v4[1] ~ normal(0, 16);

  for (i in 2:N) {
    v1[i] ~ normal(0, 1);
    v2[i] ~ normal(0, 1);
    v3[i] ~ normal(0, 1);
    v4[i] ~ normal(0, 1);
  }

  
  //print(1, prob);
  print(2, target());
  // Likelihood
  y ~ bernoulli_logit(g1 .* w1 + g2 .* w2 + g3 .* w3 + g4 .* w4);
}
"""

beh_inf = pystan.StanModel(model_code=behavior_inference_model_code)

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_680c9e6c8076b2f07158f9a0cdb1372e NOW.


## Generate fake data to test the inference

In [4]:
# Fake data for simple inference
#np.random.seed(12334)
N = 500
sigma1 = 1 / 32
sigma2 = 1 / 64
sigma3 = 1 / 128
sigma4 = 1 / 64
sigmaG = 1

n1 = np.random.normal(0, sigma1, size=N)
n2 = np.random.normal(0, sigma2, size=N)
n3 = np.random.normal(0, sigma3, size=N)
n4 = np.random.normal(0, sigma4, size=N)

w1 = np.cumsum(n1) + np.random.normal(0, 1)
w2 = np.cumsum(n2) + np.random.normal(0, 1)
w3 = np.cumsum(n3) + np.random.normal(0, 1)
w4 = np.cumsum(n4) + np.random.normal(0, 1)

g1 = np.random.normal(0, sigmaG, size=N)
g2 = np.random.normal(0, sigmaG, size=N)
g3 = np.random.normal(0, sigmaG, size=N)
g4 = np.random.normal(0, sigmaG, size=N)

evidence = g1 * w1 + g2 * w2 + g3 * w3 + g4 * w4;
prob = 1.0/(1.0 + np.exp(-evidence))
prob = prob * 0.998 + 0.001
samples = np.random.rand(len(prob))
yvals = (samples < prob).astype('int')

In [5]:
plt.figure()
plt.subplot(121)
plt.plot(w1, label='w1')
plt.plot(w2, label='w2')
plt.plot(w3, label='w3')
plt.plot(w4, label='w4')
plt.subplot(122)
plt.plot(prob)
plt.legend()

<IPython.core.display.Javascript object>



<matplotlib.legend.Legend at 0x2ad863a9e10>

## Perform the sampling from the posterior

In [45]:
data_inf = dict(N=N,
               y=yvals.astype('int'),
               g1=g1,
               g2=g2, g3=g3, g4=g4)

init_prob = [dict(prob=np.ones(N) * 0.1, sigma=-10)] * 4
samples_beh_inf = beh_inf.sampling(data=data_inf, warmup=1000, iter=1500, control={'max_treedepth':18}, init=init_prob)

RuntimeError: Exception: mismatch in dimension declared and found in context; processing stage=data initialization; variable name=y; position=0; dims declared=(226); dims found=(500)  (in 'unknown file name' at line 4)


In [25]:
def get_mean_std(samples, tag=''):
    plt.figure(5)
    means = np.mean(samples, axis=0)
    std = np.std(samples, axis=0)
    plt.plot(means)
    plt.errorbar(np.arange(len(means)), means, 1.96 * std, label=tag)
    
    return means, std

In [29]:
# Get samples of w1 and plot
n1samp = samples_beh_inf['w1']
n2samp = samples_beh_inf['w2']
n3samp = samples_beh_inf['w3']
n4samp = samples_beh_inf['w4']

plt.figure(5)
plt.subplot('221')
get_mean_std(n1samp)
plt.plot(w1)
plt.subplot('222')
get_mean_std(n2samp)
plt.plot(w2)
plt.subplot('223')
get_mean_std(n3samp)
plt.plot(w3)
plt.subplot('224')
get_mean_std(n4samp)
plt.plot(w4)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x2ad99e11be0>]

## Apply the model to behavioral data

In [55]:
# Load the data
folder = 'C:\\Users\\Sur lab\\Dropbox (MIT)\\trackball-behavior\\Data\\C13\\laser_analys_FebMar2019_left\\'
filename = '20190220_trackball_0013.mat'
data = scipy.io.loadmat(folder + filename)
group = trackball_suite.SessionGroup(folder)

session = trackball_suite.Session(folder + filename)

stim = session.get_stim() - 1
choice = session.get_choice() - 1
choice[choice == 4] = 0
opp_contrast =session.get_opp_contrast()
laser = session.get_laser() - 1


In [56]:
def get_regressors(session):
    '''Given a session, return the regressors to be used'''
    stim = session.get_stim() - 1
    choice = session.get_choice() - 1
    choice[choice == 4] = 0
    opp_contrast =session.get_opp_contrast()
    
    N = len(choice) - 1
    g1 = np.ones(N) # bias
    g2 = stim[1:] # current stim
    g3 = stim[:-1] # previous stim
    g4 = opp_contrast[1:] #contrast
    y = choice[1:]
     
    return N, g1, g2, g3, g4, y

In [58]:
# Begin to sample
def initfun():
    return dict(sigma=0.05)
#init_prob = [dict(prob=np.ones(N) * 0.1, sigma=-10)] * 4
for i in [1]: #range(len(group.sessions)):
    N, g1, g2, g3, g4, y = get_regressors(group.sessions[i])
    data_inf = dict(N=N,
                   y=y,
                   g1=g1,
                   g2=g2, g3=g3, g4=g4)

    init_prob = [dict(prob=np.ones(N) * 0.1, sigma=0.05)] * 4
    startT = time.time()
    samples_beh_inf = beh_inf.sampling(data=data_inf, warmup=1000, iter=1500, \
                                       control={'max_treedepth':18}, init=initfun)
    endT = time.time()

RuntimeError: Initialization failed.

In [None]:
endT-startT

In [48]:
tausamp =samples_beh_inf['sigma']
plt.figure()

plt.plot(tausamp)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x2ad860c1048>]

In [36]:
# Get samples of w1 and plot
n1samp = samples_beh_inf['w1']
n2samp = samples_beh_inf['w2']
n3samp = samples_beh_inf['w3']
n4samp = samples_beh_inf['w4']

means1, std1 = get_mean_std(n1samp, tag='w1')
means2, std2 = get_mean_std(n2samp, tag='w2')
means3, std3 = get_mean_std(n3samp, tag='w3')
means4, std4 = get_mean_std(n4samp, tag='w4')
plt.figure(3)
plt.legend()

#plt.figure(5)

#plt.errorbar(np.arange(N), means, 1.96 * std)
#plt.plot(means, 'r')
#plt.plot(w1, 'r')
#plt.plot(w2, 'g')
#plt.plot(w3, 'b')
#plt.plot(w4, 'k')

#plt.errorbar(np.arange(N), means3, 1.96 * std3, label='Bias')
#plt.errorbar(np.arange(N), means2, 1.96 * std2, label='Curr stim')
#plt.errorbar(np.arange(N), means3, 1.96 * std3, label='Previous stim')
#plt.errorbar(np.arange(N), means4, 1.96 * std4, label='Contrast')
plt.plot(means1, 'r--', label='Bias')
plt.plot(means2, 'g--', label='Curr stim')
plt.plot(means3, 'b--', label='Opp contrast')
plt.plot(means4, 'k--', label='Prev stim')
#plt.vlines(np.where(lasertag==2)[0], -30, 30)
plt.legend()

<IPython.core.display.Javascript object>



<matplotlib.legend.Legend at 0x2ad92d9f470>

In [21]:
# Get samples of w1 and plot
n1samp = samples_beh_inf['w1']
n2samp = samples_beh_inf['w2']
n3samp = samples_beh_inf['w3']
n4samp = samples_beh_inf['w4']

means1, std1 = get_mean_std(n1samp, tag='w1')
means2, std2 = get_mean_std(n2samp, tag='w2')
means3, std3 = get_mean_std(n3samp, tag='w3')
means4, std4 = get_mean_std(n4samp, tag='w4')
plt.figure(4)
plt.legend()

#plt.figure(5)

#plt.errorbar(np.arange(N), means, 1.96 * std)
#plt.plot(means, 'r')
#plt.plot(w1, 'r')
#plt.plot(w2, 'g')
#plt.plot(w3, 'b')
#plt.plot(w4, 'k')

plt.errorbar(np.arange(N), means1, 1.96 * std1, label='Bias')
plt.errorbar(np.arange(N), means2, 1.96 * std2, label='Curr stim')
plt.errorbar(np.arange(N), means3, 1.96 * std3, label='Previous stim')
plt.errorbar(np.arange(N), means4, 1.96 * std4, label='Contrast')
#plt.plot(means1, 'r--', label='Bias')
#plt.plot(means2, 'g--', label='Curr stim')
#plt.plot(means3, 'b--', label='Opp contrast')
#plt.plot(means4, 'k--', label='Prev stim')
plt.legend()

<IPython.core.display.Javascript object>



<matplotlib.legend.Legend at 0x26e973410f0>

## Diagnostics

In [15]:
bebi103.stan.check_all_diagnostics(samples_beh_inf)

n_eff / iter looks reasonable for all parameters.
Rhat looks reasonable for all parameters.
0 of 2000 (0.0%) iterations ended with a divergence.
0 of 2000 (0.0%) iterations saturated the maximum tree depth of 18.
E-BFMI indicated no pathological behavior.


0

In [38]:
transformation = lambda x: (x - np.mean(x)) / np.std(x)

bokeh.io.show(bebi103.viz.parcoord_plot(samples_beh_inf, 
                                        transformation=transformation, 
                                        pars=['w1[1]', 'w2[1]', 'w3[1]', 'w4[1]'],
                                       divergence_alpha=0.1, 
                                        divergence_line_width=0.5))

In [49]:
bokeh.io.show(bebi103.viz.trace_plot(samples_beh_inf, 
                                     pars=['w1[1]', 'w2[1]', 'w3[1]', 'w4[1]', 'sigma'], 
                                     inc_warmup=True))

In [50]:
bokeh.io.show(bebi103.viz.corner(samples_beh_inf, 
                                 pars=['w1[1]', 'w2[1]', 'w3[1]', 'w4[1]', 'sigma'],
                                 labels=['w1[1]', 'w2[1]', 'w3[1]', 'w4[1]', 'sigma']))