In [93]:
# generic imports
import time
#from pyprojroot2 import here
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
# jax import 
import jax
import jax.numpy as jnp
from jax import random
# numpyro import 
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC, Predictive
# arviz import 
import arviz as az

In [94]:
# config
jax.config.update('jax_platform_name', 'cpu');
print(f'We are using {jax.devices()} under the hood.')
#print(f'The default path is {here()}')

We are using [CpuDevice(id=0)] under the hood.


## Data Distribution Check

In [95]:
# let us first read the data and then understand what we have
#data_path = here() / "simulation study" / "data"
df_lo = gpd.read_file("../data/jkt_prov.shp")
df_hi = gpd.read_file("../data/jkt_dist.shp")
# let us first understand what we have
print(df_lo.columns)
print(df_hi.columns)

Index(['Province', 'Year', 'Cases', 'Population', 'HDI', 'Area_sq_km',
       'Pop_den', 'urbanicity', 'geometry'],
      dtype='object')
Index(['District', 'Year', 'Area_sq_km', 'HDI', 'Province', 'Cases',
       'Population', 'Pop_den', 'urbanicity', 'geometry'],
      dtype='object')


In [96]:
# Calculate and display dengue incidence by province and year for low resolution data
lo_prev = df_lo.copy()
lo_prev['incidence'] = (lo_prev.Cases / lo_prev.Population)
print("Dengue incidence by province and year (low resolution):")
print(lo_prev.groupby(['Province', 'Year'])[['incidence']].mean(), '\n')

# Below is code for high prev data but I  won't look at it as we want't to 
# model it not bias our choice of priors
# Calculate and display dengue incidence by district for high resolution data  
# hi_prev = df_hi.copy()
# hi_prev['incidence'] = hi_prev.Cases / hi_prev.Population
# print("Dengue incidence by district and year (high resolution):")
# print(hi_prev.groupby(['District', 'Year'])[['incidence']].mean())

# make sure we use only one year of data 
year_data = df_lo.Year.max()
# print(f'We are using data for the year {year_data}\n')
# let us filter the data for the most recent year
df_lo = df_lo[df_lo.Year == year_data]
df_hi = df_hi[df_hi.Year == year_data]


Dengue incidence by province and year (low resolution):
                  incidence
Province    Year           
DKI Jakarta 2020   0.001519
            2021   0.001315
            2022   0.002667
            2023   0.001983 



# GP Kernel Function

