# Hierarchical modelling

Hierarchical structures are commonly found in both natural data and statistical models. These hierarchies can represent various levels of organization or grouping within the data, and incorporating them into Bayesian inference can provide more accurate and insightful results. Such approach to modelling allows to account for different sources of variation in the data.


There are typically three ways to account for hierarchies in Bayesian inference: no pooling, complete pooling, and partial pooling. Let's explore each of these approaches and provide Numpyro code examples for each case.

## No Pooling:

In the "no pooling" approach, each data point is treated independently without any grouping or hierarchical structure. This approach assumes that there is no shared information between data points, which can be overly simplistic when there is underlying structure or dependencies in the data.

In [17]:
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

from jax import random
import jax.numpy as jnp

rng_key = random.PRNGKey(678)

In [23]:
# Data
data = jnp.array([10, 12, 9, 11, 8]) # remember to turn data into a jnp array

# Model
def no_pooling_model(data):
    for i, obs in enumerate(data):
        mu_i = numpyro.sample(f"mu_{i}", dist.Normal(0, 10))
        sigma_i = numpyro.sample(f"sigma_{i}", dist.Exponential(1))
        numpyro.sample(f"obs_{i}", dist.Normal(mu_i, sigma_i), obs=data[i])

# Inference
nuts_kernel = NUTS(no_pooling_model)
mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=500)
mcmc.run(rng_key, data)

# Note how many mu-s and sigma-s are estimated
mcmc.print_summary()

sample: 100%|██████████| 1500/1500 [00:03<00:00, 400.58it/s, 31 steps of size 1.02e-01. acc. prob=0.74] 



                mean       std    median      5.0%     95.0%     n_eff     r_hat
      mu_0      9.98      1.06     10.00      8.17     11.35    338.50      1.00
      mu_1     11.71      1.77     11.97      8.72     14.16    304.16      1.00
      mu_2      8.85      1.35      8.92      6.53     10.57    419.85      1.00
      mu_3     10.81      1.37     10.92      8.91     12.81    212.94      1.00
      mu_4      7.87      1.23      7.96      6.34      9.82    115.95      1.00
   sigma_0      0.94      0.82      0.67      0.08      2.02    259.38      1.00
   sigma_1      1.15      1.11      0.81      0.06      2.53    117.35      1.00
   sigma_2      1.10      0.99      0.85      0.07      2.29    242.58      1.01
   sigma_3      1.06      0.89      0.84      0.08      2.20    102.34      1.00
   sigma_4      1.01      0.91      0.80      0.05      2.12    147.39      1.00

Number of divergences: 120


## Complete Pooling:

In the "complete pooling" approach, all data points are treated as if they belong to a single group or population, and the model estimates a single set of parameters for the entire dataset. This approach assumes that there is no variation between data points, which can be overly restrictive when there is actual heterogeneity in the data.

In [24]:
# Model
def complete_pooling_model(data):
    mu = numpyro.sample("mu", dist.Normal(0, 10))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)

# Inference
nuts_kernel = NUTS(complete_pooling_model)
mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=500)
mcmc.run(rng_key, data)

# Note how many mu-s and sigma-s are estimated
mcmc.print_summary()

sample: 100%|██████████| 1500/1500 [00:01<00:00, 786.76it/s, 3 steps of size 6.05e-01. acc. prob=0.93] 



                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      9.92      0.77      9.94      8.64     11.13    460.42      1.00
     sigma      1.65      0.53      1.55      0.93      2.37    388.83      1.00

Number of divergences: 0


## Partial Pooling:

In the "partial pooling" approach, the data is grouped into distinct categories or levels, and each group has its own set of parameters. However, these parameters are constrained by a shared distribution, allowing for both individual variation within groups and shared information across groups.

In [28]:
# Data with grouping information (e.g., groups A, B, C)
group_ids = [0, 0, 1, 1, 2]
data = jnp.array([10, 12, 9, 11, 8])

