In [None]:
import pickle
import os

import numpy as np
import geopandas as gpd
import jax
import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import numpy as np
from aggGP import exp_sq_kernel
from aggVAE import vae_decoder

import time

In [None]:
def prev_model_vae_aggr(args):

    x = args["x"]
    out_dims = args["out_dims"]
    pop_density = args["pop_density"]
    hdi = args["hdi"]
    total_cases = args["total_cases"]
    total_population = args["total_population"]

    # random effect
    decoder_params =args["decoder_params"]
    z_dim, hidden_dim = decoder_params[0][0].shape #(3, 6)
    z = numpyro.sample("z", dist.Normal(jnp.zeros(z_dim), jnp.ones(z_dim))) #(3,)
    _, decoder_apply = vae_decoder(hidden_dim, out_dims) # Instantiate decoder
    vae_aggr = numpyro.deterministic("vae_aggr", decoder_apply(decoder_params, z)) #(9,)
    s = numpyro.sample("sigma", dist.HalfNormal(50)) #(,)
    vae = numpyro.deterministic("vae", s * vae_aggr) #(9,)

    ## Fixed effects
    b0 = numpyro.sample("b0", dist.Normal(0, 1))  # Intercept
    b_pop_density = numpyro.sample("b_pop_density", dist.Normal(0, 1))  # Effect of population density
    b_hdi = numpyro.sample("b_hdi", dist.Normal(0, 1))  # Effect of HDI

    # Linear predictor
    lp = b0 + vae + b_pop_density * pop_density + b_hdi * hdi  # (num_districts,)

    # Binomial likelihood
    observed_cases = numpyro.sample(
        "observed_cases",
        dist.Binomial(total_count=total_population, probs=theta),
        obs=total_cases
    )

##Load variables

In [None]:
# Lat/Lon Values of artificial grid
x = np.load("lat_lon_x_all.npy")

# combined regional data
pol_pts_all = np.load("pol_pts_all.npy")
pt_which_pol_all = np.load("pt_which_pol_all.npy")

#combine the dataframes
df_combined = gpd.read_file("final_combined_divisions.shp")

##Vars that need to be changed (??)

In [None]:
M = pol_pts_all
out_dims = df_combined.shape[0]

##Arguments to Model

In [None]:
args = {
        "total_cases" : jnp.array(df_combined["Cases"]),
        "total_population" : jnp.array(df_combined["Population"]),
        "hdi" : jnp.array(df_combined["HDI"]),
        "pop_density" : jnp.array(df_combined["Pop_density"]),
        "x" : jnp.array(x),
        "gp_kernel" : exp_sq_kernel,
        "jitter" : 1e-4,
        "noise" : 1e-4,
        "M" : M,
        # VAE training
        "rng_key": random.PRNGKey(5),
        "num_epochs": 20,
        #"learning_rate": 1.0e-3,
        "learning_rate": 0.0005,
        "batch_size": 100,
        "hidden_dim": 6,
        "z_dim": 3,
        "out_dims" : out_dims,
        "num_train": 100,
        "num_test":100,
        "vae_var": 1,
    }

##Load decoder model

In [None]:
#change the specific file name under the folder model_weights
with open("model_weights/aggVAE", "rb") as file:
        vae_params = pickle.load(file)

encoder_params = vae_params["encoder$params"]
decoder_params = vae_params["decoder$params"]
args["decoder_params"] = decoder_params

##Run MCMC (ask prof abt the warmup and the samples too)

In [None]:
mcmc_key, predict_key = random.split(random.PRNGKey(0))
start_time = time.time()
mcmc = MCMC(
        NUTS(prev_model_vae_aggr),
        num_warmup = 200,
        num_samples = 1000)

mcmc.run(mcmc_key, args, jnp.array(positive_cases))
t_elapsed = time.time() - start_time
t_elapsed_mins = int(t_elapsed / 60)

mcmc.print_summary(exclude_deterministic = False)