In [1]:
import numpy as np
import pandas as pd
import numpyro as pyro
from numpyro import distributions as dist

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def model(mun_code, vote, yes_obs=None):
    μ_α = pyro.sample("μ_α", dist.Normal(0.0, 1.0))
    σ_α = pyro.sample("σ_α", dist.HalfNormal(1.0))
    μ_β = pyro.sample("μ_β", dist.Normal(0.0, 1.0))
    σ_β = pyro.sample("σ_β", dist.HalfNormal(1.0))

    n_muns = len(np.unique(mun_code))

    with pyro.plate("patient_i", n_muns):
        α = pyro.sample("α", dist.Normal(μ_α, σ_α))
        β = pyro.sample("β", dist.Normal(μ_β, σ_β))

    yes_est = α[mun_code] + β[mun_code] * vote
    with pyro.plate("data", len(mun_code)):
        pyro.sample("obs", dist.Normal(yes_est, 1.0), obs=yes_obs)

In [3]:
df = pd.read_csv("../data/processed/swissvotes_votes.csv")
df

Unnamed: 0,region_id,region_name,votes_total,votes_yes,vote_id,canton_id
0,1.0,Aeugst am Albis,90.0,33.333333,138.0,1.0
1,2.0,Affoltern am Albis,620.0,55.322581,138.0,1.0
2,3.0,Bonstetten,181.0,49.171271,138.0,1.0
3,4.0,Hausen am Albis,263.0,41.444867,138.0,1.0
4,5.0,Hedingen,188.0,50.531915,138.0,1.0
...,...,...,...,...,...,...
1408408,6806.0,Vendlincourt,264.0,40.151515,661.0,26.0
1408409,6807.0,Basse-Allaine,470.0,44.680851,661.0,26.0
1408410,6808.0,Clos du Doubs,549.0,37.704918,661.0,26.0
1408411,6809.0,Haute-Ajoie,524.0,41.030534,661.0,26.0


In [4]:
from sklearn.preprocessing import LabelEncoder

df.dropna(inplace=True)

mun_encoder = LabelEncoder()
mun_code = mun_encoder.fit_transform(df["region_id"].values)
yes_obs = df["votes_yes"].values / 100
votes = df["vote_id"].values - df["vote_id"].min()

In [5]:
from numpyro.infer import MCMC, NUTS, Predictive
from jax import random
rng_key = random.PRNGKey(0)

nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)
mcmc.run(rng_key, mun_code, votes, yes_obs=yes_obs)

posterior_samples = mcmc.get_samples()

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
warmup:   2%|▏         | 71/4000 [06:41<6:10:32,  5.66s/it, 1023 steps of size 1.64e-05. acc. prob=0.71] 


KeyboardInterrupt: 