# Week 3 Lecture 1 - Spurious Waffle Houses

McElreath's lectures for the whole book are available here:https://github.com/rmcelreath/stat_rethinking_2022

An R/Stan repo of code is available here: https://vincentarelbundock.github.io/rethinking2/

An excellent port to Python/PyMC Code is available here:  https://github.com/dustinstansbury/statistical-rethinking-2023

You are encouraged to work through both of these versions to re-enforce what we're doing in class.

In [None]:
# Import python packages
%matplotlib inline
import pandas as pd
import numpy as np
import seaborn as sns
import scipy as sp 
import random as rd
import pymc as pm
from matplotlib import pyplot as plt

## Waffle Houses

Let's import the Waffe House devorce data:

In [None]:
# Import data
ddata = pd.read_csv('WaffleDivorce.csv',sep=';')
# Display top 5 rows
ddata.head()

In [None]:
# Table of descriptive statistics
ddata.describe().T

So not unintense data to look at, but let's start with divorce and Waffle Houses

In [None]:
plt.figure(figsize=(7,7))
plt.scatter(ddata.WaffleHouses,ddata.Divorce)
[plt.annotate(txt, (ddata.WaffleHouses[i], ddata.Divorce[i])) for i, txt in enumerate(ddata.Location)]
b1,b0 = np.polyfit(ddata.WaffleHouses,ddata.Divorce, 1)
xnew = np.linspace(0,400,100)
plt.plot(xnew,b0+b1*xnew,c='black')
plt.xlabel('Number of Waffle Houses', fontsize=17)
plt.ylabel('Divorce rate', fontsize=17)
plt.savefig('WaffleDivorce.jpg');

Or is divorce rate a product of marriage rate?

In [None]:
plt.figure(figsize=(7,7))
plt.scatter(ddata.Marriage,ddata.Divorce)
[plt.annotate(txt, (ddata.Marriage[i], ddata.Divorce[i])) for i, txt in enumerate(ddata.Location)]
b1,b0 = np.polyfit(ddata.Marriage,ddata.Divorce, 1)
xnew = np.linspace(13,32,100)
plt.plot(xnew,b0+b1*xnew,c='black')
plt.xlabel('Marriage rate', fontsize=17)
plt.ylabel('Divorce rate', fontsize=17)
plt.savefig('WaffleMarriage.jpg');

Or age at marriage?

In [None]:
plt.figure(figsize=(7,7))
plt.scatter(ddata.MedianAgeMarriage,ddata.Divorce)
[plt.annotate(txt, (ddata.MedianAgeMarriage[i], ddata.Divorce[i])) for i, txt in enumerate(ddata.Location)]
b1,b0 = np.polyfit(ddata.MedianAgeMarriage,ddata.Divorce, 1)
xnew = np.linspace(23,32,100)
plt.plot(xnew,b0+b1*xnew,c='black')
plt.xlabel('Median marriage age', fontsize=17)
plt.ylabel('Divorce rate', fontsize=17)
plt.savefig('WaffleAge.jpg');

And what does the South have to do with all this? Well with the assertion of a causal model we can take a look and see. For example, if we assert:

A->M->D
A->D

Then we can look and see what the affect of marriage rate (M) is on divorce (D), given that we know the median age (A). To do this we need a statistcal model to help evaluate this DAG.

$$
D_i \sim N(\mu_i,\sigma)\\
\mu_i = \beta_0+\beta_M M_i+\beta_A A_i
$$

There is nothing magic here - we've all done multiple regression before - but what is new is our causal assertion. Weird, that what we assert and assume changes things eh? But you should get very comfortable with this idea because it turns out it lies at the core of scientific enquiry - as Popper argued, causality is built consenually. 

First we should standardize variables:


In [None]:
def stdize(x):
    return (x-np.mean(x))/np.std(x)

In [None]:
A = stdize(ddata.MedianAgeMarriage.values)
M = stdize(ddata.Marriage.values)
D = stdize(ddata.Divorce.values)

State = ddata.Location.values

With covariates in hand we can do some prior predictive simulation to see what priors might look like in terms of possible lines:

