
### Load the necessary libraries


In [None]:
import os

import dill
import numpy as np
import jax.numpy as jnp
import pandas as pd
import geopandas as gpd

import jax
from jax import random

import numpyro
from numpyro.infer import Predictive
import numpyro.distributions as dist

import arviz as az
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots

import matplotlib.pyplot as plt

import sys
sys.path.append(os.path.pardir)

jax.config.update("jax_default_device", jax.devices()[1])
print(f"Jax using device : {jax.devices()}")

ModuleNotFoundError: No module named 'numpyro'

### Load the necessary variables

In [None]:
# Lat/Lon Values of artificial grid
x = np.load(os.path.join("../data/processed/lat_lon_x_all.npy"))
# Low regional data
pol_pts_all = np.load("../data/processed/pol_pt_lo.npy")
pt_which_pol_all = np.load("../data/processed/pt_which_pol_all.npy")
# Dataframes
df = gpd.read_file("../data/processed/final_combined_divisions/final_combined_divisions.shp")

In [None]:
#load the model parameters
# Define the correct path for model_weights (sibling folder to src)
script_dir = os.getcwd()  # Get current working directory
save_dir = os.path.abspath(os.path.join(script_dir, "..", "model_weights"))  # Move up and into model_weights

# Ensure the directory exists
if not os.path.exists(save_dir):
    raise FileNotFoundError(f"Directory '{save_dir}' does not exist. Ensure the chains were saved properly.")

n_samples = 2000  # Adjust based on your settings

# Load MCMC objects
mcmc_list = []
for chain_id in range(3):  # Since you have chains 0 to 2
    # Find all matching files for the chain
    matching_files = sorted(
        [f for f in os.listdir(save_dir) if f.startswith(f"aggVAEPrev_chain{chain_id}_nsamples_{n_samples}_tt")],
        reverse=True  # Prioritize later files in case of multiple runs
    )

    if matching_files:
        file_path = os.path.join(save_dir, matching_files[0])  # Pick the latest file
        with open(file_path, "rb") as file:
            mcmc = dill.load(file)
            mcmc_list.append(mcmc)
        print(f"Loaded Chain {chain_id} from {file_path}")
    else:
        print(f"Warning: Missing Chain {chain_id} file!")

# Ensure all chains were loaded
if len(mcmc_list) != 3:
    raise ValueError(f"Not all chains were loaded successfully! Loaded {len(mcmc_list)}/3 chains.")

# Extract samples from NumPyro MCMC objects (keeping chains separate)
extracted_samples = [mcmc.get_samples(group_by_chain=True) for mcmc in mcmc_list]

### Check diagnostics

In [None]:
#metric values (ESS and Rhat)
combined_samples = mcmc.get_samples()

# Compute ESS and R-hat diagnostics
ss = numpyro.diagnostics.summary(combined_samples)

# Compute and print diagnostics
r = np.mean(ss["vae_aggr"]["n_eff"])
print(f"Average ESS for all aggVAE effects : {round(r)}")
print(f"Max r_hat for all aggVAE effects : {round(np.max(ss['vae_aggr']['r_hat']),2)}")
print(f"kernel_length R-hat : {round(ss['kernel_length']['r_hat'], 2)}")
print(f"kernel_var R-hat : {round(ss['kernel_var']['r_hat'],2)}")

In [None]:
#Plots
# Convert to ArviZ InferenceData format
idata = az.from_dict(posterior=combined_samples)

In [None]:
#trace plot (check mixing)
az.plot_trace(idata, var_names = "kernel_length")
az.plot_trace(idata, var_names = "kernel_var")

In [None]:
#rank plot (ensure good mixing across chains)
az.plot_rank(idata)

In [None]:
#rhat (check diagnostics)
print(az.rhat(idata))

In [None]:
#effective sample size (ESS, >1000 ideally)
print(az.ess(idata))

In [None]:
# Extract posterior samples (already combined)
pos_samples = idata.posterior

# Print MCMC summary
print(az.summary(idata, var_names=["vae_aggr", "kernel_length", "kernel_var"]))

### Extract prevalence estimate and observed prevalence data from the posterior and combine it into existing df

In [None]:
#extract the theta estimate (dengue prevalence probability)
# posterior predictive 1
args["predict"] = True
prev_posterior_predictive_vae = Predictive(prev_model_vae_aggr, prev_samples)(random.PRNGKey(1), args)

theta_samps_vae_aggr = prev_posterior_predictive_vae["theta"]
theta_mean_vae_aggr = theta_samps_vae_aggr.mean(axis = 0)
bci_vae_aggr_25 = np.quantile(theta_samps_vae_aggr,0.25, axis = 0)
bci_vae_aggr_75 = np.quantile(theta_samps_vae_aggr,0.75, axis = 0)
df["theta_vae_aggr"] = theta_mean_vae_aggr[0:df.shape[0]]

theta_observed = df["prev"]
theta_vae_aggr = theta_mean_vae_aggr

_max = np.max([theta_observed, theta_vae_aggr])
_min = np.min([theta_observed, theta_vae_aggr])

In [None]:
#check the df head
df.head()

In [None]:
#plot the observed prevalence vs gp-estimated prevalence
fig,ax = plt.subplots(1,2, figsize = (12,6))
df.plot(column = "theta_gp", ax = ax[0], legend = True, cmap = "plasma", vmin = _min, vmax = _max)
df.plot(column = "obs_prev", ax = ax[1], legend = True, cmap = "plasma", vmax = _max, vmin = _min)

ax[0].set_title("Observed Dengue Prevalence")
ax[1].set_title("Aggregated VAE-Estimated Dengue Prevalence")

In [None]:
#scatterplot
fig, ax = plt.subplots(1, figsize=(10,5))

ax.scatter(df.prev, df.theta_vae_aggr)
ax.set_ylim(_min-0.02, _max+0.02)
ax.set_xlim(_min-0.02, _max+0.02)
ax.axline((1, 1), slope=1, ls="--", c=".3")
ax.set_xlabel("Observed prevalence")
ax.set_ylabel("Estimated prevalence")
ax.set_title("Observations using aggVAE Priors")