## Autism with mixed effects models

This follows the coursera course's assignment on autism data. The data looks at the socialization of a child with autism as they progress through childhood.

We will be comparing the statsmodels option (presented in the class) with the more explicit methods promoted by McElrath in his statistical learning course. If I get really ambitious, I will try to have an R notebook that also dose the `lme4` package.

The variables from the data file are(

* __AGE__ is the age of a child (between 2 and 13 years)
* __VSAE__ is a measure of the child's socialization
* __SICDEGP__ is the expressive language group at age 2, and can take values {1,2,3}. Higher values indicate more expressive language.
* __CHILDID__ is the unique ID that is given to each child and acts as their identifier.

We will be renaming the variables, as __VSAE__ and __SICDEGP__ don't have meaning to us.

Follows the example given [on coursera](https://www.coursera.org/learn/fitting-statistical-models-data-python/ungradedLab/ffKgX/fitting-multilevel-and-marginal-models-to-autism-data-in-python/lab?path=%2Fnotebooks%2Fweek3%2FAutism_Multilevel_Marginal_Models.ipynb)

In [1]:
import pandas as pd
import numpy as np
import statsmodels.api as sm
import matplotlib.pyplot as plt
import patsy
from scipy.stats import chi2

In [2]:
autism = pd.read_csv("autism.csv").dropna().rename(
    columns={
        'AGE': 'age',
        'vsae': 'social',
        'sicdegp': 'language',
        'childid': 'id'
    }
)

In [3]:
autism.head()

Unnamed: 0,age,social,language,id
0,2,6.0,3,1
1,3,7.0,3,1
2,5,18.0,3,1
3,9,25.0,3,1
4,13,27.0,3,1


## 1. Build first model (no centering)

Note from the statsmodels documentation on [`from_formula`](https://www.statsmodels.org/stable/generated/statsmodels.regression.mixed_linear_model.MixedLM.from_formula.html)

Let's say we have schools, and classrooms within schools. We want to find the relationship between tests and age, accounting for the effects of classroom and schools. The school will be the top-level group, and the classroom is a nested group. Note that classroom labels may be (but need not be) different across schools, and the number of classrooms can be different school to school.

```
>>> vc = {'classroom': '0 + C(classroom)'}
>>> MixedLM.from_formula('test_score ~ age', 
                         vc_formula=vc,
                         re_formula='1',
                         groups='school',
                         data=data)
```



In [4]:
# Build the model
mlm_mod = sm.MixedLM.from_formula(
    formula = 'social ~ age * C(language)', 
    groups = 'id', 
    re_formula="1 + age", 
    data=autism
)

# Run the fit
mlm_result = mlm_mod.fit()

# Print out the summary of the fit
mlm_result.summary()



0,1,2,3
Model:,MixedLM,Dependent Variable:,social
No. Observations:,610,Method:,REML
No. Groups:,158,Scale:,62.2592
Min. group size:,1,Log-Likelihood:,-2348.7987
Max. group size:,5,Converged:,Yes
Mean group size:,3.9,,

0,1,2,3,4,5,6
,Coef.,Std.Err.,z,P>|z|,[0.025,0.975]
Intercept,1.901,1.600,1.188,0.235,-1.235,5.038
C(language)[T.2],-0.415,2.109,-0.197,0.844,-4.549,3.718
C(language)[T.3],-3.917,2.345,-1.670,0.095,-8.514,0.680
age,2.957,0.593,4.986,0.000,1.794,4.119
age:C(language)[T.2],0.741,0.784,0.945,0.344,-0.795,2.277
age:C(language)[T.3],4.356,0.869,5.014,0.000,2.653,6.058
id Var,58.265,2.990,,,,
id x age Cov,-28.736,0.697,,,,
age Var,14.204,0.283,,,,


We see that we have
```
social = (1.901 - 0.415 delta_{lang,2} - 3.917 delta_{lang,3})
       + (2.957 + 0.741 delta_{lang,2} + 4.356 delta_{lang,3}) age 
       + 58.265 * z_id
       + 14.204 * z_age|id
       - 28.736 C(age, id)
       epsilon
```
This is a standard linear model, with each age also being the draw of a random variable with mean 0 and variance of 14.204; each individual having a difference in intercept contribute a mean of 0 and an variance of 14.204, and these variables having a negative correlation.

The course notes have a bunch of information about the inability to measure the socialization of infants. This is  nonsense; data wasn't collected at age zero (presumably because of the difficulty of measurement) and the linear extrapolation of the line would give an intercept. 

If we are estimating a different effect for each age within a particular id, we don't have enough data to estimate the variance individuals.

In [5]:
# Build the model
mlm_mod = sm.MixedLM.from_formula(
    formula = 'social ~ age * C(language)', 
    groups = 'id', 
    re_formula="0 + age", 
    data=autism
)

# Run the fit
mlm_result = mlm_mod.fit()

# Print out the summary of the fit
mlm_result.summary()

0,1,2,3
Model:,MixedLM,Dependent Variable:,social
No. Observations:,610,Method:,REML
No. Groups:,158,Scale:,84.5319
Min. group size:,1,Log-Likelihood:,-2427.0905
Max. group size:,5,Converged:,Yes
Mean group size:,3.9,,

0,1,2,3,4,5,6
,Coef.,Std.Err.,z,P>|z|,[0.025,0.975]
Intercept,2.482,1.271,1.952,0.051,-0.010,4.973
C(language)[T.2],-1.293,1.674,-0.773,0.440,-4.574,1.987
C(language)[T.3],-4.230,1.862,-2.272,0.023,-7.880,-0.580
age,2.822,0.470,6.006,0.000,1.901,3.743
age:C(language)[T.2],0.985,0.620,1.589,0.112,-0.230,2.199
age:C(language)[T.3],4.463,0.688,6.482,0.000,3.113,5.812
age Var,8.198,0.124,,,,


Writing this explicitly, we have
```
social ~ (2.482 - 1.293 delta_{lang,2} - 4.230 delta_{lang,3})
       + (2.822 + 0.985 delta_{lang,2} + 4.463 delta_{lang,3}) age
       + 8.198 (z_age|id)
```

## Let's try the same model using numpyro

We are going to start with the simpler model (both because it is the model that makes sense, and because we don't have to deal with covariances)
  

In [6]:
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp

from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
import numpyro

In [7]:
autism.describe()

Unnamed: 0,age,social,language,id
count,610.0,610.0,610.0,610.0
mean,5.770492,26.409836,1.959016,105.332787
std,3.974853,30.78981,0.762391,62.283496
min,2.0,1.0,1.0,1.0
25%,2.0,10.0,1.0,48.25
50%,4.0,14.0,2.0,107.5
75%,9.0,27.0,3.0,158.0
max,13.0,198.0,3.0,212.0


In [8]:
def independent_model(age=None, language=None, child_id=None, social=None):
    a = numpyro.sample('a', dist.Normal(0, 30))
    b = numpyro.sample('b', dist.Normal(0, 30))
    sigma = numpyro.sample('sigma', dist.Exponential(1.))
    
    mu = dist.Normal(a*age + b, sigma)
    numpyro.sample('obs', mu, obs=social)

In [9]:
numpyro.__version__

'0.7.2'

In [10]:
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(independent_model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(rng_key_, 
         age=autism.age.values, 
         language=autism.language.values,
         child_id=autism['id'].values,
         social=autism.social.values,
        )
mcmc.print_summary()
samples_1 = mcmc.get_samples()

sample: 100%|██████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [00:03<00:00, 941.57it/s, 3 steps of size 4.32e-01. acc. prob=0.93]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         a      4.55      0.26      4.55      4.09      4.92    781.93      1.00
         b      0.16      1.80      0.16     -2.58      3.24    752.61      1.00
     sigma     24.57      0.67     24.53     23.56     25.74   1137.49      1.00

Number of divergences: 0


In [11]:
def model_child_units(age=None, language=None, child_id=None, social=None):
    num_language = 3
    num_children = 610
    
    M = numpyro.sample('M', dist.Normal(0, 30))
    B = numpyro.sample('B', dist.Normal(0, 30))
    M_sigma = numpyro.sample('M_sigma', dist.Exponential(1.0))
    B_sigma = numpyro.sample('B_sigma', dist.Exponential(1.0))
    sigma = numpyro.sample('sigma', dist.Exponential(1.))
    
    slopes = numpyro.sample('child_slope', dist.Normal(M, M_sigma), sample_shape=(num_children,))
    intercepts = numpyro.sample('child_intercept', dist.Normal(B, B_sigma), sample_shape=(num_children,))
    
    grad_lang = numpyro.sample('grad_language', dist.Normal(0,30), sample_shape=(num_language,))
    intercept_lang = numpyro.sample('intercept_language', dist.Normal(0, 4), sample_shape=(num_language,))
    
    mu = dist.Normal(age*(slopes[child_id] + grad_lang[language])
                     + (intercepts[child_id] + intercept_lang[language]), sigma)
    numpyro.sample('obs', mu, obs=social)

In [12]:
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(model_child_units)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(rng_key_, 
         age=autism.age.values, 
         language=autism.language.values,
         child_id=autism['id'].values,
         social=autism.social.values,
        )

samples_1 = mcmc.get_samples()

for param in ['B', 'B_sigma', 'M', 'M_sigma', 'grad_language', 'intercept_language', 'sigma']:
    print(f"{param}\t{samples_1[param].mean(axis=0)}")

sample: 100%|████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [01:34<00:00, 31.62it/s, 1023 steps of size 3.99e-04. acc. prob=0.88]


B	-0.310611367225647
B_sigma	1.1656969785690308
M	0.10157603770494461
M_sigma	2.210665464401245
grad_language	[2.9776413 2.79052   5.1648445]
intercept_language	[ 1.0987537   2.3467777  -0.11276083]
sigma	10.015547752380371


In [13]:
mcmc.print_summary()


                           mean       std    median      5.0%     95.0%     n_eff     r_hat
                    B     -0.31      0.10     -0.31     -0.46     -0.13      3.22      2.08
              B_sigma      1.17      0.06      1.15      1.08      1.24     11.99      1.06
                    M      0.10      0.09      0.10     -0.02      0.25     20.14      1.06
              M_sigma      2.21      0.21      2.21      1.89      2.55      3.43      1.87
   child_intercept[0]     -0.92      1.00     -0.79     -2.54      0.66      8.54      1.01
   child_intercept[1]     -0.68      1.06     -0.66     -2.42      1.07     16.12      1.00
   child_intercept[2]     -0.27      0.97     -0.50     -1.56      1.56      3.81      1.46
   child_intercept[3]     -0.39      0.99     -0.32     -1.98      1.13      6.68      1.04
   child_intercept[4]     -0.57      1.31     -0.16     -2.70      1.27      4.61      1.00
   child_intercept[5]     -0.32      1.06     -0.12     -1.95      1.23      7.