In [None]:
# Number of samples
nsamp = 100
# Intercept
β0_ = np.random.normal(0, .2, nsamp)
# Marriage rate slope
βm_ = np.random.normal(0, .5, nsamp)
# Marriage age slope
βa_ = np.random.normal(0, .5, nsamp)

In [None]:
_, ax = plt.subplots(1,2, figsize=(8,4))

# Grab range of marriage ages to plot over
A_ = np.linspace(min(A),max(A),50)
# Plot resulting lines given sample values for β0 and βa, using a list comprehension
[ax[0].plot(A_, b0+b1*A_, c='black', alpha=0.1) for b0,b1 in zip(β0_,βa_)]
# Make it look nice
ax[0].set_xlabel('Median marriage age (std)', fontsize=17)
ax[0].set_ylabel('Divorce rate (std)', fontsize=17)



# Grab range of marriage rates to plot over
M_ = np.linspace(min(M),max(M),50)
# Plot resulting lines given sample values for β0 and βm, using a list comprehension
[ax[1].plot(M_, b0+b1*M_, c='black', alpha=0.1) for b0,b1 in zip(β0_,βm_)]
# Make it look nice
ax[1].set_xlabel('Marriage rate (std)', fontsize=17)
ax[1].set_ylabel('Divorce rate (std)', fontsize=17)
plt.tight_layout()
plt.savefig('n0xstd.jpg');

Next, we can bulid a NUTS model in PyMC:

In [None]:
# Causal model

# Bayesian PyMC
with pm.Model(coords={'State':State}) as divorce:
    # Priors
    β0 = pm.Normal('Intercept', 0, .2)
    βa = pm.Normal('Marriage age', 0, .5)
    βm = pm.Normal('Marriage rate', 0, .5)
    σ = pm.Exponential('Sigma', 1)
    
    # Linear model
    μ_ = pm.Deterministic('Mu', β0+βa*A+βm*M, dims='State')
    
    # Link function
    μ = μ_*1
    
    # Likelihood
    yi = pm.Normal('yi',μ, σ, observed=D)

In [None]:
# Run sampler
with divorce:
    trace = pm.sample(1000)

In [None]:
pm.summary(trace)

In [None]:
pm.plot_trace(trace, var_names=['Intercept', 'Marriage age', 'Marriage rate', 'Sigma'])
plt.tight_layout()
plt.savefig('posterior.jpg',dpi=300);

In [None]:
pm.plot_forest(trace, var_names=['Intercept', 'Marriage age', 'Marriage rate', 'Sigma'], ridgeplot_overlap=3)
plt.axvline(0,linestyle=':')
plt.xlabel('Effect size (std)')
plt.tight_layout()
plt.savefig('forest.jpg');

In [None]:
# Bayesian PyMC
with pm.Model(coords={'State':State}) as divorce_m:
    # Priors
    β0 = pm.Normal('Intercept', 0, .2)
    βm = pm.Normal('Marriage rate', 0, .5)
    σ = pm.Exponential('Sigma', 1)
    
    # Linear model
    μ_ = pm.Deterministic('Mu', β0+βm*M, dims='State')
    
    # Link function
    μ = μ_*1
    
    # Likelihood
    yi = pm.Normal('yi',μ, σ, observed=D)

In [None]:
# Bayesian PyMC
with pm.Model(coords={'State':State}) as divorce_a:
    # Priors
    β0 = pm.Normal('Intercept', 0, .2)
    βa = pm.Normal('Marriage age', 0, .5)
    σ = pm.Exponential('Sigma', 1)
    
    # Linear model
    μ_ = pm.Deterministic('Mu', β0+βa*A, dims='State')
    
    # Link function
    μ = μ_*1
    
    # Likelihood
    yi = pm.Normal('yi',μ, σ, observed=D)

In [None]:
# Run samplers
with divorce_m:
    trace_m = pm.sample(1000)
with divorce_a:
    trace_a = pm.sample(1000)

In [None]:
import arviz as az

In [None]:
_, ax = plt.subplots(1,2, figsize=(8,4))

