In [10]:
import operator
from functools import partial

import arviz as az
import polars as pl
import jax
from jax import numpy as jnp
from jax.scipy.stats import norm

from blackjax_utils import run_nuts, get_idata

In [2]:
RAW_RADON_CSV_URL = (
    "https://raw.githubusercontent.com/stan-dev/example-models/"
    "e5b7d9e2e9ecc375805c7e49e4a4d4c1882b5e3b"
    "/jupyter/radon/data/mn_radon.csv"
)
RAW_URANIUM_CSV_URL = (
    "https://raw.githubusercontent.com/stan-dev/example-models/"
    "e5b7d9e2e9ecc375805c7e49e4a4d4c1882b5e3b"
    "/jupyter/radon/data/mn_uranium.csv"
)
mn_radon = pl.read_csv(RAW_RADON_CSV_URL)
mn_uranium = pl.read_csv(RAW_URANIUM_CSV_URL)
mn_uranium.head()

county,log_uranium,county_id,homes
str,f64,i64,i64
"""AITKIN """,-0.689048,1,4
"""ANOKA """,-0.847313,2,52
"""BECKER """,-0.113459,3,3
"""BELTRAMI """,-0.593353,4,7
"""BENTON """,-0.14289,5,4


In [3]:
mn_radon.head()

floor,county,log_radon,log_uranium,county_id
i64,str,f64,f64,i64
1,"""AITKIN """,0.788457,-0.689048,1
0,"""AITKIN """,0.788457,-0.689048,1
0,"""AITKIN """,1.064711,-0.689048,1
0,"""AITKIN """,0.0,-0.689048,1
0,"""ANOKA """,1.131402,-0.847313,2


```stan
data {
  int<lower=1> N;
  vector[N] x;
  vector[N] y;
}
parameters {
  real alpha;
  real beta;
  real<lower=0> sigma;
}
model {
  y ~ normal(alpha + beta * x, sigma);
  alpha ~ normal(0, 10);
  beta ~ normal(0, 10);
  sigma ~ normal(0, 10);
}
generated quantities {
  array[N] real y_rep = normal_rng(alpha + beta * x, sigma);
}
```

In [15]:
def fully_pooled_model_log_density(parameters, data):
    alpha, beta, log_sigma = parameters
    x, y = data
    sigma = jnp.exp(log_sigma)
    yhat = alpha + beta * x
    log_lik = norm.logpdf(y, loc=yhat, scale=jnp.full_like(yhat, sigma)).sum()
    log_prior = jax.tree.map(
        lambda leaf: norm.logpdf(leaf, loc=0, scale=10),
        parameters
    )
    return log_lik + jax.tree.reduce(operator.add, log_prior)

example_params_fp = (-1, 0.4, 0.0)
example_data_fp = (
    mn_radon["floor"].cast(pl.Float32).to_numpy(),
    mn_radon["log_radon"].to_numpy()
)

fully_pooled_model_log_density(example_params_fp, example_data_fp)

Array(-3369.8276, dtype=float32)

In [16]:
fully_pooled_model_log_posterior = partial(
    fully_pooled_model_log_density, 
    data=example_data_fp
)
fully_pooled_model_log_posterior(example_params_fp)

Array(-3369.8276, dtype=float32)

In [11]:
def get_idata(states, info, coords=None, dims=None):
    if isinstance(states.position, dict):
        target = states.position
    elif isinstance(states.position, tuple):
        target = {
            k: v 
            for k, v in enumerate(states.position)
        }
    else:
        raise ValueError("Unexpected input")
    idata = az.convert_to_inference_data(
        target,
        group="posterior",
        coords=coords,
        dims=dims,
    )
    idata.add_groups({"sample_stats": info._asdict()})
    return idata

In [17]:
mcmc_key = jax.random.key(1234)

states, info = run_nuts(
    key=mcmc_key,
    log_posterior=fully_pooled_model_log_posterior,
    init_params=jax.tree.map(jnp.array, example_params_fp),
    init_sd=(0.1, 0.1, 0.1),
)
idata_fp = get_idata(states, info)
idata_fp

In [18]:
az.summary(idata_fp)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
0,1.326,0.029,1.271,1.382,0.001,0.001,1537.0,1348.0,1.01
1,-0.611,0.072,-0.743,-0.48,0.002,0.001,1371.0,1645.0,1.0
2,-0.194,0.024,-0.239,-0.151,0.001,0.0,1856.0,1624.0,1.0


In [23]:
idata_fp.sample_stats["is_divergent"].sum().item()

0

In [20]:
jnp.exp(-0.194)

Array(0.8236579, dtype=float32, weak_type=True)

```
data {
  int<lower=1> N;  // observations
  int<lower=1> J;  // counties
  array[N] int<lower=1, upper=J> county;
  vector[N] x;     // floor
  vector[N] y;     // radon
}
parameters {
  vector[J] alpha;
  real beta;
  real<lower=0> sigma;
}
model {
  y ~ normal(alpha[county] + beta * x, sigma);  
  alpha ~ normal(0, 10);
  beta ~ normal(0, 10);
  sigma ~ normal(0, 10);
}
generated quantities {
  array[N] real y_rep = normal_rng(alpha[county] + beta * x, sigma);
}
```