# Model
def partial_pooling_model(group_ids, data):

    num_groups = len(set(group_ids))
    with numpyro.plate("groups", num_groups):
        group_mu = numpyro.sample("group_mu", dist.Normal(0, 10))
        group_sigma = numpyro.sample("group_sigma", dist.Exponential(1))

    with numpyro.plate("data", len(data)):
        mu = numpyro.deterministic("mu", group_mu[group_ids])
        sigma = numpyro.deterministic("sigma", group_sigma[group_ids])
        obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)

# Inference
nuts_kernel = NUTS(partial_pooling_model)
mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=500)
mcmc.run(rng_key, group_ids, data)


TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.

In [15]:
# Data with grouping information (e.g., groups A, B, C)
group_ids = [0, 0, 1, 1, 2]
data = jnp.array([10, 12, 9, 11, 8])

# Model
def partial_pooling_model(group_ids, data):

    num_groups = len(set(group_ids))
    with numpyro.plate("groups", num_groups):
        group_mu = numpyro.sample("group_mu", dist.Normal(0, 10))
        group_sigma = numpyro.sample("group_sigma", dist.Exponential(1))

    with numpyro.plate("data", len(data)):    
        mu = numpyro.sample("mu", dist.Normal(group_mu[group_ids], group_sigma[group_ids]))
        sigma = numpyro.sample("sigma", dist.Exponential(1))
        obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)



    mu = group_mu[group_ids]
    sigma = group_sigma[group_ids]
    obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)

# Inference
nuts_kernel = NUTS(partial_pooling_model)
mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=500)
mcmc.run(rng_key, group_ids, data)

# Note how many mu-s and sigma-s are estimated
mcmc.print_summary()


TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.

In the partial pooling example, the `group_ids` variable indicates the group to which each data point belongs. This allows for the estimation of group-specific parameters while sharing information across groups through the shared distributions of `group_mu` and `group_sigma``.

These three approaches represent different ways to account for hierarchies in Bayesian inference, each with its own assumptions and implications for modeling real-world data. Depending on the specific context and data structure, one of these approaches may be more appropriate than the others.

In [None]:
def partial_pooling_model(group_ids, data):
    μ_α = numpyro.sample("μ_α", dist.Normal(0., 100.))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.))
    μ_β = numpyro.sample("μ_β", dist.Normal(0., 100.))
    σ_β = numpyro.sample("σ_β", dist.HalfNormal(100.))

    unique_patient_IDs = np.unique(PatientID)
    n_patients = len(unique_patient_IDs)

    with numpyro.plate("plate_i", n_patients):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β = numpyro.sample("β", dist.Normal(μ_β, σ_β))

    σ = numpyro.sample("σ", dist.HalfNormal(100.))
    FVC_est = α[PatientID] + β[PatientID] * Weeks

    with numpyro.plate("data", len(PatientID)):
        numpyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)

In [None]:
# Data with grouping information (e.g., groups A, B, C)
group_ids = [0, 0, 1, 1, 2]
data = jnp.array([10, 12, 9, 11, 8])

# Model
def partial_pooling_model(group_ids, data):

    num_groups = len(set(group_ids))
    num_data = len(data)

    with numpyro.plate("groups", num_groups):
        group_mu = numpyro.sample("group_mu", dist.Normal(0, 10))
        group_sigma = numpyro.sample("group_sigma", dist.Exponential(1))

In [30]:




    
    # Hyperparameters for group-level distributions
    group_mu = numpyro.sample("group_mu", dist.Normal(0, 10))
    group_sigma = numpyro.sample("group_sigma", dist.Exponential(1))
    
    # Individual parameters for each group
    with numpyro.plate("plate_group", num_groups):
        mu = numpyro.sample("mu", dist.Normal(group_mu, group_sigma))
        
    # Likelihood
    with numpyro.plate("plate_data", len(data)):
        numpyro.sample("obs", dist.Normal(mu[group_ids], 1), obs=data)

# Inference
nuts_kernel = NUTS(partial_pooling_model)
mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=500)
mcmc.run(group_ids, data)


ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())