v_ = 'Marriage age'
ax_ = 0
az.plot_dist(trace_a.posterior[v_],ax=ax[ax_], label=v_)
ax[ax_].set_xlim(-1,1)
ax[ax_].axvline(0,c='black',linestyle=":")

v_ = 'Marriage rate'
ax_ = 1
az.plot_dist(trace_m.posterior[v_],ax=ax[ax_], label=v_)
ax[ax_].set_xlim(-1,1)
ax[ax_].axvline(0,c='black',linestyle=":")
plt.savefig('singles.jpg',dpi=300);

In [None]:
_, ax = plt.subplots(1,2, figsize=(8,4))

v_ = 'Marriage age'
ax_ = 0
az.plot_dist(trace.posterior[v_],ax=ax[ax_], label=v_+' (full)',color='red')
az.plot_dist(trace_a.posterior[v_],ax=ax[ax_], label=v_+' (only)')
ax[ax_].axvline(0,c='black',linestyle=":")


v_ = 'Marriage rate'
ax_ = 1
az.plot_dist(trace.posterior[v_],ax=ax[ax_], label=v_+' (full)', color='red')
az.plot_dist(trace_m.posterior[v_],ax=ax[ax_], label=v_+' (only)')
ax[ax_].axvline(0,c='black',linestyle=":")

plt.savefig('conditionals.jpg',dpi=300);

# Plotting

Among the most - I'll say **the most** - important checks on your models is to plot the model and the data together. It is critical that you see things the way the model sees things, otherwise it is difficult to know how well you're doing in fitting these things. Three options are:

    1. Predictor residual plots
    2. Posterior prediction plots
    3. Counterfactual plots
    
Each has their own value and can tell us something about how our model is doing.

## 1. Predictor residual plots

There's an awful legacy in Biology of modelling the residuals of another model. It's awful because it's wrong, and you should never do it. It's wrong because it doesn't get the unceratinties right, prioritizing variation in the first analysis and hiding it in the second. This leads to biased estiamtes, possibly for both models, but certainly for the second. But there is some utility in seeing what information remains in one predictor when you already have information about the other (which is what multiple regression does). 

To do this we need to build individual models where we regress one predictor on the other, which will give us the marginal benefit of the other predictor conditional on knowing one of them.

So for the divorce case, we have two models:

In [None]:
# Bayesian PyMC
with pm.Model() as m_a:
    # Priors
    β0 = pm.Normal('Intercept', 0, .2)
    βm = pm.Normal('Marriage rate', 0, .5)
    σ = pm.Exponential('Sigma', 1)
    
    # Linear model
    μ_ = pm.Deterministic('Mu', β0+βm*M)
    
    # Link function
    μ = μ_*1
    
    # Likelihood
    yi = pm.Normal('yi',μ, σ, observed=A)

In [None]:
# Bayesian PyMC
with pm.Model() as a_m:
    # Priors
    β0 = pm.Normal('Intercept', 0, .2)
    βm = pm.Normal('Marriage age', 0, .5)
    σ = pm.Exponential('Sigma', 1)
    
    # Linear model
    μ_ = pm.Deterministic('Mu', β0+βm*A)
    
    # Link function
    μ = μ_*1
    
    # Likelihood
    yi = pm.Normal('yi',μ, σ, observed=M)

In [None]:
# Run samplers
with a_m:
    trace_m = pm.sample(1000)
with m_a:
    trace_a = pm.sample(1000)

In [None]:
# Get residuals for the other predictor - first set of chains (of 4)
m_pred = trace_m.posterior['Mu'].values[0,].mean(0)
residuals_m = M - m_pred

a_pred = trace_a.posterior['Mu'].values[0,].mean(0)
residuals_a = A - a_pred

In [None]:
residuals_a.shape

In [None]:
_, ax = plt.subplots(1,2, figsize=(10,4))

xnew = np.linspace(-2,3,100)