In [27]:
def unpooled_model_log_density(parameters, data):
    alpha = parameters["alpha"]
    beta = parameters["beta"]
    log_sigma = parameters["log_sigma"]
    x, county_ix, y = data
    sigma = jnp.exp(log_sigma)
    yhat = alpha[county_ix] + beta * x
    log_lik = norm.logpdf(y, loc=yhat, scale=jnp.full_like(yhat, sigma)).sum()
    log_prior = jax.tree.map(
        lambda leaf: norm.logpdf(leaf, loc=0, scale=10).sum(),
        parameters
    )
    return log_lik + jax.tree.reduce(operator.add, log_prior)

N_county = mn_radon["county"].n_unique()
key_up = jax.random.key(1234)
example_params_up = {
    "alpha": -1 + jax.random.normal(key_up, shape=(N_county,)) * 0.5,
    "beta": 0.4,
    "log_sigma": 0.0,
}
example_data_up = (
    mn_radon["floor"].cast(pl.Float32).to_numpy(), 
    mn_radon["county_id"].to_numpy() - 1, 
    mn_radon["log_radon"].to_numpy()
)
unpooled_model_log_density(example_params_up, example_data_up)

Array(-3925.0999, dtype=float32)

In [33]:
unpooled_model_log_posterior = partial(
    unpooled_model_log_density, 
    data=example_data_up
)

mcmc_key_up = jax.random.key(1234)

states, info = run_nuts(
    key=mcmc_key_up,
    log_posterior=unpooled_model_log_posterior,
    init_params=jax.tree.map(jnp.array, example_params_up),
    init_sd=jax.tree.map(
        lambda l: jnp.full_like(l, 0.1), 
        jax.tree.map(jnp.array, example_params_up)
    ),
)
idata_up = get_idata(states, info)
az.summary(idata_up)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
alpha[0],0.832,0.374,0.178,1.570,0.006,0.009,3624.0,1493.0,1.00
alpha[1],0.878,0.110,0.682,1.092,0.002,0.003,3960.0,1426.0,1.00
alpha[2],1.534,0.448,0.749,2.416,0.008,0.012,3140.0,1433.0,1.00
alpha[3],1.551,0.282,1.068,2.078,0.005,0.007,3086.0,1674.0,1.00
alpha[4],1.432,0.404,0.683,2.166,0.006,0.011,4171.0,1431.0,1.01
...,...,...,...,...,...,...,...,...,...
alpha[82],1.622,0.221,1.222,2.040,0.003,0.006,4553.0,1437.0,1.00
alpha[83],1.641,0.211,1.253,2.027,0.004,0.005,3319.0,1380.0,1.00
alpha[84],1.171,0.544,0.221,2.250,0.009,0.012,3616.0,1567.0,1.00
beta,-0.720,0.073,-0.855,-0.580,0.001,0.002,2641.0,1617.0,1.00


```stan
data {
  int<lower=1> N;  // observations
  int<lower=1> J;  // counties
  array[N] int<lower=1, upper=J> county;
  vector[N] x;
  vector[N] y;
}
parameters {
  real mu_alpha;
  real<lower=0> sigma_alpha;
  vector<offset=mu_alpha, multiplier=sigma_alpha>[J] alpha;  // non-centered parameterization
  real beta;
  real<lower=0> sigma;
}
model {
  y ~ normal(alpha[county] + beta * x, sigma);  
  alpha ~ normal(mu_alpha, sigma_alpha); // partial-pooling
  beta ~ normal(0, 10);
  sigma ~ normal(0, 10);
  mu_alpha ~ normal(0, 10);
  sigma_alpha ~ normal(0, 10);
}
generated quantities {
  array[N] real y_rep = normal_rng(alpha[county] + beta * x, sigma);
}

```

In [None]:
def partial_pooled_model_log_density(parameters, data):
    mu_alpha, log_sigma_alpha, alpha_z, beta, log_sigma = parameters
    prior_sd = (10.0, 10.0, 1.0, 10.0, 10.0)
    x, county_ix, y = data
    sigma = jnp.exp(log_sigma)
    sigma_alpha = jnp.exp(log_sigma_alpha)
    alpha = mu_alpha + sigma_alpha * alpha_z 
    yhat = alpha[county_ix] + beta * x
    log_lik = norm.logpdf(y, loc=yhat, scale=jnp.full_like(yhat, sigma)).sum()
    log_prior = jax.tree.map(
        lambda leaf, sd: norm.logpdf(leaf, loc=0, scale=sd).sum(),
        parameters,
        prior_sd
    )
    return log_lik + jax.tree.reduce(operator.add, log_prior)

key_pp = jax.random.key(12345)
example_params_pp = (
    -1, 0.2, jax.random.normal(key_pp, shape=(N,)), 0.4, 0.0
)
example_data_pp = (
    mn_radon["log_uranium"].to_numpy(), 
    mn_radon["county_id"].to_numpy() - 1, 
    mn_radon["log_radon"].to_numpy()
)
partial_pooled_model_log_density(example_params_pp, example_data_pp)