# Case study: Reaction times of Schizophrenics

## Objectives and takeaways
1. Take a real-world experiment, write the model.
2. Write a Metropolis sampler, including the proposal distribution.
3. Perform inference using your sampler

## Experiment

We shall work with the experiment published in Belin and Rubin [1] in 1995 that analyzed reaction times to visual stimuli in Schizophrenia.

A total of 17 volunteers performed 30 repetitions of a visual task and their reaction time was measured in milliseconds. There were 6 schizophrenics and 11 healthy volunteers.

Note that in the work [1], the authors do not use a Bayesian approach for estimation but apply an EM procedure. The priors that we define in this work must therefore be our construction.

Below is the original data from the experiment, available [here](http://www.stat.columbia.edu/~gelman/book/data/schiz.asc).

[1] T. Belin and D. Rubin, “The analysis of repeated‐measures data on schizophrenic reaction times using mixture models,” Stat Med, vol. 14, no. 8, pp. 747–768, 1995.

In [1]:
orig_data = """
312 272 350 286 268 328 298 356 292 308 296 372 396 402 280 330 254 282 350 328 332 308 292 258 340 242 306 328 294 272
354 346 384 342 302 312 322 376 306 402 320 298 308 414 304 422 388 422 426 338 332 426 478 372 392 374 430 388 354 368
256 284 320 274 324 268 370 430 314 312 362 256 342 388 302 366 298 396 274 226 328 274 258 220 236 272 322 284 274 356
260 294 306 292 264 290 272 268 344 362 330 280 354 320 334 276 418 288 338 350 350 324 286 322 280 256 218 256 220 356
204 272 250 260 314 308 246 236 208 268 272 264 308 236 238 350 272 252 252 236 306 238 350 206 260 280 274 318 268 210
590 312 286 310 778 364 318 316 316 298 344 262 274 330 312 310 376 326 346 334 282 292 282 300 290 302 300 306 294 444
308 364 374 278 366 310 358 380 294 334 302 250 542 340 352 322 372 348 460 322 374 370 334 360 318 356 338 346 462 510
244 240 278 262 266 254 240 244 226 266 294 250 284 260 418 280 294 216 308 324 264 232 294 236 226 234 274 258 208 380
232 262 230 222 210 284 232 228 264 246 264 316 260 266 304 268 384 234 308 266 294 254 222 262 278 290 208 232 206 206
318 324 282 364 286 342 306 302 280 306 256 334 332 336 360 344 480 310 336 314 392 284 292 280 320 322 286 406 352 324
240 292 350 254 396 430 260 320 298 312 290 248 276 364 318 434 400 382 318 298 298 248 250 234 280 306 282 234 424 244

276 272 264 258 278 286 314 340 334 364 286 344 312 380 262 324 310 260 280 262 364 316 270 286 326 302 300 302 344 290
374 466 432 376 360 454 478 382 524 410 520 470 514 354 434 380 416 384 462 386 404 362 420 360 390 356 550 372 386 396
594 1014 1586 1344 610 838 772 264 748 1076 446 314 304 1680 1700 334 256 422 302 296 354 322 276 382 502 428 544 286 650 432
402 466 296 348 680 702 500 500 576 624 406 378 586 826 298 882 564 656 716 380 448 506 1714 748 510 810 984 458 390 642
620 714 414 358 460 598 324 442 372 410 998 636 968 490 696 560 562 720 618 456 502 974 1032 470 462 798 716 300 586 574
454 388 344 226 562 766 502 432 608 516 500 796 542 458 448 404 372 524 400 366 374 350 1154 558 440 348 400 460 514 450"""

In [2]:
import numpy as np

# the first 11 lines are from controls, the last 6 from schizophrenics
def parse_data():
    rts, idx = [], 0
    for line in orig_data.split('\n'):
        if len(line) == 0: continue
        tokens = line.split(' ')
        rts.append(list(map(int, line.split(' '))))
    return np.array(rts, dtype=np.float)

reaction_times = parse_data()

### Data description
Patients are stored in rows. For example data in line ```reaction_times[0,:]``` shows reaction times in milliseconds for a healthy volunteer (control group) as ```patient_class[0] == 0```.

In [3]:
reaction_times[0,:]

array([312., 272., 350., 286., 268., 328., 298., 356., 292., 308., 296.,
       372., 396., 402., 280., 330., 254., 282., 350., 328., 332., 308.,
       292., 258., 340., 242., 306., 328., 294., 272.])

In [4]:
# reaction times for the first patient with schizophrenia
reaction_times[11,:]

array([276., 272., 264., 258., 278., 286., 314., 340., 334., 364., 286.,
       344., 312., 380., 262., 324., 310., 260., 280., 262., 364., 316.,
       270., 286., 326., 302., 300., 302., 344., 290.])

### Model structure
Below we discuss model number 1 from Belin and Rubin [1], who show 3 additional models which had progressively more structure. The other models are not discussed in this notebook.

In the following, we will use the subscript $i \in \{1, 2, ..., 17\}$ to denote participants and $j \in \{1, 2, ..., 30\}$ to denote trials.

It is assumed that each trial participant has their own mean reaction time $\alpha_i$ and that there is a common variance in reaction time that is common to both groups $\sigma^2$.  We know if each participat is in the control group $X_i=0$ or a schizophrenic $X_i=1$ and we observe the reaction time $Y_{i,j}$ for each participant and trial.

In line with psychological theory, the authors assumed that schizophrenics behaved in the same way as the control group participants but in a proportion of the trials $\gamma$, they had difficulty attending to the task (attentional deficit) and thus their reaction time was delayed by $\tau$ milliseconds.

We may formalize the model for the control group as follows:

$$Y_{i,j} \sim {\cal N}(\alpha_i, \sigma^2)$$

and for the schizophrenia group as

$$Y_{i,j} \sim {\cal N}(\alpha_i + Z_{i,j}\tau, \sigma^2),$$

where $Z_{i,j}$ is a latent (unobserved)variable that denotes whether the attentional deficit was present in the trial or not.  We assume

$$Z_{i,j} \sim \text{Bernoulli}(\lambda)$$.

### Priors
Feel free to assume weakly informative priors and experiment. Original recommendations consisted of uniform and inverse gamma distributions but using wide Normal or HalfNormal distributions may work as well. Explore different options.

In [None]:
Np, Nt = reaction_times.shape
Np = 17 # use Np-11 schizo patients

In [None]:
def log_prior(v):
    alphas, tau, lam, sigma = v['alphas'], v['tau'], v['lambda'], v['sigma']
    # reaction times
    logp = np.sum(np.log(1./100) - alphas**2 / (2*100**2))
    # time delay
    logp += np.log(1./100) - tau**2 / (2*100**2)
    # proportion of attention deficits 
    logp += -np.inf if lam < 0 or lam > 1 else 0
    # variance prior
    logp += np.log(1./10) - sigma**2 / (2*10**2)
    return logp

def log_likelihood(v, rts):
    ll = 0.0
    tau, sigma, Z, alphas = v['tau'], v['sigma'], v['Z'], v['alphas']
    for i in range(11):
        for j in range(Nt):
            ll += -np.log(sigma) - (rts[i, j] - alphas[i])**2 / (2*sigma**2)
            
    for i in range(11,Np):
        for j in range(Nt):
            ll += -np.log(sigma) - (rts[i, j] - (alphas[i] + Z[i-11,j]*tau))**2 / (2*sigma**2)
#             ll += -np.log(sigma) - (rts[i, j] - alphas[i])**2 / (2*sigma**2)
            
    return ll
    
def model_log_posterior_unnorm(v, rts):
    return log_prior(v) + log_likelihood(v, rts)

In [None]:
# setup the initial point
#lam = np.random.uniform()
lam = 0.1

v_init = { 'alphas': np.mean(reaction_times, axis=1)[:Np],
           'tau': np.abs(np.random.randn() * 100),
#           'sigma2': sigma2_rv.rvs(),
           'sigma': 15,
           'lambda' : lam,
           'Z': np.where(np.random.uniform(size=(Np-11,Nt)) < lam, np.ones((Np-11,Nt)), np.zeros((Np-11,Nt)))
         }
v_init

In [None]:
log_prior(v_init), log_likelihood(v_init, reaction_times), model_log_posterior_unnorm(v_init, reaction_times)

In [None]:
scale = 2.0

def propose(v, accept_ratio=None):
    global scale
    
    # adaptive region
    if accept_ratio is not None:
        if accept_ratio < 0.1:
            scale *= 0.8
            print('downscale=%g' % scale)
        if accept_ratio > 0.5:
            scale *= 1.0/0.8
            print('upscale=%g' % scale)

    assert(scale < 5)
            
    v_new = {}
    # only change a few alphas at a time
    v_new['alphas'] = np.where(np.random.uniform(size=Np) < 0.1,
                               v['alphas'] * np.random.uniform(low=1.0-0.05*scale, high=1.0/(1.0-0.05*scale), size=Np),
                               v['alphas'])
#    v_new['tau'] = v['tau'] + np.random.randn() * 5 * scale
    v_new['tau'] = v['tau'] * np.random.uniform(low=1.0-0.1*scale, high=1.0/(1.0-0.1*scale))
    v_new['sigma'] = v['sigma'] * np.random.uniform(low=1.0-0.05*scale, high=1.0/(1.0-0.05*scale))
    
    # make jump in transformed space
    lam_old = v['lambda']
    lam_logodds = np.log(lam_old / (1 - lam_old))
    lam_logodds_new = np.log(lam_old / (1 - lam_old)) + np.random.randn() * 0.1 * scale
    v_new['lambda'] = np.exp(lam_logodds_new) / (1 + np.exp(lam_logodds_new))
    
    #v_new['lambda'] = (v['lambda'] + np.random.randn() * 0.01 * scale) % 1.0
    # sample the new Z_ij parameters but only do part of them
    v_new['Z'] = np.copy(v['Z'])
    
    ndx = np.random.choice((Np-11)*Nt, (Np-11)*Nt // 10, replace=False)
    rows = ndx // Nt
    cols = ndx % Nt
    for r,c in zip(rows,cols):
        v_new['Z'][r, c] = 1 if np.random.uniform() < v_new['lambda'] else 0
    
#    np.where(np.random.uniform(size=(Np-11,Nt)) < 0.05,
#                          np.where(np.random.uniform(size=(Np-11,Nt)) < v_new['lambda'], np.ones((Np-11,Nt)), np.zeros((Np-11,Nt))),
#                          v['Z'])
    return v_new

def metropolis(v_init, log_posterior, n, status_period=None):
    v, logp = v_init, log_posterior(v_init)
    states, states_logp = [v_init], [logp]
    was_accept = []
    
    if status_period is None:
        status_period = n // 10
    
    for i in range(1, n):
        
        # every 500 steps, invoke adaptation
        accept_probability = None
        if i % 500 == 0:
            accept_probability = float(np.sum(was_accept[i-500:i])/500)
        
        # propose new values
        v_new = propose(v, accept_probability)
        
        # compute the log posterior
        logp_new = log_posterior(v_new)
        
        # draw uniform number
        u = np.random.uniform()
        if logp_new - logp > np.log(u):
            v, logp = v_new, logp_new
            states.append(v_new)
            states_logp.append(logp_new)
            was_accept.append(1)
        else:
            # this is different from Monte Carlo rejection sampler
            # if we reject a new sample we 're-sample' the current state
            states.append(v)
            states_logp.append(logp)
            was_accept.append(0)
            
        if i % status_period == 0:
            print('Stats @ %d: accept_ratio=%g avg_logp=%g' % (i, float(np.sum(was_accept[i-status_period:i]))/status_period, np.mean(states_logp[i-status_period:i])))
            
    return states, float(np.sum(was_accept)) / n, was_accept

In [None]:
trace, accept_ratio, proposals = metropolis(v_init, lambda v: model_log_posterior_unnorm(v, reaction_times), 50000, 1000)

In [None]:
accept_ratio

In [None]:
taus = [v['tau'] for v in trace]
sigmas = [v['sigma'] for v in trace]
lambdas = [v['lambda'] for v in trace]
alphas = np.vstack([v['alphas'] for v in trace])

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

plt.figure(figsize=(12,6))
plt.subplot(4,2,1)
sns.kdeplot(lambdas)
plt.ylabel('$\lambda$', fontsize=14)
plt.subplot(4,2,2)
plt.plot(lambdas)

plt.subplot(4,2,3)
sns.kdeplot(sigmas)
plt.ylabel('$\sigma$', fontsize=14)
plt.subplot(4,2,4)
plt.plot(sigmas)

plt.subplot(4,2,5)
sns.kdeplot(taus)
plt.ylabel('$\\tau$', fontsize=14)
plt.subplot(4,2,6)
plt.plot(taus)

plt.subplot(4,2,7)
plt.plot(np.mean(alphas, axis=0) - np.mean(reaction_times[:Np,:], axis=1), 'or-')

plt.show()

In [None]:
# Control check, this should be heavily skewed toward 1
Z_2_all = np.vstack([v['Z'][0,:] for v in trace])

plt.figure(figsize=(16,12))
for i in range(5):
    for j in range(6):
        plt.subplot(5, 6, i*6+j+1)
        plt.hist(Z_2_all[:,i*6+j], bins=[-0.5, 0.5, 1.5], density=True)
        plt.title('Z[13,%d]=%g' % (i*6+j, reaction_times[13,i*6+j]))
_ = plt.show()

In [None]:
plt.scatter(reaction_times[11,:], np.mean(Z_2_all, axis=0))

In [None]:
import pymc3 as pm
import theano.tensor as tt

In [None]:
Np = 12

with pm.Model() as schizo_model:
    
    tau = pm.HalfNormal('tau', 100)
    lam = pm.Uniform('lambda', 0, 1.0)
    sigma = pm.HalfNormal('sigma', 10)
    alphas = pm.HalfNormal('alphas', 100, shape=(Np))
    
    Z = pm.Bernoulli('Z', lam, shape=(Nt))
    
    for i in range(11):
        Y_ij_health = pm.Normal('Y_%d' % i, alphas[i], sigma, observed=reaction_times[i,:])
    for i in range(11,Np):
        Y_ij_schizo = pm.Normal('Y_%d' % i, alphas[i] + Z*tau, sigma, observed=reaction_times[i,:])

In [None]:
with schizo_model:
    trace = pm.sample(draws=10000, tune=1000, step=pm.Metropolis(), chains=2)

In [None]:
pm.summary(trace, varnames=['tau', 'lambda', 'sigma'])

In [None]:
with schizo_model:
    _ = pm.traceplot(trace, varnames=['tau', 'lambda', 'sigma', 'alphas'])

In [None]:
alphas1 = trace['alphas']

In [None]:
alphas1.shape

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
means = np.vstack([np.mean(rts,axis=1), ])
plt.plot(np.mean(alphas1,axis=0)-np.mean(rts[:Np,:],axis=1), 'o')
plt.title('Difference in estimated and data mean')

In [None]:
Z = trace['Z']

plt.figure(figsize=(16,12))
for i in range(5):
    for j in range(6):
        plt.subplot(5, 6, i*6+j+1)
        plt.hist(Z[:,i*6+j], density=True)
        plt.title('Z[11,%d]=%g' % (i*6+j, rts[13,i*6+j]))
_ = plt.show()