v_ = 'Marriage rate'
ax_ = 0
ax[ax_].scatter(M,A)
ax[ax_].plot(xnew,np.median(trace_a.posterior['Intercept'])+np.median(trace_a.posterior[v_])*xnew)
ax[ax_].set_ylabel('Marriage age', fontsize=17,c='red')
ax[ax_].set_xlabel(v_, fontsize=17)
ax[ax_].vlines(M, a_pred, a_pred + residuals_a, colors='grey')

v_ = 'Marriage age'
ax_ = 1
ax[ax_].scatter(A,M)
ax[ax_].plot(xnew,np.median(trace_m.posterior['Intercept'])+np.median(trace_m.posterior[v_])*xnew)
ax[ax_].set_ylabel('Marriage rate', fontsize=17)
ax[ax_].set_xlabel(v_, fontsize=17,c='red')
ax[ax_].vlines(A, m_pred, m_pred + residuals_m, colors='grey')
plt.tight_layout()
plt.savefig('residualplots.jpg',dpi=300);

What's seemingly bonkers, is that we now have the residuals for each parameter, we can plot them against divorce to see how the **full model** actually sees these things inside their guts:

In [None]:


_, ax = plt.subplots(2,2, figsize=(10,10))

xnew = np.linspace(-2,3,100)

v_ = 'Marriage age'
ax_ = 0
ax[0,ax_].scatter(residuals_a,D)
coef = np.polyfit(residuals_a,D,1)
poly1d_fn = np.poly1d(coef)
ax[0,ax_].plot(residuals_a, poly1d_fn(residuals_a), 'red', label='Local fit')
ax[0,ax_].plot(xnew,np.median(trace.posterior['Intercept'].values[0,])+np.median(trace.posterior[v_].values[0,])*xnew,label='Model slope')
m,b = np.polyfit(residuals_a,D, 1)
[ax[0,ax_].plot(xnew, trace.posterior['Intercept'].values[0,][i]+trace.posterior[v_].values[0,][i]*xnew, alpha=0.05, c='black') for i in range(100)]
ax[0,ax_].set_ylabel('Divorce rate')
ax[0,ax_].set_xlabel(v_+' residuals (M->A)')
ax[0,ax_].legend()

v_ = 'Marriage rate'
ax_ = 1
ax[0,ax_].scatter(residuals_m,D)
coef = np.polyfit(residuals_m,D,1)
poly1d_fn = np.poly1d(coef)
ax[0,ax_].plot(residuals_m, poly1d_fn(residuals_m), 'red', label='Local fit')
ax[0,ax_].plot(xnew,np.median(trace.posterior['Intercept'].values[0,])+np.median(trace.posterior[v_].values[0,])*xnew,label='Model slope')
[ax[0,ax_].plot(xnew, trace.posterior['Intercept'].values[0,][i]+trace.posterior[v_].values[0,][i]*xnew, alpha=0.05, c='black') for i in range(100)]
ax[0,ax_].set_ylabel('Divorce rate')
ax[0,ax_].set_xlabel(v_+' residuals (A->M)')

v_ = 'Marriage age'
ax_ = 0
ax[1, ax_].scatter(A,D)
coef = np.polyfit(A,D,1)
poly1d_fn = np.poly1d(coef)
ax[1,ax_].plot(A, poly1d_fn(A), 'red', label='Local fit')
ax[1,ax_].plot(xnew,np.median(trace.posterior['Intercept'].values[0,])+np.median(trace.posterior[v_].values[0,])*xnew,label='Model slope')
[ax[1,ax_].plot(xnew, trace.posterior['Intercept'].values[0,][i]+trace.posterior[v_].values[0,][i]*xnew, alpha=0.05, c='black') for i in range(100)]
ax[1,ax_].set_ylabel('Divorce rate')
ax[1,ax_].set_xlabel(v_)

v_ = 'Marriage rate'
ax_ = 1
ax[1, ax_].scatter(M,D)
coef = np.polyfit(M,D,1)
poly1d_fn = np.poly1d(coef)
ax[1,ax_].plot(M, poly1d_fn(M), 'red', label='Local fit')
ax[1,ax_].plot(xnew,np.median(trace.posterior['Intercept'])+np.median(trace.posterior[v_])*xnew,label='Model slope')
[ax[1,ax_].plot(xnew, trace.posterior['Intercept'].values[0,][i]+trace.posterior[v_].values[0,][i]*xnew, alpha=0.05, c='black') for i in range(100)]
ax[1,ax_].set_ylabel('Divorce rate')
ax[1,ax_].set_xlabel(v_)

