In [22]:
import pandas as pd
import numpy as np
import constants as c
import os
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp

from jax import random
from numpyro.infer import MCMC, NUTS

In [23]:
df = pd.read_parquet(os.path.join(c.DATA_DIR, 'brvehins', 'brvehins_clean.parquet'))
df['log_exposure'] = np.log10(df['ExposTotal'])

n_cities = df['city_idx'].nunique()

In [27]:
def model(city_idx, log_exposure, claim_nb_coll):
    # Hyperparameters for the city random effect
    mu_b = numpyro.sample('mu_b', dist.Normal(0, 1))
    sigma_b = numpyro.sample('sigma_b', dist.HalfCauchy(1))

    # Random effect for each city
    with numpyro.plate('city', n_cities):
        b = numpyro.sample('b', dist.Normal(mu_b, sigma_b))

    # Global offset with single free parameter to be added to log mean
    log_offset = numpyro.sample('log_offset', dist.Normal(0, 1))

    # Model for the log-mean
    log_mean = log_exposure + b[city_idx] + log_offset

    # Poisson likelihood
    with numpyro.plate('data', len(city_idx)):
        numpyro.sample('obs', dist.Poisson(jnp.exp(log_mean)), obs=claim_nb_coll)

# Run NUTS
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500, num_chains=1)

# Clocks at 12s/it at 1024 steps, so ~10 ms per stepb
mcmc.run(rng_key_, city_idx=df['city_idx'].values, log_exposure=df['log_exposure'].values, claim_nb_coll=df['ClaimNbColl'].values)


warmup:  15%|█▌        | 154/1000 [28:41<2:37:35, 11.18s/it, 1023 steps of size 4.29e-04. acc. prob=0.75]


KeyboardInterrupt: 

In [25]:
df['ExposTotal'].describe()

count    2.658372e+06
mean     2.363132e+00
std      1.955782e+01
min      2.739726e-03
25%      4.200000e-01
50%      5.000000e-01
75%      1.030000e+00
max      2.963260e+03
Name: ExposTotal, dtype: float64