In [97]:
def dist_euclid(x, z):
    """
    Computes Eucledian Distance Between Regions. This function is used by
    exp_sq_kernel function (kernel function for gaussian processes)
    """
    x = jnp.array(x) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    z = jnp.array(z) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    if len(x.shape)==1:
        x = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    if len(z.shape)==1:
        z = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    n_x, m = x.shape # 7304 , 2
    n_z, m_z = z.shape # 7304 , 2
    assert m == m_z
    delta = jnp.zeros((n_x,n_z)) #(ngrid_pts,ngrid_pts) <- i.e (7304,7304)
    for d in jnp.arange(m):
        x_d = x[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        z_d = z[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        delta += (x_d[:,jnp.newaxis] - z_d)**2 # (7304,7304)

    return jnp.sqrt(delta) #(7304,7304)

In [98]:
def exp_sq_kernel(x, z, var, length, noise, jitter=1.0e-4):
    dist = dist_euclid(x, z) #(7304, 7304)
    deltaXsq = jnp.power(dist/ length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    k += (noise + jitter) * jnp.eye(x.shape[0])
    return k # (ngrid_pts, ngrid_pts) <- (7304,7304)


# Aggregation Functions

In [99]:
def M_g(M, g):
    '''
    - $M$ is a matrix with binary entries $m_{ij},$ showing whether point $j$ is in polygon $i$
    - $g$ is a vector of GP draws over grid
    - $maltmul(M, g)$ gives a vector of sums over each polygon
    '''
    M = jnp.array(M)
    g = jnp.array(g).T
    return(jnp.matmul(M, g))

# Aggregated Prevalence Model - must edit this to include HDI, population density



In [100]:
def prev_model_gp_aggr(args):
    """Dengue prevalence model with a Gaussian Process"""

    x = args["x"]  # Spatial grid points: (num_grid_points, 2)
    gp_kernel = args["gp_kernel"]  # Gaussian Process kernel
    noise = args["noise"]
    jitter = args["jitter"]

    pop_density_lo = args["pop_density_lo"]  # (4,) one province (jkt) for 4 yrs' data
    pop_density_hi = args["pop_density_hi"]  # (24,)

    #aggregate pop_density tgt
    pop_density = jnp.concatenate([pop_density_lo,pop_density_hi], axis = 0)

    #aggregate hdi tgt

    hdi_lo = args["hdi_lo"]  # (4,) 6 districts within jkt for 4 yrs' data
    hdi_hi = args["hdi_hi"]  # (24,)
    hdi = jnp.concatenate([hdi_lo,hdi_hi], axis = 0)

    urban_lo = args["urban_lo"]  # (4,) 6 districts within jkt for 4 yrs' data
    urban_hi = args["urban_hi"]  # (24,)
    urban = jnp.concatenate([urban_lo,urban_hi], axis = 0)

    M_lo = args["M_lo"]  # (4, num_grid_points) aggregation matrix
    M_hi = args["M_hi"]  # (24, num_grid_points) aggregation matrix

    total_cases_lo = args["total_cases_lo"] #cos we wanna predict total cases district-wise, so only feed total cases for low res data

    total_population_lo = args["total_population_lo"]
    total_population_hi = args["total_population_hi"]

    #aggregate total populations low and high
    total_population = jnp.concatenate([total_population_lo,total_population_hi], axis = 0)

    #add NaN values to total_cases to accommodate for unavailable total cases data for high resolution (that we want to predict)
    total_cases = jnp.pad(total_cases_lo, (0, M_hi.shape[0]),constant_values = 0.0) #[3762.  484. ... , 0,0,0]
    total_cases = jnp.where(total_cases == 0, jnp.nan, total_cases)# [3762.  484. ... , nan,nan,nan]
    total_cases_mask = ~jnp.isnan(total_cases) # [True, True, ...., False, False, False]

    # GP hyperparameters
    kernel_length = numpyro.sample("kernel_length", args["kernel_length"])
    kernel_var = numpyro.sample("kernel_var", args["kernel_var"])

    # GP Kernel and Sample
    k = gp_kernel(x, x, kernel_var, kernel_length, noise, jitter)
    f = numpyro.sample("f", dist.MultivariateNormal(loc=jnp.zeros(x.shape[0]), covariance_matrix=k))  # (num_grid_points,)

    # Aggregate GP values to district level
    gp_aggr_lo = numpyro.deterministic("gp_aggr_lo", M_g(M_lo, f))  # (4,)
    gp_aggr_hi = numpyro.deterministic("gp_aggr_hi", M_g(M_hi, f))  # (24,)

    # Now we need to aggregate both. This step is important since even though we only
    # show the model the low resolution data, to produce high resolution data it
    # needs the GP realizations for those regions
    gp_aggr = numpyro.deterministic("gp_aggr", jnp.concatenate([gp_aggr_lo,gp_aggr_hi])) #(28,)

    # Fixed effects
    b0 = numpyro.sample("b0", dist.Normal(-5.25, 0.5))  # Intercept
    b_pop_density = numpyro.sample("b_pop_density", dist.Normal(0, 0.33))  # Effect of population density normal (2, 0.25) previously
    b_hdi = numpyro.sample("b_hdi", dist.Normal(0, 0.33))  # Effect of HDI
    b_urban = numpyro.sample("b_urban", dist.Normal(0, 0.33))  # Effect of urbanicity

    #standardise all before passing into lp
    #pop_density = (pop_density - jnp.mean(pop_density)) / jnp.std(pop_density)
    #hdi = (hdi-jnp.mean(hdi)) / jnp.std(hdi)
    #urban = (urban - jnp.mean(urban)) / jnp.std(urban)

    # Linear predictor
    lp = numpyro.deterministic("lp", b0 - gp_aggr - b_pop_density * pop_density + b_hdi * hdi - b_urban * urban)  # (num_districts,)
    #lp = numpyro.deterministic("lp", (b0 - gp_aggr - b_pop_density * pop_density))  # (num_districts,)
    #lp = numpyro.deterministic("lp", (b0 - gp_aggr))  # (num_districts,)

    # Prevalence probability
    theta = numpyro.deterministic("theta", jax.nn.sigmoid(lp))  # (num_districts,)

    # Binomial likelihood
    with numpyro.handlers.mask(mask=total_cases_mask):
        pred_cases = numpyro.sample(
            "pred_cases",
            dist.Binomial(total_count=total_population, probs=theta),
            obs=total_cases)

    return pred_cases


# Load Data

In [101]:
# Lat/Lon Values of artificial grid
x = np.load("../data/lat_lon_x_jkt.npy")
# combined regional data
pol_pts_jkt_lo = np.load("../data/pol_pts_jkt_lo.npy")
pt_which_pol_jkt_lo = np.load("../data/pt_which_pol_jkt_lo.npy")
pol_pts_jkt_hi = np.load("../data/pol_pts_jkt_hi.npy")
pt_which_pol_jkt_hi = np.load("../data/pt_which_pol_jkt_hi.npy")

# Vars needed to be changed (change according to the agg prevalence model parameters)

In [102]:
#M_lo = jnp.array(pol_pts_jkt_lo)[0,:].reshape(1, -1)
#M_hi = jnp.array(pol_pts_jkt_hi)[[0,4,9,14,19], :]
M_lo = jnp.array(pol_pts_jkt_lo)
M_hi = jnp.array(pol_pts_jkt_hi)
print(M_lo)
print(M_hi)
pop_density_lo = jnp.array(df_lo["Pop_den"])
pop_density_hi = jnp.array(df_hi["Pop_den"]) 
hdi_lo = jnp.array(df_lo["HDI"])
hdi_hi = jnp.array(df_hi["HDI"])
urban_lo = jnp.array(df_lo["urbanicity"])
urban_hi = jnp.array(df_hi["urbanicity"])
cases_lo = jnp.array(df_lo["Cases"])
pop_lo = jnp.array(df_lo["Population"])
pop_hi = jnp.array(df_hi["Population"])

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 1 1 1 1 0 0 0 1 1 1 1
  1 1 1 0 0 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 0 0 1
  1 1 1 1 1 1 1 1 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 1 1 1 1]]
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 1
  1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0
  0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 1
  1 1 1 1 1 0 0 0 0 0 0

In [103]:
print(pop_density_lo)

[1211.4202]


In [104]:
#print the shape of all the vars above
print(M_lo.shape)
print(M_hi.shape)
print(pop_density_lo.shape)
print(pop_density_hi.shape)
print(hdi_lo.shape)
print(hdi_hi.shape)
print(cases_lo.shape)
print(pop_lo.shape)
print(pop_hi.shape)
print(x.shape)

(1, 100)
(5, 100)
(1,)
(5,)
(1,)
(5,)
(1,)
(1,)
(5,)
(100, 2)


In [105]:
args = {
        "x" : jnp.array(x), # Lat/lon vals of grid points # Shape (num_districts, 2)
        "gp_kernel" : exp_sq_kernel,
        "jitter" : 1e-4,
        "noise" : 1e-4,
        "M_lo" : M_lo, # Aggregation matrix # Shape (num_districts, num_districts)
        "M_hi" : M_hi, # Aggregation matrix # Shape (num_districts, num_districts)
        # GP Kernel Hyperparams
        "kernel_length" :  dist.InverseGamma(3, 3), #(,)
        "kernel_var" : dist.LogNormal(0,0.5),
        "pop_density_lo": pop_density_lo, # Shape (num_districts,)
        "pop_density_hi": pop_density_hi, # Shape (num_districts,)
        "hdi_lo": hdi_lo, # Shape (num_districts, 2)
        "hdi_hi": hdi_hi, # Shape (num_districts, 2)
        "urban_lo": urban_lo,
        "urban_hi": urban_hi,
        "total_cases_lo" : cases_lo,
        "total_population_lo" : pop_lo,
        "total_population_hi" : pop_hi,
}


## Prior Predictive Check

In [106]:
# Prior Predictive Check
prior_samples = Predictive(prev_model_gp_aggr, num_samples=500)(
    random.PRNGKey(6), args)

# transform prior samples to arviz inference object
prior_samples_arviz = az.from_numpyro(prior=prior_samples)

In [107]:
az.summary(prior_samples_arviz, var_names = ["theta", "lp", "gp_aggr", "b0", "b_pop_density", "b_hdi", "b_urban"], stat_focus = "median")



Unnamed: 0,median,mad,eti_3%,eti_97%,mcse_median,ess_median,ess_tail,r_hat
theta[0],1.0,0.0,0.0,1.0,0.478,450.902,500.0,
theta[1],0.004,0.004,0.0,1.0,0.002,416.989,514.0,
theta[2],0.005,0.005,0.0,0.979,0.002,446.318,472.0,
theta[3],0.004,0.004,0.0,1.0,0.003,500.016,526.0,
theta[4],0.002,0.002,0.0,1.0,0.002,456.061,500.0,
theta[5],0.002,0.002,0.0,1.0,0.001,458.526,500.0,
lp[0],9.551,263.442,-713.885,726.724,19.936,450.902,437.0,
lp[1],-5.654,6.606,-24.792,13.137,0.578,416.989,514.0,
lp[2],-5.384,3.344,-15.139,3.844,0.352,446.318,472.0,
lp[3],-5.576,7.947,-29.609,17.154,0.752,500.016,526.0,


# Run MCMC

In [108]:
# Base seed for reproducibility
base_seed = 3  # Keep this fixed for full replicability
# MCMC settings
n_warm = 1000
n_samples = 2000

# Save Model

In [109]:
rng = jax.random.PRNGKey(base_seed)
mcmc = MCMC(NUTS(prev_model_gp_aggr, target_accept_prob = 0.8, max_tree_depth = 10), # change for later run to 0.95 and 12
            num_warmup=n_warm,
            num_samples=n_samples,
            num_chains=4,
            chain_method="vectorized",
            progress_bar=True
            )
start = time.time()
mcmc.run(rng, args)
end = time.time()
t_elapsed_min = round((end - start) / 60)
print(f"Time taken to run MCMC: {t_elapsed_min} minutes")

warmup:   2%|▏         | 71/3000 [01:26<59:08,  1.21s/it]  


KeyboardInterrupt: 

In [None]:
# creating  prior and posterior predictive
rng_key_pr, rng_key_po = random.split(random.PRNGKey(4))
posterior_samples = mcmc.get_samples()
posterior_predictive = Predictive(prev_model_gp_aggr, posterior_samples)(
    rng_key_po, args
)
prior = Predictive(prev_model_gp_aggr, num_samples=500)(
    rng_key_pr, args
)

# Working with Inference Object

In [None]:
Total_districts = df_lo.shape[0] + df_hi.shape[0]
numpyro_data = az.from_numpyro(
    mcmc,
    prior=prior,
    posterior_predictive=posterior_predictive,
    coords={"district": np.arange(Total_districts)},
    dims={"theta": ["district"]},
)

In [None]:
# save the inference object
from pyprojroot2 import here
save_path = here() / "simulation study" / "model_runs" / "aggGP_inference_object.nc"
save_path.parent.mkdir(parents=True, exist_ok=True)

In [None]:
numpyro_data.to_netcdf(save_path)

In [None]:
numpyro_data = az.from_netcdf(save_path)

In [None]:
az.summary(numpyro_data, var_names = ["kernel"], filter_vars="like", stat_focus = "median")

In [None]:
az.summary(numpyro_data, var_names = ["theta", "lp", "gp_aggr", "b"], filter_vars="like", stat_focus = "median")

In [None]:
theta_samples = numpyro_data.posterior.theta.values
n_lo = df_lo.shape[0]
n_hi = df_hi.shape[0]
theta_mean_gp = np.median(theta_samples, axis = (0, 1))
bci_gp_25 = np.quantile(theta_samples,0.25,axis = (0, 1))
bci_gp_75 = np.quantile(theta_samples,0.75, axis = (0, 1))
# Slice IQR values to match low-res and high-res
bci_lo_25 = bci_gp_25[:n_lo]
bci_lo_75 = bci_gp_75[:n_lo]

bci_hi_25 = bci_gp_25[n_lo:n_lo + n_hi]
bci_hi_75 = bci_gp_75[n_lo:n_lo + n_hi]


df_lo["obs_prev"] = df_lo["Cases"] / df_lo["Population"]
df_hi["obs_prev"] = df_hi["Cases"] / df_hi["Population"]

df_lo["theta_gp"] = theta_mean_gp[0:n_lo] 
df_hi["theta_gp"] = theta_mean_gp[n_lo:n_lo + n_hi] * 1e-2

theta_obs_lo = df_lo["obs_prev"]
theta_gp_est_lo = df_lo["theta_gp"]
theta_obs_hi = df_hi["obs_prev"]
theta_gp_est_hi = df_hi["theta_gp"]

max_val_lo = np.max([theta_obs_lo, theta_gp_est_lo])
min_val_lo = np.min([theta_obs_lo, theta_gp_est_lo])

max_val_hi = np.max([theta_obs_hi, theta_gp_est_hi])
min_val_hi = np.min([theta_obs_hi, theta_gp_est_hi])

# Plot the low resolution data

In [None]:
# Create the figure and axes
fig, ax = plt.subplots(1, 2, figsize=(12, 6))

# Plot observed prevalence on the map
df_lo.plot(
    column="obs_prev",  # Column to use for color
    cmap="viridis",  # Colormap
    vmin=0.001,  # Minimum value for color scale
    vmax=0.01,  # Maximum value for color scale
    legend=True,  # Show legend
    ax=ax[0],  # Plot on the first subplot
)
ax[0].set_title("Low-Res Observed Prevalence")

# Plot estimated prevalence on the map
df_lo.plot(
    column="theta_gp",  # Column to use for color
    cmap="viridis",  # Colormap
    vmin=0.001,  # Minimum value for color scale
    vmax=0.01,  # Maximum value for color scale
    legend=True,  # Show legend
    ax=ax[1],  # Plot on the second subplot
)
ax[1].set_title("Low-Res Estimated Prevalence (θ)")

# Save the plot
plt.savefig("observed_vs_estimated_prevalence_lo.png")  # Save as PNG
# Or save as PDF:
# plt.savefig("observed_vs_estimated_prevalence_lo.pdf")

plt.tight_layout()
plt.show()

#save the plot
plt.figure(figsize=(8, 6))

plt.show()

# Plot high resolution data

In [None]:
# Create the figure and axes
fig, ax = plt.subplots(1, 2, figsize=(12, 6))

# Plot observed prevalence on the map
df_hi.plot(
    column="obs_prev",  # Column to use for color
    cmap="viridis",  # Colormap
    vmin=0.001,  # Minimum value for color scale
    vmax=0.01,  # Maximum value for color scale
    legend=True,  # Show legend
    ax=ax[0],  # Plot on the first subplot
)
ax[0].set_title("High-Res Observed Prevalence")

# Plot estimated prevalence on the map
df_hi.plot(
    column="theta_gp",  # Column to use for color
    cmap="viridis",  # Colormap
    vmin=0.001,  # Minimum value for color scale
    vmax=0.01,  # Maximum value for color scale
    legend=True,  # Show legend
    ax=ax[1],  # Plot on the second subplot
)
ax[1].set_title("High-Res Estimated Prevalence (θ)")

# Save the plot
plt.savefig("observed_vs_estimated_prevalence_hi.png")  # Save as PNG
# Or save as PDF:
# plt.savefig("observed_vs_estimated_prevalence.pdf")

plt.tight_layout()
plt.show()

#save the plot
plt.figure(figsize=(8, 6))

plt.show()