plt.tight_layout()

plt.savefig('machine.jpg');


So conditional on knowing marriage rate, marriage age still tells us something useful about divorce, but conditional on knowing marriage age, marriage rate tells us very little. Hence the difference in parameter estimates, with marriage age having a way bigger effect size. 

Incidentally, while we have these residuals, let's take a look at their distribtuion and what they mean:

In [None]:
_, ax = plt.subplots(1,2, figsize=(10,4))

tmp = ax[0].hist(residuals_a, label='Resid (A)')
ax[0].plot(xnew,sp.stats.norm.pdf(xnew,0,trace_m.posterior['Sigma'].mean())*max(tmp[0]),label='Sigma (A)')
ax[0].set_xlabel('A residuals')
ax[0].legend()

tmp = ax[1].hist(residuals_m, label='Resid (M)')
ax[1].plot(xnew,sp.stats.norm.pdf(xnew,0,trace_a.posterior['Sigma'].mean())*max(tmp[0]),label='Sigma (M)')
ax[1].set_xlabel('M residuals')
ax[1].legend()
plt.tight_layout()
plt.savefig('residplot.jpg',dpi=300);

The distribution of the residuals is the error distribution (`Sigma`) for the linear model - i.e. `Sigma` describes the magnitude of the deviations from the regression line. 

## 2. Posterior prediction plots

Another important question is - how well is our model capturing the observed data? Are our predictions about each observation any good? Having used MCMC for our inference (and stored the values using a `pm.Determinisitc` node), we can just grab the observed and expected values and plot them:

In [None]:
# Create data frame of transposed traces for each state observation
PostObs = pd.DataFrame(trace.posterior['Mu'].values[0,], columns=ddata.Location)
PostObs

In [None]:
# Calculate expected y values and UI's 
y = PostObs.median(0).values
y_l95 = np.percentile(PostObs,2.5,axis=0)
y_u95 = np.percentile(PostObs,97.5,axis=0)

In [None]:
# Plot expected vs observed
plt.vlines(D, y_l95, y_u95, colors='grey')
plt.scatter(D,y)
plt.plot((-2,2),(-2,2),linestyle=":")
plt.xlabel('Observed')
plt.ylabel('Posterior')
plt.tight_layout()
plt.savefig('obs_ex.jpg',dpi=300);

So we can see that our model underpredicts high divorce rates (right side) and overpredicts low devorce rates (left side) but that is to be expected, it is a normal model after all an predictions tend to shrink toward the overall average. 

But it does look like there are some outlying values, let's label a few

In [None]:
# Plot expected vs observed
plt.vlines(D, y_l95, y_u95, colors='grey')
plt.scatter(D,y)
plt.plot((-2,2),(-2,2),linestyle=":");


# Label states that are >x SD off
x = 1.3
[plt.text(D[i],y[i],ddata.Location.values[i]) for i in np.arange(0,len(ddata.Location.values))[abs(D-y)>x]]
plt.xlabel('Observed')
plt.ylabel('Posterior')
plt.tight_layout()
plt.savefig('obs_ex2.jpg',dpi=300);

## 3. Counterfactual plots

Counterfactuals are frequently brought up in statistical circles, and especially in economics, as a device to imagine what would happen if something else had happened in our data. In the case of counterfactual plots, they show us what happens if we manipulate one variable while keeping the others constant. 

If we return to the causal model where median marriage age influences divorce rate both directly and indirectly via marriage rate, we can develop a counterfactual plot by simulating from our `divorce` and `a_m` models above.

In [None]:
# Divorce model trace
pm.summary(trace,var_names=['Intercept','Marriage rate','Marriage age','Sigma'])

In [None]:
# a_m model trace
pm.summary(trace_m, var_names=['Intercept','Marriage age','Sigma'])

