# Purpose
## Building an Interaction between categorical predictors

In McElreath Statistical rethinking 2nd edition [1], Richard discribes how to build interactions between continous and categorical predicatros as well as continous and continous predicators. How do we build a regression model with interactions between two categorical predicators? My first attempt was simply to build the regression using *two intercepts* where each categorical predicator would have their respected mean:

$$ y_i \sim Normal(\mu_i, \sigma)$$
$$\mu_i = \alpha_{factor_A[i]} + \alpha_{factor_B[i]}$$
$$\alpha_{factor_A[j]} \sim  Normal(0, 10)\quad \text{for}\: j = 1...J$$
$$\alpha_{factor_B[k]} \sim  Normal(0, 10)\quad \text{for}\: k = 1...K$$
$$\sigma \sim Exponential(1)$$

Above we are using a index variable for the categories which the approach is also discussed in the book, chapter 5. [1] A more experience practionar might already see the problem with this model definition, but until I built the model I didn't realize the error. I wasn't the only one to think the above model was a reasonable approach, see the post on the PyMC discourse forum titled [Modeling two varying intercepts](https://discourse.pymc.io/t/modeling-two-varying-intercepts/6107) for a spolier.

To be clear, the model definition above is **wrong** to predict a continous variable with two categorical predicator but let's simulate a dataset to see why model definition is incorrect. Lets simulate data that we know doesn't have interactions just to isolate the issue.

```python
import jax.numpy as jnp
from jax.numpy import DeviceArray

import numpyro
import numpyro.distributions as dist

proba_ = 0.28
cpt_ = jnp.array([[0.9, 0.1], [0.2, 0.8]])
mean_ = jnp.array([15.,10.])

def synthetic_00(proba: float, cpt: DeviceArray, mean: DeviceArray):
    A = numpyro.sample("A", dist.Bernoulli(proba))
    M = numpyro.sample("M", dist.Categorical(cpt[A]))
    D = numpyro.sample("D", dist.Normal(mean[A],1))
```

* model render 

The function above simulates a joint probability where A is a common cause of both D and M, and D and M are conditional independent given A. Before fitting we expect that if we regress D on just A we would see one category will have a mean of 15 and another with a mean of 10. 

* pairs plot

We will fit the model using Numpyro. The incorrect model of predicting a continous outcome with two categorical predicators is shown below:

```python
def model_00(
    factor_00: DeviceArray, factor_01: DeviceArray, y: Optional[DeviceArray] = None
):
    num_data = cat_00.shape[0]
    num_factor_00 = len(set(factor_00))
    num_factor_01 = len(set(factor_01))

    sigma = numpyro.sample("sigma", dist.Exponential(1))

    with numpyro.plate("num_factor_00", num_factor_00):
        a_factor_00 = numpyro.sample("a_factor_00", dist.Normal(0, 10))

    with numpyro.plate("num_factor_01", num_cat_01):
        b_factor_01 = numpyro.sample("b_factor_01", dist.Normal(0, 10))

    with numpyro.plate("num_data", num_data):
        mu = numpyro.deterministic(
            "mu", a_factor_00[factor_00] + b_factor_01[factor_01]
        )
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
```

* mcmc.print_summary()
* trace plot

Looking at the trace plots we can see the model converges and don't see any obvious problems with the inference. The problem with this model is multicollinearity. With both means in the model there are infinite combinations of the two parameters that will result in equally good fits to the data such that the sums of the parameters recovers the response between D and A. We can verify the cause by inspecting the correlation of the posterior for the two parameters.

* image of the posterior distribution of the parameters

The plot above shows a significant correlation between the two parameters. Visually inspecting the correlation of the posterior for the parameters is always good practice for your model diagnostics, see [chapter 8 of Course Handouts for Bayesian Data Analysis Class for examples](https://bookdown.org/marklhc/notes_bookdown/model-diagnostics.html).

## Wait... How do we build an interaction with a continous outcome and two categorical predicators?

Again more traditionally trained in statistics might already recognize the pattern of our problem as a simple analysis of variance (ANOVA). Our next challenge is how to set up the model using our Bayesian approach. Typical statitical texts develope ANOVA by the decomposition of the variance for each factor through the ANOVA table, but there is an analogous way to setup ANOVA using regression. Aside, most statical test are just linear models, see the post by [Jonas Kristoffer Lindeløv titled Common statistical tests are linear models](https://lindeloev.github.io/tests-as-linear/#6_three_or_more_means). I find the simplest ANOVA model to implement without re-parameterizing the data into a design matrix is the cell means model:

$$Y_{i,j,k} \sim Normal(\mu_{i,j}, \sigma)$$
$$\mu_{i,j} \sim Normal(0,10)$$
$$\sigma \sim Exponential(1)$$
$$i = 1,...,I\: \text{levels of factor}\: I$$
$$j = 1,...,J\: \text{levels of factor}\: J$$
$$k = 1,...,K\: \text{observations in cell}\: (i,j)$$


* Wanted to see how I can build a regression prediciting a continous/metric response using two predictors that are discrete/categorical/factors/nominal
* The large motivation for the exercise is recover known interaction effects between the two predictors
* McElreath [1] shows how to build interactions between continous and categorical, continous and continous, but doesn't show a categorical and categorical
* Using the approach index variables for categorical predicators I tought the model will simply by the additoin of two seperate means
* This approach is incorrest and was similar to [Modeling two varying intercepts](https://discourse.pymc.io/t/modeling-two-varying-intercepts/6107) which introduced multicollinearity


* build a Bayesian Regression with continous output on two categorical predicators
* key is to see if we model categorical predicators as index variables can we recover 2-way interactions

* The data and examples are adopted from reference [2]
* 

1. Ref: McElreath, R., 2018. Statistical rethinking: A Bayesian course with examples in R and Stan. Chapman and Hall/CRC.
2. Ref: (2016). Understanding 2-way Interactions [Online]. University of Virginia Library Research Data Servi. Available at: [data.library.virginia.edu/understanding-2-way-interactions/](data.library.virginia.edu/understanding-2-way-interactions/) (Accessed: 7 October 2021).

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import arviz as az
import numpyro
import numpyro.distributions as dist
from numpyro.infer.util import Predictive
import pandas as pd
from numpyro.infer import MCMC, NUTS
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%load_ext watermark

In [None]:
%watermark -v -m -p arviz,jax,matplotlib,numpy,pandas,scipy,numpyro

In [None]:
%watermark -gb

In [None]:
num_warmup = 500
num_samples = 1000
num_chains = 4

In [None]:
np.repeat(["male", "female"], 20)

In [None]:
np.tile(np.repeat(["yes", "no"], 10), 2)

In [None]:
d_ = dist.Normal(loc=jnp.array([15, 10]))
s_ = d_.sample(jax.random.PRNGKey(2), (10,))
a_ = jnp.hstack([s_[:, 0], s_[:, 1]])

d_ = dist.Normal(loc=jnp.array([10, 15]))
s_ = d_.sample(jax.random.PRNGKey(3), (10,))
b_ = jnp.hstack([s_[:, 0], s_[:, 1]])

In [None]:
np.hstack([a_, b_])

In [None]:
dat = pd.DataFrame(
    {
        "gender": np.repeat(["male", "female"], 20),
        "trt": np.tile(np.repeat(["yes", "no"], 10), 2),
        "resp": np.hstack([a_, b_]),
    }
)

dat["gender"] = dat["gender"].astype("category")
dat["trt"] = dat["trt"].astype("category")

dat.head()

In [None]:
dat.info()

In [None]:
dat.groupby("gender").mean()

In [None]:
dat.groupby("trt").mean()

In [None]:
dat.groupby(["trt", "gender"]).mean()

In [None]:
dat["resp_z_score"] = (dat["resp"] - dat["resp"].mean())/ dat["resp"].std()

In [None]:
dat.groupby(["trt", "gender"])[["resp","resp_z_score"]].mean()

In [None]:
dat["gender_id"] = dat.gender.cat.codes
dat["trt_id"] = dat.trt.cat.codes

In [None]:
dat.info()

In [None]:
def model_00(cat_00, y=None):
    num_data = cat_00.shape[0]
    num_cat_00 = len(set(cat_00))

    sigma = numpyro.sample("sigma", dist.Exponential(1))

    with numpyro.plate("num_cat_00", num_cat_00):
        a_cat_00 = numpyro.sample("a_cat_00", dist.Normal(0, 1000))

    with numpyro.plate("num_data", num_data):
        numpyro.sample("obs", dist.Normal(a_cat_00[cat_00], sigma), obs=y)

In [None]:
kernel = NUTS(model_00)
mcmc = MCMC(
    kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
)
mcmc.run(jax.random.PRNGKey(4), jnp.array(dat["gender_id"]), jnp.array(dat["resp"]))
mcmc.print_summary()
ds_ = az.from_numpyro(mcmc)
az.plot_trace(ds_);

In [None]:
posterior_predictive = Predictive(model_00, mcmc.get_samples())(
    jax.random.PRNGKey(10),
    jnp.array(dat["gender_id"]),
)
fig, ax = plt.subplots()
ax.scatter(dat["resp"], np.array(posterior_predictive["obs"].mean(0)))

In [None]:
kernel = NUTS(model_00)
mcmc = MCMC(
    kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
)
mcmc.run(jax.random.PRNGKey(4), jnp.array(dat["trt_id"]), jnp.array(dat["resp"]))
mcmc.print_summary()
ds_ = az.from_numpyro(mcmc)
az.plot_trace(ds_);

In [None]:
fig, ax = plt.subplots()
ax.scatter(np.array(mcmc.get_samples()["a_cat_00"][:,0]),np.array(mcmc.get_samples()["a_cat_00"][:,1]))

In [None]:
def model_01(cat_00, cat_01, y=None):
    num_data = cat_00.shape[0]
    num_cat_00 = len(set(cat_00))
    num_cat_01 = len(set(cat_01))

    sigma = numpyro.sample("sigma", dist.Exponential(1))

    with numpyro.plate("num_cat_00", num_cat_00):
        a_cat_00 = numpyro.sample("a_cat_00", dist.Normal(0, 0.5))

    with numpyro.plate("num_cat_01", num_cat_01):
        b_cat_01 = numpyro.sample("b_cat_01", dist.Normal(0, 0.5))

    with numpyro.plate("num_data", num_data):
        mu = numpyro.deterministic("mu", a_cat_00[cat_00] + b_cat_01[cat_00])
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

In [None]:
kernel = NUTS(model_01)
mcmc = MCMC(
    kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
)
mcmc.run(
    jax.random.PRNGKey(4),
    jnp.array(dat["gender_id"]),
    jnp.array(dat["trt_id"]),
    jnp.array(dat["resp"]),
)
mcmc.print_summary(exclude_deterministic=True)
ds_ = az.from_numpyro(mcmc)
az.plot_trace(ds_);

In [None]:
kernel = NUTS(model_01)
mcmc = MCMC(
    kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
)
mcmc.run(
    jax.random.PRNGKey(4),
    jnp.array(dat["gender_id"]),
    jnp.array(dat["trt_id"]),
    jnp.array(dat["resp"]),
)
mcmc.print_summary(exclude_deterministic=True)
ds_ = az.from_numpyro(mcmc)
az.plot_trace(ds_);

In [None]:
fig, ax = plt.subplots()
ax.scatter(np.array(mcmc.get_samples()["a_cat_00"][:1000,0]), np.array(mcmc.get_samples()["a_cat_00"][:1000,1]))

In [None]:
mcmc.get_samples()["a_cat_00"][:,0]

In [None]:
posterior_predictive = Predictive(model_01, mcmc.get_samples())(
    jax.random.PRNGKey(10),
    jnp.array(dat["gender_id"]),
    jnp.array(dat["trt_id"]),
)

In [None]:
fig, ax = plt.subplots()
ax.scatter(dat["resp_z_score"], np.array(posterior_predictive["obs"].mean(0)))

In [None]:
def model_02(cat_00, cat_01, y=None):
    num_data = cat_00.shape[0]
    num_cat_00 = len(set(cat_00))
    num_cat_01 = len(set(cat_01))
    num_interactions = num_cat_00 * num_cat_01

    sigma = numpyro.sample("sigma", dist.Exponential(1))
    
    grand_mean = numpyro.sample("grand_mean", dist.Normal(0, 0.5))

    with numpyro.plate("num_cat_00", num_cat_00):
        a_cat_00 = numpyro.sample("a_cat_00", dist.Normal(0, 0.5))

    with numpyro.plate("num_cat_01", num_cat_01):
        b_cat_01 = numpyro.sample("b_cat_01", dist.Normal(0, 0.5))
    

    with numpyro.plate("num_data", num_data):
        mu = numpyro.deterministic("mu", grand_mean+ a_cat_00[cat_00] + b_cat_01[cat_00] + a_cat_00[cat_00]*b_cat_01[cat_00])
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

In [None]:
kernel = NUTS(model_02)
mcmc = MCMC(
    kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
)
mcmc.run(
    jax.random.PRNGKey(4),
    jnp.array(dat["gender_id"]),
    jnp.array(dat["trt_id"]),
    jnp.array(dat["resp"]),
)
mcmc.print_summary(exclude_deterministic=True)
ds_ = az.from_numpyro(mcmc)
az.plot_trace(ds_);

In [None]:
posterior_predictive = Predictive(model_02, mcmc.get_samples())(
    jax.random.PRNGKey(10),
    jnp.array(dat["gender_id"]),
    jnp.array(dat["trt_id"]),
)
fig, ax = plt.subplots()
ax.scatter(dat["resp"], np.array(posterior_predictive["obs"].mean(0)))

In [None]:
def model_04(factor_00, factor_01, y=None):
    num_data = factor_00.shape[0]
    num_factor_00 = len(set(factor_00))
    num_factor_01 = len(set(factor_01))

    sigma = numpyro.sample("sigma", dist.Exponential(1))
    
    with numpyro.plate("num_factor_01", num_factor_01):
        with numpyro.plate("num_factor_00", num_factor_00):
            cell_mean = numpyro.sample("cell_mean", dist.Normal(0, 1000))
    
    with numpyro.plate("num_data", num_data):
        numpyro.sample("obs", dist.Normal(cell_mean[factor_00, factor_01], sigma), obs=y)

In [None]:
kernel = NUTS(model_04)
mcmc_D_A_M = MCMC(
    kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
)
mcmc_D_A_M.run(
    jax.random.PRNGKey(4),
    jnp.array(dat["gender_id"]),
    jnp.array(dat["trt_id"]),
    jnp.array(dat["resp"]),
)
mcmc_D_A_M.print_summary(exclude_deterministic=True)
ds_ = az.from_numpyro(mcmc_D_A_M)
az.plot_trace(ds_);

In [None]:
mcmc_D_A_M.get_samples()["cell_mean"].shape

In [None]:
mcmc_D_A_M.get_samples()["cell_mean"].mean(0)

In [None]:
dat.groupby(["trt", "gender"])["resp"].mean().values.reshape(2,2).mean(1)

In [None]:
mcmc_D_A_M.get_samples()["cell_mean"].mean(0).sum(0)

In [None]:
mcmc_D_A_M.get_samples()["cell_mean"].mean(0).sum(1)

```python
H0: all (μJK – μJ. – μ.K + μ) = 0
```

In [None]:
mcmc_D_A_M.get_samples()["cell_mean"].mean(0) - mcmc_D_A_M.get_samples()["cell_mean"].mean(0).mean(0) - mcmc_D_A_M.get_samples()["cell_mean"].mean(0).mean(1) - mcmc_D_A_M.get_samples()["cell_mean"].mean(0).mean()

## Example from Purdue Stat 512

In [None]:
bread = pd.DataFrame(
    {
        "sales": [47, 43, 46, 40, 62, 68, 67, 71, 41, 39, 42, 46],
        "height": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
        "width": [1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2],
    }
)
bread["height_id"] = bread["height"].astype("category").cat.codes
bread["width_id"] = bread["width"].astype("category").cat.codes

In [None]:
kernel = NUTS(model_04)
mcmc = MCMC(
    kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
)
mcmc.run(
    jax.random.PRNGKey(4),
    jnp.array(bread["height_id"]),
    jnp.array(bread["width_id"]),
    jnp.array(bread["sales"]),
)
mcmc.print_summary(exclude_deterministic=True)
ds_ = az.from_numpyro(mcmc)
az.plot_trace(ds_);

In [None]:
mcmc.get_samples()["cell_mean"].mean(0)

In [None]:
bread["sales"].mean()

In [None]:
mcmc.get_samples()["cell_mean"].mean(0).mean()

In [None]:
mcmc.get_samples()["cell_mean"].mean(0).mean(1) - mcmc.get_samples()["cell_mean"].mean(0).mean()

In [None]:
mcmc.get_samples()["cell_mean"].mean(0).mean(0) - mcmc.get_samples()["cell_mean"].mean(0).mean()

In [None]:
mcmc.get_samples()["cell_mean"].mean(0).mean(0)

In [None]:
from patsy import dmatrix

In [None]:
bread["height"] = bread["height"].astype("category")
bread["width"] = bread["width"].astype("category")

In [None]:
bread

In [None]:
d = dmatrix("height*width", bread)
d

In [None]:
jnp.array(d)

In [None]:
def model_05(X, y=None):
    
    num_data = X.shape[0]
    num_feature = X.shape[1]
    
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    
    with numpyro.plate("num_feature", num_feature):
        beta = numpyro.sample("beta", dist.Normal(0, 1000))
    
    with numpyro.plate("num_data", num_data):
        mu = jnp.sum(beta * X, axis=-1)
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

In [None]:
kernel = NUTS(model_05)
mcmc = MCMC(
    kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
)
mcmc.run(
    jax.random.PRNGKey(4),
    jnp.array(d),
    jnp.array(bread["sales"]),
)
mcmc.print_summary(exclude_deterministic=True)
ds_ = az.from_numpyro(mcmc)
az.plot_trace(ds_);