In [1]:
import operator

import polars as pl
import jax
from jax import numpy as jnp
from jax.scipy.stats import norm


In [2]:
RAW_RADON_CSV_URL = (
    "https://raw.githubusercontent.com/stan-dev/example-models/"
    "e5b7d9e2e9ecc375805c7e49e4a4d4c1882b5e3b"
    "/jupyter/radon/data/mn_radon.csv"
)
RAW_URANIUM_CSV_URL = (
    "https://raw.githubusercontent.com/stan-dev/example-models/"
    "e5b7d9e2e9ecc375805c7e49e4a4d4c1882b5e3b"
    "/jupyter/radon/data/mn_uranium.csv"
)
mn_radon = pl.read_csv(RAW_RADON_CSV_URL)
mn_uranium = pl.read_csv(RAW_URANIUM_CSV_URL)
mn_uranium.head()

county,log_uranium,county_id,homes
str,f64,i64,i64
"""AITKIN """,-0.689048,1,4
"""ANOKA """,-0.847313,2,52
"""BECKER """,-0.113459,3,3
"""BELTRAMI """,-0.593353,4,7
"""BENTON """,-0.14289,5,4


In [3]:
mn_radon.head()

floor,county,log_radon,log_uranium,county_id
i64,str,f64,f64,i64
1,"""AITKIN """,0.788457,-0.689048,1
0,"""AITKIN """,0.788457,-0.689048,1
0,"""AITKIN """,1.064711,-0.689048,1
0,"""AITKIN """,0.0,-0.689048,1
0,"""ANOKA """,1.131402,-0.847313,2


```stan
data {
  int<lower=1> N;
  vector[N] x;
  vector[N] y;
}
parameters {
  real alpha;
  real beta;
  real<lower=0> sigma;
}
model {
  y ~ normal(alpha + beta * x, sigma);
  alpha ~ normal(0, 10);
  beta ~ normal(0, 10);
  sigma ~ normal(0, 10);
}
generated quantities {
  array[N] real y_rep = normal_rng(alpha + beta * x, sigma);
}
```

In [4]:
def fully_pooled_model_log_density(parameters, data):
    alpha, beta, log_sigma = parameters
    x, y = data
    sigma = jnp.exp(log_sigma)
    yhat = alpha + beta * x
    log_lik = norm.logpdf(y, loc=yhat, scale=jnp.full_like(yhat, sigma)).sum()
    log_prior = jax.tree.map(
        lambda leaf: norm.logpdf(leaf, loc=0, scale=10),
        parameters
    )
    return log_lik + jax.tree.reduce(operator.add, log_prior)

example_params_fp = (-1, 0.4, 0.0)
example_data_fp = (mn_radon["log_uranium"].to_numpy(), mn_radon["log_radon"].to_numpy())

fully_pooled_model_log_density(example_params_fp, example_data_fp)

Array(-3544.3743, dtype=float32)

```
data {
  int<lower=1> N;  // observations
  int<lower=1> J;  // counties
  array[N] int<lower=1, upper=J> county;
  vector[N] x;     // floor
  vector[N] y;     // radon
}
parameters {
  vector[J] alpha;
  real beta;
  real<lower=0> sigma;
}
model {
  y ~ normal(alpha[county] + beta * x, sigma);  
  alpha ~ normal(0, 10);
  beta ~ normal(0, 10);
  sigma ~ normal(0, 10);
}
generated quantities {
  array[N] real y_rep = normal_rng(alpha[county] + beta * x, sigma);
}
```

In [5]:
def unpooled_model_log_density(parameters, data):
    alpha, beta, log_sigma = parameters
    x, county_ix, y = data
    sigma = jnp.exp(log_sigma)
    yhat = alpha[county_ix] + beta * x
    log_lik = norm.logpdf(y, loc=yhat, scale=jnp.full_like(yhat, sigma)).sum()
    log_prior = jax.tree.map(
        lambda leaf: norm.logpdf(leaf, loc=0, scale=10).sum(),
        parameters
    )
    return log_lik + jax.tree.reduce(operator.add, log_prior)

N = len(mn_radon)
key_up = jax.random.key(1234)
key_a, key_b, key_s = jax.random.split(key_up, num=3)
example_params_up = jax.tree.map(
    lambda x, k: jax.random.normal(k, shape=(N,)) * 0.5,
    (-1, 0.4, 0.0),
    (key_a, key_b, key_s)
)
example_data_up = (
    mn_radon["log_uranium"].to_numpy(), 
    mn_radon["county_id"].to_numpy() - 1, 
    mn_radon["log_radon"].to_numpy()
)
unpooled_model_log_density(example_params_up, example_data_up)

Array(-11675.303, dtype=float32)

```stan
data {
  int<lower=1> N;  // observations
  int<lower=1> J;  // counties
  array[N] int<lower=1, upper=J> county;
  vector[N] x;
  vector[N] y;
}
parameters {
  real mu_alpha;
  real<lower=0> sigma_alpha;
  vector<offset=mu_alpha, multiplier=sigma_alpha>[J] alpha;  // non-centered parameterization
  real beta;
  real<lower=0> sigma;
}
model {
  y ~ normal(alpha[county] + beta * x, sigma);  
  alpha ~ normal(mu_alpha, sigma_alpha); // partial-pooling
  beta ~ normal(0, 10);
  sigma ~ normal(0, 10);
  mu_alpha ~ normal(0, 10);
  sigma_alpha ~ normal(0, 10);
}
generated quantities {
  array[N] real y_rep = normal_rng(alpha[county] + beta * x, sigma);
}

```

In [6]:
def partial_pooled_model_log_density(parameters, data):
    mu_alpha, sigma_alpha, alpha_z, beta, log_sigma = parameters
    prior_sd = (10.0, 10.0, 1.0, 10.0, 10.0)
    x, county_ix, y = data
    sigma = jnp.exp(log_sigma)
    alpha = mu_alpha + sigma_alpha * alpha_z 
    yhat = alpha[county_ix] + beta * x
    log_lik = norm.logpdf(y, loc=yhat, scale=jnp.full_like(yhat, sigma)).sum()
    log_prior = jax.tree.map(
        lambda leaf, sd: norm.logpdf(leaf, loc=0, scale=sd).sum(),
        parameters,
        prior_sd
    )
    return log_lik + jax.tree.reduce(operator.add, log_prior)

key_pp = jax.random.key(12345)
example_params_pp = (
    -1, 0.2, jax.random.normal(key_pp, shape=(N,)), 0.4, 0.0
)
example_data_pp = (
    mn_radon["log_uranium"].to_numpy(), 
    mn_radon["county_id"].to_numpy() - 1, 
    mn_radon["log_radon"].to_numpy()
)
partial_pooled_model_log_density(example_params_pp, example_data_pp)

Array(-4852.3555, dtype=float32)