With these values in place, we can see what the predicted change in divorce rate is across the full range of changes in median marriage age. To do this, we first choose the range of marriage ages:

In [None]:
# Marriage age prediction range
nsim = 100
A_new = np.linspace(min(A),max(A),nsim)

Next we calculate the expected effect of marriage age on marriage rate:

In [None]:
# Marriage rates given marriage age range
M_new = trace_m.posterior['Intercept'].values[0,]+trace_m.posterior['Marriage age'].values[0,]*A_new[:,None]

And finally we simulate from the full `divorce` model, given our new (counterfactual) covariate values:


In [None]:
# Divorce rates given marriage age range and manipulated marriage rate
D_new = trace.posterior['Intercept'].values[0,]+trace.posterior['Marriage age'].values[0,]*A_new[:,None]+trace.posterior['Marriage rate'].values[0,]*M_new.mean(1)[:,None]

In [None]:
_, ax = plt.subplots(1,2, figsize=(10,4))

ax[0].plot(A_new,M_new.mean(1))
# Uncertainty intervals
ax[0].plot(A_new,np.quantile(M_new,0.95,1),linestyle=":", c='blue')
ax[0].plot(A_new,np.quantile(M_new,0.05,1),linestyle=":",  c='blue')
ax[0].set_xlabel('Marriage age range')
ax[0].set_ylabel('Manipulated marriage rate')
ax[0].set_ylim(-3,2)


# Expected trend
ax[1].plot(A_new,D_new.mean(1))
# Uncertainty intervals
ax[1].plot(A_new,np.quantile(D_new,0.95,1),linestyle=":", c='blue')
ax[1].plot(A_new,np.quantile(D_new,0.05,1),linestyle=":",  c='blue')

ax[1].set_xlabel('Marriage age range')
ax[1].set_ylabel('Counterfactual divorce rate')
ax[1].set_ylim(-3,2)

plt.savefig('counterfactual.jpg',dpi=300);

# Masked relationships

One of the many (many, many,...) pitfalls of statistical models is the presence of masked relationships - variables that counteract each other so they each appear to have no particular relationship. The primate milk data has just such a case.

In [None]:
# Import data
mdata = pd.read_csv('milk.csv', sep=';')
# Drop rows where neocortex percent is nan
mdata = mdata[mdata['neocortex.perc'].notna()]
# Add log(mass) column
mdata['log(mass)'] = np.log(mdata.mass.values)
mdata.head()

If we take a look at the bivariate relationships among variables, it seems there's not too much going on beyond the relationship between 

In [None]:
g = sns.PairGrid(mdata, vars=['kcal.per.g','log(mass)','neocortex.perc'])
g.map_upper(sns.scatterplot, s=15)
g.map_lower(sns.kdeplot)
g.map_diag(sns.kdeplot, lw=2);

Yet if we run the full model for the relationship between log(mass) and neocortex.conc on kcal.per.g, we get a surprise:

In [None]:
# Grab variables of interest
logMass = stdize(mdata['log(mass)'].values)
neocorp = stdize(mdata['neocortex.perc'].values)
kcal = stdize(mdata['kcal.per.g'])

In [None]:
# Bayesian PyMC
with pm.Model() as milker:
    # Priors
    β0 = pm.Normal('Intercept', 0, .2)
    β1 = pm.Normal('log(mass)', 0, .5)
    β2 = pm.Normal('neocortex_perc', 0, .5)
    σ = pm.Exponential('Sigma', 1)
    
    # Linear model
    μ_ = β0+β1*logMass+β2*neocorp
    
    # Link function
    μ = μ_*1
    
    # Likelihood
    yi = pm.Normal('yi',μ, σ, observed=kcal)

In [None]:
with milker:
    trace_milk = pm.sample(1000)

In [None]:
pm.summary(trace_milk)

What's this now? Both log(mass) and percent neocortex do not span zero, meaning they have strong relationships in the data. This can happen and is due to some unknown variable having synnergistic effects on both variables, but in different directions. Because they both happen they appear to not have any effect in a bivariate plot, but when both are present, their actual effects are revealed. Knowing that it can happen is half the battle. But it still sucks that it does.