In [1]:
import os
import torch

import numpy as np
import pandas as pd
import geopandas as gpd

import time

import jax
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
from numpyro.infer import SVI, Trace_ELBO, Predictive
import numpyro.diagnostics

from termcolor import colored

import dill
import pickle
import arviz as az

In [2]:
#assign this to GPU 0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use GPU 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Script A running on {device}")

Script A running on cuda


# GP Kernel Function

In [3]:
def dist_euclid(x, z):
    """
    Computes Eucledian Distance Between Regions. This function is used by
    exp_sq_kernel function (kernel function for gaussian processes)
    """
    x = jnp.array(x) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    z = jnp.array(z) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    if len(x.shape)==1:
        x = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    if len(z.shape)==1:
        z = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    n_x, m = x.shape # 7304 , 2
    n_z, m_z = z.shape # 7304 , 2
    assert m == m_z
    delta = jnp.zeros((n_x,n_z)) #(ngrid_pts,ngrid_pts) <- i.e (7304,7304)
    for d in jnp.arange(m):
        x_d = x[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        z_d = z[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        delta += (x_d[:,jnp.newaxis] - z_d)**2 # (7304,7304)

    return jnp.sqrt(delta) #(7304,7304)

In [4]:
def exp_sq_kernel(x, z, var, length, noise, jitter=1.0e-4):
    dist = dist_euclid(x, z) #(7304, 7304)
    deltaXsq = jnp.power(dist/ length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    k += (noise + jitter) * jnp.eye(x.shape[0])
    return k # (ngrid_pts, ngrid_pts) <- (7304,7304)


#Aggregation Functions

In [5]:
def M_g(M, g):
    '''
    - $M$ is a matrix with binary entries $m_{ij},$ showing whether point $j$ is in polygon $i$
    - $g$ is a vector of GP draws over grid
    - $maltmul(M, g)$ gives a vector of sums over each polygon
    '''
    M = jnp.array(M)
    g = jnp.array(g).T
    return(jnp.matmul(M, g))

# Aggregated Prevalence Model (old implementation, produces all theta = 1) - must edit this to include HDI, population density

In [6]:
def prev_model_gp_aggr(args):
    """Dengue prevalence model with a Gaussian Process"""

    x = args["x"]  # Spatial grid points: (num_grid_points, 2)
    gp_kernel = args["gp_kernel"]  # Gaussian Process kernel
    noise = args["noise"]
    jitter = args["jitter"]

    pop_density = args["pop_density"]  # (num_districts,)
    hdi = args["hdi"]  # (num_districts,)
    M = args["M"]  # (num_districts, num_grid_points) aggregation matrix
    total_cases = args["total_cases"]
    total_population = args["total_population"]

    # GP hyperparameters
    kernel_length = numpyro.sample("kernel_length", args["kernel_length"])
    kernel_var = numpyro.sample("kernel_var", args["kernel_var"])

    # GP Kernel and Sample
    k = gp_kernel(x, x, kernel_var, kernel_length, noise, jitter)
    f = numpyro.sample("f", dist.MultivariateNormal(loc=jnp.zeros(x.shape[0]), covariance_matrix=k))  # (num_grid_points,)

    # Aggregate GP values to district level
    gp_aggr = numpyro.deterministic("gp_aggr", M @ f)  # (num_districts,)

    # Fixed effects
    b0 = numpyro.sample("b0", dist.Normal(0, 1))  # Intercept
    b_pop_density = numpyro.sample("b_pop_density", dist.Normal(0, 1))  # Effect of population density
    b_hdi = numpyro.sample("b_hdi", dist.Normal(0, 1))  # Effect of HDI

    # Standardize covariates
    pop_density = (pop_density - jnp.mean(pop_density)) / (jnp.std(pop_density))
    hdi = (hdi - jnp.mean(hdi)) / (jnp.std(hdi))

    # Linear predictor
    lp = b0 + gp_aggr + b_pop_density * pop_density + b_hdi * hdi  # (num_districts,)

    # Prevalence probability
    theta = numpyro.deterministic("theta", jax.nn.sigmoid(lp)* 1e-3) # (num_districts,)

    # Binomial likelihood
    observed_cases = numpyro.sample(
        "observed_cases",
        dist.Binomial(total_count=total_population, probs = theta),
        obs=total_cases
    )

    return observed_cases

# Load Data

In [7]:
# Lat/Lon Values of artificial grid
x = np.load("../data/processed/lat_lon_x_all.npy")

# combined regional data
pol_pts_all = np.load("../data/processed/pol_pts_all.npy", allow_pickle = True)
pt_which_pol_all = np.load("../data/processed/pt_which_pol_all.npy")

#combine the dataframes
df_combined = gpd.read_file("../data/processed/final_combined_divisions/final_combined_divisions.shp")

In [8]:
#check columns
df_combined.head()

Unnamed: 0,District,x,y,Year,Area_sq_km,HDI,Cases,Population,Pop_den,geometry
0,BANDUNG,107.610841,-7.099969,2020,1767.96,72.39,9180,14495160,8198.805403,"POLYGON ((107.73309 -6.814, 107.73354 -6.81427..."
1,BANDUNG,107.610841,-7.099969,2021,1767.96,72.73,8008,14662620,8293.52474,"POLYGON ((107.73309 -6.814, 107.73354 -6.81427..."
2,BANDUNG,107.610841,-7.099969,2022,1767.96,73.16,16764,14830092,8388.250865,"POLYGON ((107.73309 -6.814, 107.73354 -6.81427..."
3,BANDUNG,107.610841,-7.099969,2023,1767.96,73.74,4020,14997564,8482.97699,"POLYGON ((107.73309 -6.814, 107.73354 -6.81427..."
4,BANDUNG BARAT,107.414953,-6.897056,2020,1305.77,68.08,3864,7153344,5478.257273,"POLYGON ((107.40945 -6.68851, 107.40986 -6.688..."


#Vars needed to be changed (change according to the agg prevalence model parameters)

In [9]:
M = jnp.array(pol_pts_all)
pop_density = jnp.array(df_combined["Pop_den"])
hdi = jnp.array(df_combined["HDI"])
test_cases = jnp.array(df_combined["Population"])
cases = jnp.array(df_combined["Cases"])

In [10]:
#print the shape of all the vars above
print(M.shape)
print(pop_density.shape)
print(hdi.shape)
print(test_cases.shape)
print(cases.shape)
print(x.shape)
print(pt_which_pol_all.shape)

(96, 7304)
(96,)
(96,)
(96,)
(96,)
(7304, 2)
(7304,)


#Agg GP Model

In [11]:
 args = {
        "x" : jnp.array(x), # Lat/lon vals of grid points # Shape (num_districts, 2)
        "gp_kernel" : exp_sq_kernel,
        "jitter" : 1e-4,
        "noise" : 1e-4,
        "M" : jnp.array(pol_pts_all), # Aggregation matrix # Shape (num_districts, num_districts)
        # GP Kernel Hyperparams
        "kernel_length" : dist.InverseGamma(3,3), #(,)
        "kernel_var" : dist.HalfNormal(1e-5),
        "pop_density": jnp.array(df_combined["Pop_den"]), # Shape (num_districts,)
        "hdi": jnp.array(df_combined["HDI"]), # Shape (num_districts, 2)
        "total_cases" : jnp.array(df_combined["Cases"]),
        "total_population" : jnp.array(df_combined["Population"])
    }


#Run MCMC

In [12]:
# Random keys
run_key, predict_key = random.split(random.PRNGKey(3))

# MCMC settings
n_warm = 1000
n_samples = 2000

In [13]:
# Directory for saving
save_dir = "model_weights"
os.makedirs(save_dir, exist_ok=True)

#Save Model

In [None]:
# Random keys
base_key = random.PRNGKey(3)  # Base seed
chain_keys = random.split(base_key, 4)  # Generate 4 unique keys

# MCMC settings
n_warm = 1000
n_samples = 2000

# Get script location and define correct save directory (sibling to src/)
script_dir = os.getcwd()  # Get current working directory
save_dir = os.path.abspath(os.path.join(script_dir, "..", "model_weights"))  # Move up and into model_weights
os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists


# Run MCMC for each chain
for i, chain_key in enumerate(chain_keys, start=1):
    print(f"\n🔹 Running Chain {i} ...")
    
    # Initialize MCMC with NUTS kernel
    nuts_kernel = NUTS(prev_model_gp_aggr)
    mcmc = MCMC(nuts_kernel, num_warmup=n_warm, num_samples=n_samples)

    # Run the chain
    start = time.time()
    mcmc.run(chain_key, args)  # Ensure args is a tuple (args,)
    end = time.time()
    t_elapsed_min = round((end - start) / 60)

    # Save the MCMC object
    f_path = os.path.join(save_dir, f"aggGP_chain{i}_nsamples_{n_samples}_tt{t_elapsed_min}min_logit.pkl")
    with open(f_path, "wb") as file:
        dill.dump(mcmc, file)

    print(f"Saved Chain {i} to {f_path}")
    print(f"Time taken: {t_elapsed_min} min\n")


🔹 Running Chain 1 ...


warmup:   4%|▍         | 125/3000 [3:50:51<107:41:33, 134.85s/it, 1023 steps of size 1.69e-08. acc. prob=0.70]

In [None]:
# Print total elapsed time
total_end = time.time()
print("\nMCMC Total elapsed time:", round(total_end), "s")
print("MCMC Total elapsed time:", round(total_end / 60), "min")
print("MCMC Total elapsed time:", round(total_end / (60 * 60)), "h")