### Load Libraries

In [None]:
import os
import sys

import numpy as np
import geopandas as gpd
import jax
import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

import matplotlib.pyplot as plt

import time
import pickle

sys.path.append(os.path.pardir)
from aggGP import exp_sq_kernel
from aggVAE import vae_decoder

### Define necessary functions

In [None]:
def exp_sq_kernel(x, z, var, length, noise, jitter=1.0e-4):
    dist = dist_euclid(x, z) #(7304, 7304)
    deltaXsq = jnp.power(dist/ length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    k += (noise + jitter) * jnp.eye(x.shape[0])
    return k # (ngrid_pts, ngrid_pts) <- (7304,7304)

In [None]:
def vae_decoder(hidden_dim, out_dim):
    return stax.serial(
        # (num_samples, z_dim) -> (num_samples, hidden_dim): (5,40) -> (5,50)
        stax.Dense(hidden_dim, W_init = stax.randn()),
        stax.Elu,
        # (num_samples, hidden_dim) -> (num_samples, num_regions) : (5,50) -> (5, 58)
        stax.Dense(out_dim, W_init = stax.randn())
    )

### Prevalence Disease Modelling

In [None]:
def prev_model_vae_aggr(args):
    """Dengue prevalence model with a Variational Autoencoder (VAE)"""

    x = args["x"]  # Spatial grid points: (num_grid_points, 2)
    pop_density = args["pop_density"]  # (num_districts,)
    hdi = args["hdi"]  # (num_districts,)
    M = args["M"]  # (num_districts, num_grid_points) aggregation matrix
    total_cases = args["total_cases"]
    total_population = args["total_population"]
    decoder_params = args["decoder_params"]
    out_dims = args["out_dims"]  # (num_districts,)
    predict = args["predict"]

    # VAE latent variable
    z_dim, h_dim = decoder_params[0][0].shape  # (latent_dim, hidden_dim)
    z = numpyro.sample("z", dist.Normal(jnp.zeros(z_dim), jnp.ones(z_dim)))  # (latent_dim,)
    dec_init_fn, dec_apply_fn = vae_decoder(h_dim, out_dims)  # Instantiate decoder
    vae_output = numpyro.deterministic("vae_output", dec_apply_fn(decoder_params, z))  # (num_grid_points,)

    # Aggregate VAE values to district level
    vae_aggr = numpyro.deterministic("vae_aggr", M @ vae_output)  # (num_districts,)

    # Fixed effects
    b0 = numpyro.sample("b0", dist.Normal(0, 1))  # Intercept
    b_pop_density = numpyro.sample("b_pop_density", dist.Normal(0, 1))  # Effect of population density
    b_hdi = numpyro.sample("b_hdi", dist.Normal(0, 1))  # Effect of HDI

    # Linear predictor
    lp = b0 + vae_aggr + b_pop_density * pop_density + b_hdi * hdi  # (num_districts,)
    theta = numpyro.deterministic("theta", jax.nn.sigmoid(lp))  # (num_districts,)

    # Binomial likelihood
    if not predict:
        observed_cases = numpyro.sample(
            "observed_cases",
            dist.Binomial(total_count=total_population, probs=theta),
            obs=total_cases
        )
    else:
        observed_cases = numpyro.sample(
            "observed_cases",
            dist.Binomial(total_count=total_population, probs=theta)
        )

    return observed_cases


### Load variables

In [None]:
# Lat/Lon Values of artificial grid
x = np.load("../data/processed/lat_lon_x_all.npy")

# combined regional data
pol_pts_all = np.load("../data/processed/pol_pts_all.npy")
pt_which_pol_all = np.load("../data/processed/pt_which_pol_all.npy")

#combine the dataframes
df_combined = gpd.read_file("../data/processed/final_combined_divisions/final_combined_divisions.shp")

### Load decoder model

In [None]:
with open("../model_weights/aggVAE_Dec_e20_h53_z48", "rb") as file:
    vae_params = pickle.load(file)

encoder_params = vae_params["encoder$params"]
decoder_params = vae_params["decoder$params"]

In [None]:
args = {
    "x": jnp.array(x),  # Spatial grid points
    "pop_density": jnp.array(df.Pop_density),  # Population density per district
    "hdi": jnp.array(df.HDI),  # HDI per district
    "M": jnp.array(pol_pts_all),  # Aggregation matrix for district-level prevalence
    "total_cases": jnp.array(df.Cases),  # Observed dengue cases per district
    "total_population": jnp.array(df.Population),  # Population tested per district

    # VAE training
    "rng_key": random.PRNGKey(5),
    "num_epochs": 20,
    "learning_rate": 0.0005,
    "batch_size": 100,
    "hidden_dim": 6,
    "z_dim": 3,
    "num_train": 100,
    "num_test":100,
    "vae_var": 1,

    # NN Weights
    "decoder_params" : decoder_params,
    "out_dims" : df.shape[0]

    # To handle Predictions since np.nans giving us issues
    "predict" : False,

    # Set to True only if you want to see VAE Aggr GP results before running MCMC
    "check_vae_samples" : True
}

### MCMC training

In [None]:
# 🔹 Random keys
run_key, predict_key = random.split(random.PRNGKey(3))

# 🔹 MCMC settings
n_warm = 1000
n_samples = 2000
n_chains = 4

In [None]:
# Get script location and define correct save directory (sibling to src/)
script_dir = os.getcwd()  # Get current working directory
save_dir = os.path.abspath(os.path.join(script_dir, "..", "model_weights"))  # Move up and into model_weights

# Ensure the directory exists
os.makedirs(save_dir, exist_ok=True)

# Run MCMC for each chain separately to prevent total loss on crash
chain_keys = random.split(run_key, n_chains)  # Precompute keys

for chain_id in range(n_chains):
    print(f"\nRunning Chain {chain_id + 1}/{n_chains}...")

    # Generate a separate key for each chain
    chain_run_key = chain_keys[chain_id]

    # Initialize MCMC with controlled step size
    mcmc = MCMC(NUTS(prev_model_vae_aggr),
        num_warmup=n_warm,
        num_samples=n_samples,
        num_chains=1)

    # Run the chain
    start = time.time()
    mcmc.run(chain_run_key, args)  # Ensure args is a tuple (args,)
    end = time.time()
    t_elapsed_min = round((end - start) / 60)

    # 🔹 Save after each chain completes
    f_path = os.path.join(save_dir, f"aggVAEPrev_chain{chain_id}_nsamples_{n_samples}_tt{t_elapsed_min}min_logit.pkl")
    with open(f_path, "wb") as file:
        dill.dump(mcmc, file)

    print(f"Saved Chain {chain_id + 1} to {f_path}")
    print(f"Time taken: {t_elapsed_min} min\n")

In [None]:
# 🔹 Print total elapsed time
total_end = time.time()
print("\nMCMC Total elapsed time:", round(total_end), "s")
print("MCMC Total elapsed time:", round(total_end / 60), "min")
print("MCMC Total elapsed time:", round(total_end / (60 * 60)), "h")