## Objectives and takeaways
1. Take a real-world experiment, write the model.
2. Write a Metropolis sampler, including the proposal distribution.
3. Perform inference using your sampler

Remark: with thanks to Charles Lao for consulting on the structure of a suitable model.

## Experiment

We shall work with the experiment published in Belin and Rubin [1] in 1995 that analyzed reaction times to visual stimuli in Schizophrenia.

A total of 17 volunteers performed 30 repetitions of a visual task and their reaction time was measured in milliseconds. There were 6 schizophrenics and 11 healthy volunteers.

Note that in the work [1], the authors do not use a Bayesian approach for estimation but apply an EM procedure.  The priors that we define in this work must therefore be our construction.

Below is the original data from the experiment, available [here](http://www.stat.columbia.edu/~gelman/book/data/schiz.asc).

In [None]:
orig_data = """
312 272 350 286 268 328 298 356 292 308 296 372 396 402 280 330 254 282 350 328 332 308 292 258 340 242 306 328 294 272
354 346 384 342 302 312 322 376 306 402 320 298 308 414 304 422 388 422 426 338 332 426 478 372 392 374 430 388 354 368
256 284 320 274 324 268 370 430 314 312 362 256 342 388 302 366 298 396 274 226 328 274 258 220 236 272 322 284 274 356
260 294 306 292 264 290 272 268 344 362 330 280 354 320 334 276 418 288 338 350 350 324 286 322 280 256 218 256 220 356
204 272 250 260 314 308 246 236 208 268 272 264 308 236 238 350 272 252 252 236 306 238 350 206 260 280 274 318 268 210
590 312 286 310 778 364 318 316 316 298 344 262 274 330 312 310 376 326 346 334 282 292 282 300 290 302 300 306 294 444
308 364 374 278 366 310 358 380 294 334 302 250 542 340 352 322 372 348 460 322 374 370 334 360 318 356 338 346 462 510
244 240 278 262 266 254 240 244 226 266 294 250 284 260 418 280 294 216 308 324 264 232 294 236 226 234 274 258 208 380
232 262 230 222 210 284 232 228 264 246 264 316 260 266 304 268 384 234 308 266 294 254 222 262 278 290 208 232 206 206
318 324 282 364 286 342 306 302 280 306 256 334 332 336 360 344 480 310 336 314 392 284 292 280 320 322 286 406 352 324
240 292 350 254 396 430 260 320 298 312 290 248 276 364 318 434 400 382 318 298 298 248 250 234 280 306 282 234 424 244

276 272 264 258 278 286 314 340 334 364 286 344 312 380 262 324 310 260 280 262 364 316 270 286 326 302 300 302 344 290
374 466 432 376 360 454 478 382 524 410 520 470 514 354 434 380 416 384 462 386 404 362 420 360 390 356 550 372 386 396
594 1014 1586 1344 610 838 772 264 748 1076 446 314 304 1680 1700 334 256 422 302 296 354 322 276 382 502 428 544 286 650 432
402 466 296 348 680 702 500 500 576 624 406 378 586 826 298 882 564 656 716 380 448 506 1714 748 510 810 984 458 390 642
620 714 414 358 460 598 324 442 372 410 998 636 968 490 696 560 562 720 618 456 502 974 1032 470 462 798 716 300 586 574
454 388 344 226 562 766 502 432 608 516 500 796 542 458 448 404 372 524 400 366 374 350 1154 558 440 348 400 460 514 450"""

In [None]:
import numpy as np

# the first 11 lines are from controls, the last 6 from schizophrenics
def parse_data():
    rts, idx = [], 0
    for line in orig_data.split('\n'):
        if len(line) == 0: continue
        cat = 0 if idx < 11 else 1
        tokens = line.split(' ')
        rts.append(list(map(int, line.split(' '))))
    return np.array(rts)

reaction_times = parse_data()

In [None]:
Np, Nt = reaction_times.shape
print('Data shows %d patients, %d trials per patient' % (Np, Nt))

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

_, axes = plt.subplots(6, 3, figsize=(16, 8), sharex=True)
bins = np.linspace(reaction_times.min(), reaction_times.max(), 100)
axes = axes.flatten()
for isbj in range(Np):
    axes[isbj].hist(reaction_times[isbj, :], bins)
    axes[isbj].set_title('subject index %d [%s]' %(isbj, 'control' if isbj < 11 else '*patient*'), fontsize=14)
plt.tight_layout();

### Data description
Patients are stored in rows. For example data in line ```rts[0,:]``` shows reaction times in milliseconds for the first control group participant.

In [None]:
print(reaction_times[0:11,:].mean(1))

### Model structure
Below we discuss model similar to model number 4 from Belin and Rubin [1], their richest model. The other models are not discussed in this notebook.

In the following, we will use the subscript $i \in \{1, 2, ..., 17\}$ to denote participants and $j \in \{1, 2, ..., 30\}$ to denote trials, so for example $Y_{i,j}$ denotes observed reaction time for patient $i$ and trial $j$.

It is assumed that there is a mean reaction time $\mu$ and standard deviation $\sigma_1$ in the control group but each participant in the study can have a slightly different mean reaction time $\alpha_i$. Schizophrenic participants behave like the control group except in specific trials, where they exhibit an attention deficit that causes their reaction time to be increased.  The (unobserved) variable $Z_{i,j} \in \{0,1\}$ denotes whether the attention deficit was present $Z_{i,j}=1$ or absent $Z_{i,j}=0$ in trial $j$ for patient $i$. The proportion of trials $\lambda_i$ where the deficit manifests varies among patients.

We may formalize the model for the control group as follows:

$$Y_{i,j} \sim {\cal N}(\alpha_i, \sigma_1^2)$$

and for the schizophrenia group as

$$Y_{i,j} \sim {\cal N}(\alpha_i, \sigma_1^2),$$

if the trial had no attention deficit present (so exactly the same as control group) and

$$Y_{i,j} \sim {\cal N}(\alpha_i + \tau, \sigma_2^2),$$

if there was an attention deficit.  Note the different standard deviation used to model the reaction time under attention deficit.  In the model below, we use $Z_{i,j}$ to keep track of which mode is active in the trials performed by schizophrenic patients.  We additionally posit that 

$$\begin{array}{rcl}
\lambda_i &\sim& \text{U}(0,1) \\
Z_{i,j} &\sim& \text{Bernoulli}(\lambda), \\
\mu &\sim& {\cal N}(400, 100)\\
\alpha_i &\sim& {\cal N}(\mu, 50), \\
\tau &\sim& \text{HalfNormal}(50), \\
\sigma_1 &\sim& \text{HalfNormal}(100), \\
\sigma_2 &\sim& \text{HalfNormal}(100). \\
\end{array}$$

### Structure of solution
1. Select priors for the problem (see below).
2. Write down the structure of the model, write the ```log_prior``` and ```log_likelihood``` functions.
3. Write a proposal distribution, function ```proposal``` to suggest a new state from current state.
4. Write (or modify, from notebook 3f) the Metropolis sampler.
5. Sample outputs and evaluate.

## Working sampler using PyMC3

So that the structure of the model is clear, we provide a working example of a sampler using the PyMC3 library.

In [None]:
import pymc3 as pm
import theano.tensor as tt

In [None]:
with pm.Model() as schizo_model:
    mu = pm.Normal('mu', 400, 100)
    alphas_ = pm.Normal('alphas_', 0, 50., shape=(Np,1))
    alphas = pm.Deterministic('alphas', mu + alphas_)
    
    sigma_ctrl = pm.HalfNormal('sigma_ctrl', 100)
    sigma_pat = pm.HalfNormal('sigma_pat', 100)
    tau = pm.HalfNormal('tau', 100)
    
    lambdas = pm.Uniform('lambdas', 0., 1., shape=(6,1))
    Z = pm.Bernoulli('Z', lambdas, shape=(6, Nt))

    controls = pm.Normal('control',
                      alphas[:11,:],
                      sigma_ctrl,
                      observed=reaction_times[:11, :])

    patients = pm.Normal('patients',
                       alphas[11:,:] + Z*tau,
                       tt.switch(Z, sigma_pat, sigma_ctrl),
                       observed=reaction_times[11:,:])

In [None]:
with schizo_model:
    trace = pm.sample(draws=2000,
                      tune=1000,
                      chains=2)

In [None]:
pm.summary(trace, varnames=['mu', 'tau', 'lambdas', 'sigma_ctrl', 'sigma_pat'])

In [None]:
with schizo_model:
    _ = pm.traceplot(trace, varnames=['mu','tau', 'alphas', 'lambdas', 'sigma_ctrl', 'sigma_pat'])

### Analysis of results and comparison to data
We expect that for patients where some reaction times are very high, the model will infer that those are affected by the deficit and model them differently than the rest of the patient trials.

Thus:
- $\sigma_2$ should be much higher than $\sigma_1$ - check
- $\alpha_i$ should be very close to data mean in control group and should be much lower than data mean in groups with high attention deficit trials
- $\lambda_i$ should vary strongly across patients
- $Z_{i,j}$ should vary across trials at least in patients 13-16
- $Z_{i,j}$ correlate with reaction time positively ($Z_{i,j}=1$ means the reaction time should be high)

In [None]:
plt.plot(np.mean(trace['alphas'],axis=0)-np.mean(reaction_times[:Np,:],axis=1)[:, np.newaxis], 'o')
plt.title('Difference in $\\alpha_i$ and data mean vs. patient')

In [None]:
Zs = trace['Z']

plt.figure(figsize=(10,4))
labels = []
for k in range(6):
    plt.plot(np.mean(Zs[:,k,:], axis=0), 'o-')
    labels.append('Patient %d' % (k+11))

plt.legend(labels)

In [None]:
plt.plot(np.arange(11, 17), np.mean(np.mean(Zs, axis=0), axis=1), 'o-')
plt.title('Proportion of trials with attn deficit')
plt.xlabel('Patient index [-]')
plt.ylabel('avg(Z)');

We see that patients 13, 14 and 15 have the highest attention deficit proportion, which seems to correspond to the initial data plots well.

In [None]:
plt.figure(figsize=(12,6))
for k in range(6):
    plt.subplot(2,3,k+1)
    ts = reaction_times[k+11,:]
    Zs_k = np.mean(Zs[:,k,:], axis=0)
    plt.scatter(ts, Zs_k)
    plt.ylim([0,1])
    plt.title('Patient %d' % (k+11))

### Exercise
According to our model, we would expect that the probability of $Z_{i,j}$ being one for low reaction times would be low and high/high would hold as well.  In the plots below, this does not hold and for some very small reaction times, we still see that the model claims that some very small reaction times actually result from attention deficits. Why is that?

Hint: note that if you replace the ```tt.switch(...)``` statement with ```sigma_ctrl```, this effect disappears.

# Your turn!
Now let's work on writing our own sampler.  While ```PyMC3``` uses a NUTS sampler plus Gibbs sampling for the latent variables $Z_{i,j}$, we will build the entire system using Metropolis and compare our results.

### Structure of state
Although as much freedom should be provided to write the code, there is a recommendation below on how to structure the state variable as a dictionary because it helps to improve the readability of the code.

In [None]:
# setup the initial point
lambda_initial = 0.5

# here's a nice initial point :)
v_init = { 'mu' : 400,
           'delta_alphas': np.random.randn(Np) * 50,
           'tau': 50,
           'sigma_ctrl': 100,
           'sigma_pat': 100,
           'lambdas' : np.ones(6,) * lambda_initial,
           'Zij': np.where(np.random.uniform(size=(6,Nt)) < lambda_initial, np.ones((6,Nt)), np.zeros((6,Nt)))
         }

MD = {
        'mu': (1, 0, np.inf),
        'delta_alphas': (Np, 0, np.inf),
        'tau' : (1, 0, np.inf),
        'sigma_ctrl' : (1, 0, np.inf),
        'sigma_pat' : (1, 0, np.inf),
        'lambdas': (6, 0, 1), 
    }

sigmas_0 = {
        'mu': 40,
        'delta_alphas': 10,
        'tau' : 10,
        'sigma_ctrl' : 20,
        'sigma_pat' : 20,
        'lambdas': 0.1,
        'Zij': 0.2
    }

def gen(v0, sigmas):
    def one(key, fd):
        size, lower, upper = fd, -np.inf, np.inf
        if isinstance(fd, tuple):
            size, lower, upper = fd    
        val = v0[key] + sigmas[key] * np.random.randn(size)
        #val = np.clip(val, lower, upper)
        if size == 1:
            val = val[0]
        return val
    
    d= {}
    for key, fd in MD.items():
        d[key] = one(key, fd)
    
    zs = np.copy(v0['Zij'])
    for i, l in enumerate(d['lambdas']):
        for j in range(Nt):
            if np.random.uniform() < sigmas['Zij']:
                zs[i,j] = 1.0 if np.random.uniform() < 0.5 else 0.0
    d['Zij'] = zs
    for k,v in d.items():
        if type(v) == np.ndarray:
            assert v.shape == v0[k].shape
    return d


In [None]:
from scipy.stats import norm, halfnorm, uniform, bernoulli

def log_prior(v):
    assert len(v['Zij']) == len(v['lambdas']) == 6
    PS = (norm.logpdf(v['mu'], 400, 100),
          norm.logpdf(v['delta_alphas'], v['mu'], 50).sum(),
            halfnorm.logpdf(v['tau'], scale=50),
            halfnorm.logpdf(v['sigma_ctrl'], scale=100),
            halfnorm.logpdf(v['sigma_pat'], scale=100),
            uniform.logpdf(v['lambdas']).sum(),
            np.sum(bernoulli.logpmf(v['Zij'][i,:], v['lambdas'][i]).sum() for i in range(6))
           )
    #print(PS)
    return sum(PS)

def log_likelihood(v, X=reaction_times):
    assert len(v['delta_alphas']) == Np
    s = 0.0
    # control
    for i in range(11):
        s += norm.logpdf(X[i, :], v['delta_alphas'][i], v['sigma_ctrl']).sum()
        #print(i, s)
    
    # patients
    for i in range(6):
        zs = v['Zij'][i]
        p_sigmas = v['sigma_pat'] * zs + (1-zs) * v['sigma_ctrl']
        assert all( s in [v['sigma_pat'], v['sigma_ctrl']] for s in p_sigmas)
        p_alphas = zs * v['tau'] + v['delta_alphas'][11 + i]
        s += norm.logpdf(X[11 + i, :], p_alphas, p_sigmas).sum()
        #print(i, s)
    
    return s


def log_posterior(v):
    return log_prior(v) + log_likelihood(v)

print(log_posterior(v_init))

In [None]:
def metropolis(x0, logp, N, gen, verbose=False):
    states = [x0]
    logps = [logp(x0)]
    acc = 0
    for i in range(N):
        if verbose:
            c = ''
            if i % 5000 == 4999:
                c = ''
                print("%.2f%%"%(100*i/N))
            elif i % 1000 == 999:
                c = '1'
            elif i % 100 == 99:
                c = '.'
            print(c, end='')
        
        xold = states[-1]
        lold = logps[-1]
        
        xnext = gen(xold, i/N)
        lnext = logp(xnext)

        u = np.random.uniform()
        if lnext - lold > np.log(u):
            states.append(xnext)
            logps.append(lnext)
            acc += 1
        else:
            # this is different from Monte Carlo rejection sampler
            # if we reject a new sample we 're-sample' the current state
            states.append(xold)
            logps.append(lold)
    if verbose:        
        print()
    print("acceptance = %.2f%%"%(100*acc/N))
    return acc / N, states, logps

In [None]:
from multiprocessing import Pool

# SUCH pain to get working values
sigmas_1 = {
        'mu': 6.4,
        'delta_alphas': 3.2,
        'tau' : 12.0,
        'sigma_ctrl' : 2.0,
        'sigma_pat' : 8.0,
        'lambdas': 0.01,
        'Zij':0.1
    }

def one(i):
    np.random.seed(i)
    return metropolis(v_init, log_posterior, 100000, (lambda v,p : gen(v, sigmas_1)))

RUNS = 4
with Pool(4) as p:
    runs = p.map(one, range(RUNS))

#acc, S0, L0 = metropolis(v_init, log_posterior, 4000, (lambda v,p : gen(v, sigmas_1)))

In [None]:
CUTOFF = 0

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

#plt.hist(L)
#plt.show()

plt.figure(figsize=(16,8))

for i, k in enumerate(['mu', 'tau', 'sigma_ctrl', 'sigma_pat']):
    for acc, SS, L in runs:
        S = SS[CUTOFF:]
        values = [s[k] for s in S]
        
        plt.subplot(4,2,i*2 + 1)
        plt.title(k)
        sns.kdeplot(values)
        plt.subplot(4,2,i*2 + 2)
        plt.title(k)
        S = S[CUTOFF:]
        values = [s[k] for s in S]
        plt.plot(values)
        
    
plt.show()
for i in range(6):
    plt.figure(figsize=(16,8))
    for acc, SS, L in runs:
        S = SS[CUTOFF:]

        values = [s['lambdas'][i] for s in S]
        plt.subplot(1,2,1)
        sns.kdeplot(values)
        plt.title('lambda %d'%i)
        plt.subplot(1,2,2)
        plt.plot(values)
        plt.title('lambda %d'%i)

    plt.show()

plt.figure(figsize=(16,8))
for i in range(4):
    for acc, SS, L in runs:
        S = SS[CUTOFF:]
        values = [s['delta_alphas'][i] for s in S]
        plt.subplot(1,2,1)
        sns.kdeplot(values)
        plt.title('alphas')
        plt.subplot(1,2,2)
        plt.plot(values)
        plt.title('alphas')

plt.show()

### Final remarks
Thanks to Charles Lao for consulting on fitting model structure for this experiment.