In [13]:
# Standard library
import os
import sys
import math
import time
import pickle
from pathlib import Path
import itertools

# Numerical and scientific computing
import numpy as np
import matplotlib.pyplot as plt

# PyTorch (though not used in this snippet)
import torch

# JAX and NumPyro
import jax
import jax.numpy as jnp
import jax.nn as nn
from jax import random, lax
from jax.random import PRNGKey
from jax.example_libraries import stax

import numpyro
import numpyro.distributions as dist
from numpyro.infer import (
    SVI, MCMC, NUTS, Predictive,
    init_to_median, init_to_uniform,
    init_to_sample, init_to_mean, init_to_value,
    Trace_ELBO, RenyiELBO
)

# Data visualization and analysis
import plotly.express as px
import arviz as az
import geopandas as gpd

# Miscellaneous
from termcolor import colored

In [14]:
def dist_euclid(x, z):
    """
    Computes Eucledian Distance Between Regions. This function is used by
    exp_sq_kernel function (kernel function for gaussian processes)
    """
    x = jnp.array(x) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    z = jnp.array(z) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    if len(x.shape)==1:
        x = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    if len(z.shape)==1:
        z = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    n_x, m = x.shape # 7304 , 2
    n_z, m_z = z.shape # 7304 , 2
    assert m == m_z
    delta = jnp.zeros((n_x,n_z)) #(ngrid_pts,ngrid_pts) <- i.e (7304,7304)
    for d in jnp.arange(m):
        x_d = x[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        z_d = z[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        delta += (x_d[:,jnp.newaxis] - z_d)**2 # (7304,7304)

    return jnp.sqrt(delta) #(7304,7304)

def exp_sq_kernel(x, z, var, length, noise, jitter=1.0e-4):
    """
    Exponential squared kernel (RBF kernel) for Gaussian processes
    
    Args:
        x: First set of points (n_x, d)
        z: Second set of points (n_z, d)
        var: Kernel variance
        length: Length scale
        noise: Noise term
        jitter: Small constant for numerical stability
    
    Returns:
        Kernel matrix of shape (n_x, n_z)
    """
    dist = dist_euclid(x, z) #(7304, 7304)
    deltaXsq = jnp.power(dist/ length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    k += (noise + jitter) * jnp.eye(x.shape[0])
    return k # (ngrid_pts, ngrid_pts) <- (7304,7304)

def M_g(M, g):
    '''
    Matrix multiplication for aggregating GP draws over polygons
    
    Args:
        M: Matrix with binary entries m_ij, showing whether point j is in polygon i
        g: Vector of GP draws over grid
    
    Returns:
        Vector of sums over each polygon
    '''
    M = jnp.array(M)
    g = jnp.array(g).T
    return(jnp.matmul(M, g)) 

# AggVAE Model

## Function for Predictive Simulation (Prior)

In [15]:
def gp_aggr(config=None):
    """
    Gaussian Process aggregation model for spatial data
    
    Args:
        config: Dictionary containing model configuration:
            x: Spatial grid points (num_grid_points, 2)
            gp_kernel: Gaussian Process kernel function (default: exp_sq_kernel)
            noise: Noise parameter (default: 1e-4)
            jitter: Jitter parameter for numerical stability (default: 1e-4)
            M_lo: Low resolution aggregation matrix (9, 2618)
            M_hi: High resolution aggregation matrix (49, 2618)
            kernel_length: Prior distribution for kernel length (default: InverseGamma(4,1))
            kernel_var: Prior distribution for kernel variance (default: LogNormal(0,0.1))
    
    Returns:
        Sampled GP values and their aggregations
    """
    if config is None:
        config = {}
        
    # Set defaults
    x = config.get('x', None)
    gp_kernel = config.get('gp_kernel', exp_sq_kernel)
    noise = config.get('noise', 1e-4)
    jitter = config.get('jitter', 1e-4)
    M_lo = config.get('M_lo', None)
    M_hi = config.get('M_hi', None)
    kernel_length_prior = config.get('kernel_length', dist.InverseGamma(4, 1))
    kernel_var_prior = config.get('kernel_var', dist.LogNormal(0, 0.1))

    # GP hyperparameters
    kernel_length = numpyro.sample("kernel_length", kernel_length_prior)
    kernel_var = numpyro.sample("kernel_var", kernel_var_prior)
    log_mean = numpyro.sample("log_mean", dist.Normal(jnp.log(20), 0.1))
    
    # Create kernel with smaller variance
    k = gp_kernel(x, x, kernel_var, kernel_length, noise, jitter)
    
    # Sample GP on log scale with appropriate mean
    log_f = numpyro.sample("log_f", dist.MultivariateNormal(
        loc=jnp.ones(x.shape[0]) * log_mean,
        covariance_matrix=k))
    
    # Transform back to original scale
    f = jnp.exp(log_f)  # This will naturally be positive and centered around 1400
    
    # Aggregate as before
    gp_aggr_lo = numpyro.deterministic("gp_aggr_lo", M_g(M_lo, f))
    gp_aggr_hi = numpyro.deterministic("gp_aggr_hi", M_g(M_hi, f))
    gp_aggr = numpyro.deterministic("gp_aggr", jnp.concatenate([gp_aggr_lo, gp_aggr_hi]))

    return gp_aggr

## Define the VAE

In [16]:
#vae architecture
def vae_encoder(hidden_dim=50, z_dim=40):
    """
    VAE encoder network architecture
    
    Args:
        hidden_dim: Number of hidden dimensions
        z_dim: Number of latent dimensions
    
    Returns:
        Stax neural network for encoding
    """
    return stax.serial(
        #(num_samples, num_regions) -> (num_samples, hidden_dims)
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Elu,
        stax.FanOut(2),
        stax.parallel(
            # mean : (num_samples, hidden_dim) -> (num_samples, z_dim)
            stax.Dense(z_dim, W_init=stax.randn()), #(5,50)
            #std : (num_samples, hidden_dim) -> (num_samples, z_dim)
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)
        )
    )

def vae_decoder(hidden_dim, out_dim):
    """
    VAE decoder network architecture
    
    Args:
        hidden_dim: Number of hidden dimensions
        out_dim: Output dimensions (number of regions)
    
    Returns:
        Stax neural network for decoding
    """
    return stax.serial(
        # (num_samples, z_dim) -> (num_samples, hidden_dim)
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Elu,
        # (num_samples, hidden_dim) -> (num_samples, num_regions)
        stax.Dense(out_dim, W_init=stax.randn())
    )

def vae_model(batch, hidden_dim, z_dim, vae_var):
    """
    VAE model (decoder portion)
    
    Args:
        batch: Input data batch
        hidden_dim: Number of hidden dimensions
        z_dim: Number of latent dimensions
        vae_var: VAE variance parameter
    
    Returns:
        Sampled observations
    """
    batch = jnp.reshape(batch, (batch.shape[0], -1)) # (num_samples, num_regions)
    batch_dim, out_dim = jnp.shape(batch)

    # vae-decoder in numpyro module
    decode = numpyro.module(
        name="decoder",
        nn=vae_decoder(hidden_dim=hidden_dim, out_dim=out_dim),
        input_shape=(batch_dim, z_dim) #(5,40)
    )

    # Sample a univariate normal
    z = numpyro.sample(
        "z",
        dist.Normal(
            jnp.zeros((batch_dim,z_dim)),
            jnp.ones((batch_dim,z_dim))
        )
    )
    # Forward pass from decoder
    gen_loc = decode(z) #(num_regions,)
    obs = numpyro.sample(
        "obs",
        dist.Normal(gen_loc, vae_var),
        obs=batch
    ) #(num_samples, num_regions)
    return obs

def vae_guide(batch, hidden_dim, z_dim):
    """
    VAE guide (encoder portion)
    
    Args:
        batch: Input data batch
        hidden_dim: Number of hidden dimensions
        z_dim: Number of latent dimensions
    
    Returns:
        Sampled latent variables
    """
    batch = jnp.reshape(batch, (batch.shape[0], -1)) #(num_samples, num_regions)
    batch_dim, input_dim = jnp.shape(batch)# num_samples , num_regions

    # vae-encoder in numpyro module
    encode = numpyro.module(
        name="encoder",
        nn=vae_encoder(hidden_dim=hidden_dim,z_dim=z_dim),
        input_shape=(batch_dim, input_dim) #(5,58)
    ) #(num_samples, num_regions) -> (num_samples, hidden_dims)

    # Sampling mu, sigma - Pretty much the forward pass
    z_loc, z_std = encode(batch) #mu : (num_samples, z_dim), sigma2 : (num_samples, z_dim)
    # Sample a value z based on mu and sigma
    z = numpyro.sample("z", dist.Normal(z_loc, z_std)) #(num_sample, z_dim)
    return z

## Train the VAE encoder

In [17]:
@jax.jit
def epoch_train(rng_key, svi_state, num_train):
    def body_fn(i, val):
        rng_key_i = jax.random.fold_in(rng_key, i) #Array(2,)
        rng_key_i, rng_key_ls, rng_key_var, rng_key_noise = jax.random.split(rng_key_i, 4) #Tuple(Array(2,) x 4)
        loss_sum, svi_state = val #val --svi_state

        batch = agg_gp_predictive(rng_key_i, args)["gp_aggr"] #(5,116) <- num_samples : 5, total_districts : 116
        #* svi is where the vae_model & vae_guide gets applied
        svi_state, loss = svi.update(svi_state, batch)
        loss_sum += loss / args["batch_size"]
        return loss_sum, svi_state

    return lax.fori_loop(lower = 0, upper = num_train, body_fun=body_fn, init_val=(0.0, svi_state))

@jax.jit
def eval_test(rng_key, svi_state, num_test):
    def body_fn(i, loss_sum):
        rng_key_i = jax.random.fold_in(rng_key, i)
        rng_key_i, rng_key_ls, rng_key_varm, rng_key_noise = jax.random.split(rng_key_i, 4)
        batch = agg_gp_predictive(rng_key_i, args)["gp_aggr"]
        #* svi is where the vae_model & vae_guide gets applied
        loss = svi.evaluate(svi_state, batch) / args["batch_size"]
        loss_sum += loss
        return loss_sum

    loss = lax.fori_loop(lower = 0, upper = num_test,body_fun =  body_fn, init_val = 0.0)
    loss = loss / num_test
    return loss

## Function to plot the GP

In [18]:
#plot process and plot incidence map
def plot_process(gp_draws):
    """
    Plot multiple GP draws as line plots
    
    Args:
        gp_draws: Array of GP draws with shape (num_samples, num_regions)
    """
    p = px.line()
    for i in range(gp_draws.shape[1]):
        p.add_scatter(
            x=np.arange(gp_draws.shape[2]), 
            y=gp_draws[0,i, :],
            line_color='rgb(31, 119, 180)',  # A nice blue color
            opacity=0.3  # Add transparency
        )

    p.update_layout(
        template="plotly_white",
        xaxis_title="region", 
        yaxis_title="num cases",
        showlegend=False
    )
    return p  # Return the figure instead of showing it directly

def plot_incidence_map(geodf,plot_col="incidence", title="Incidence", ax=None, vmin=0.001, vmax=0.008, cmap="viridis"):
    """
    Plot incidence data on a map with value annotations.
    
    Parameters:
    -----------
    geodf : geopandas.GeoDataFrame
        GeoDataFrame containing the incidence data and geometry
    plot_col : str
        Column name to plot
    title : str
        Title for the plot
    ax : matplotlib.axes.Axes
        The axes to plot on
    vmin, vmax : float
        Minimum and maximum values for the color scale
    cmap : str
        Colormap to use for the plot
        
    Returns:
    --------
    None
    """
    # Plot the map
    geodf.plot(
        column=plot_col,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        legend=True,
        ax=ax,
    )
    
    # Add text labels with incidence values
    for idx, row in geodf.iterrows():
        # Get centroid coordinates of each polygon
        centroid = row.geometry.centroid
        # Format incidence value as percentage with 2 decimal places
        value = f"{row[plot_col]*100:.3f}%"
        # Add text with white background for visibility
        ax.annotate(
            value,
            xy=(centroid.x, centroid.y),
            xytext=(0, 0),
            textcoords='offset points',
            ha='center',
            va='center',
            fontsize=12,
            fontweight='bold',
            color='white',
            bbox=dict(
                facecolor='black',
                alpha=0.7,
                edgecolor='none',
                pad=1
            )
        )
    
    ax.set_title(title)

## Load the variables

In [19]:
# Lat/Lon Values of artificial grid
x = np.load("../data/lat_lon_x_jkt.npy")
pol_pts_jkt_lo = np.load("../data/pol_pts_jkt_lo.npy")
pol_pts_jkt_hi = np.load("../data/pol_pts_jkt_hi.npy")
df_lo = gpd.read_file("../data/jkt_prov.shp")
df_hi = gpd.read_file("../data/jkt_dist.shp")

In [20]:
print(pol_pts_jkt_lo)

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 1 1 1 1 0 0 0 1 1 1 1
  1 1 1 0 0 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 0 0 1
  1 1 1 1 1 1 1 1 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 1 1 1 1]]


In [21]:
print(pol_pts_jkt_hi)

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 1
  1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0
  0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 1
  1 1 1 1 1 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 1 1 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
  0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0
  0 0 0 0 0 1 1 1 0 0 0 

## Arguments

In [22]:
args = {
        "x": x,
        "gp_kernel": exp_sq_kernel,
        "noise": 1e-4,
        "M_lo": jnp.array(pol_pts_jkt_lo),
        "M_hi": jnp.array(pol_pts_jkt_hi),
        "jitter" : 1e-4,
        # VAE training
        "rng_key": random.PRNGKey(5),
        #common num_epochs 20-50
        "num_epochs": 20,
        #learning rate 0.0005 common choice, ADAM optimiser adapts the learning rate accordingly
        "learning_rate": 0.0005,
        #chosen to be 100 (no tune)
        "batch_size": 100,
        #change this to the optimal values after hyperparameter tuning
        "hidden_dim": 50,
        "z_dim": 50,
        #chosen to be 100 (no tune)
        "num_train": 100,
        "num_test":100,
        #variance set to 1 bc the latent variable prior distribution is assumed to be normal
        "vae_var": 1,
        "kernel_length": dist.InverseGamma(3, 3),
        "kernel_var": dist.LogNormal(0, 0.5)
    }


In [25]:
# Model configuration
args = {
    "x": x,
    "gp_kernel": exp_sq_kernel,
    "noise": 1.e-2,
    "jitter": 1.e-2,
    "M_lo": jnp.array(pol_pts_jkt_lo[0,:].reshape(1, -1)),
    "M_hi":  jnp.array(pol_pts_jkt_hi)[[0,4,9,14,19], :],
    'kernel_length': dist.InverseGamma(5, 2),
    'kernel_var': dist.LogNormal(-1, 0.05),
    "vae_var": 0.1,
    "batch_size": 5,
    "hidden_dim": 50,
    "z_dim": 40,
    "learning_rate": 1e-3,
    "num_epochs": 100,
    "rng_key": PRNGKey(6)
}

## Prior predictive simulation

In [26]:
# Create predictive function for GP
from numpyro.infer import Predictive

# Draw samples from the prior
prior_samples = Predictive(gp_aggr, num_samples=5)(
    PRNGKey(6), 
    config=args
)

# transform prior samples to arviz inference object
prior_samples_arviz = az.from_numpyro(prior=prior_samples)

In [27]:
plot_process(prior_samples_arviz.prior.gp_aggr.values)

## Initiate Training Loop for VAE encoder

In [28]:
# Initialize SVI
optimizer = numpyro.optim.Adam(step_size=1e-3)
agg_gp_predictive = Predictive(gp_aggr,num_samples = 5)
@jax.jit
def epoch_train(rng_key, svi_state, num_train):
    def body_fn(i, val):
        rng_key_i = jax.random.fold_in(rng_key, i) #Array(2,)
        rng_key_i, rng_key_ls, rng_key_var, rng_key_noise = jax.random.split(rng_key_i, 4) #Tuple(Array(2,) x 4)
        loss_sum, svi_state = val #val --svi_state
        
        batch = agg_gp_predictive(rng_key_i, args)["gp_aggr"] #(5,116) <- num_samples : 5, total_districts : 116
        #* svi is where the vae_model & vae_guide gets applied
        svi_state, loss = svi.update(svi_state, batch)
        loss_sum += loss / args["batch_size"]
        return loss_sum, svi_state 
    
    return lax.fori_loop(lower = 0, upper = num_train, body_fun=body_fn, init_val=(0.0, svi_state))

@jax.jit 
def eval_test(rng_key, svi_state, num_test):
    def body_fn(i, loss_sum):
        rng_key_i = jax.random.fold_in(rng_key, i)
        rng_key_i, rng_key_ls, rng_key_varm, rng_key_noise = jax.random.split(rng_key_i, 4)
        batch = agg_gp_predictive(rng_key_i, args)["gp_aggr"]
        #* svi is where the vae_model & vae_guide gets applied
        loss = svi.evaluate(svi_state, batch) / args["batch_size"]
        loss_sum += loss
        return loss_sum 
    
    loss = lax.fori_loop(lower = 0, upper = num_test,body_fun =  body_fn, init_val = 0.0)
    loss = loss / num_test 
    return loss

In [29]:
#-------------------------- Train VAE -------------------------- #
run = True
if run:
    adam = numpyro.optim.Adam(step_size = args["learning_rate"])
    svi = SVI(
        model=lambda batch: vae_model(batch, args["hidden_dim"], args["z_dim"], args["vae_var"]),
        guide=lambda batch: vae_guide(batch, args["hidden_dim"], args["z_dim"]),
        optim=adam, 
        loss=RenyiELBO(),
    )
    rng_key, rng_key_samp, rng_key_init = random.split(args["rng_key"],3)
    #(num_samples, num_regions) 
    init_batch = agg_gp_predictive(rng_key_samp, args)["gp_aggr"] #(num_samples, num_regions) <- i.e (5,58)
    svi_state = svi.init(rng_key_init, init_batch)

    test_loss_list = []

    for i in range(args["num_epochs"]):
        rng_key, rng_key_train, rng_key_test, rng_key_infer = random.split(rng_key, 4)
        t_start = time.time()
        num_train = 1000
        # Where forward/backward pass gets called for train
        train_loss , svi_state = epoch_train(rng_key_train, svi_state, num_train)
        num_test = 1000
        # Where forward/backward pass gets called for test
        test_loss = eval_test(rng_key_test, svi_state, num_test)
        test_loss_list += [test_loss]

        print("Epoch : {}, train loss : {:.2f}, test loss : {:.2f} ({:.2f} s.)".format(i, train_loss, test_loss, time.time() - t_start))
        if np.isnan(test_loss):
            break 

Epoch : 0, train loss : 5361181116071936.00, test loss : 3631308800.00 (3.51 s.)
Epoch : 1, train loss : 252616302919680.00, test loss : 65434256.00 (0.00 s.)
Epoch : 2, train loss : 62537428992.00, test loss : 64563804.00 (0.00 s.)
Epoch : 3, train loss : 55318867968.00, test loss : 60580188.00 (0.00 s.)
Epoch : 4, train loss : 59868250112.00, test loss : 48440264.00 (0.00 s.)
Epoch : 5, train loss : 51777458176.00, test loss : 55431344.00 (0.00 s.)
Epoch : 6, train loss : 53629333504.00, test loss : 49645152.00 (0.00 s.)
Epoch : 7, train loss : 48495988736.00, test loss : 49497408.00 (0.00 s.)
Epoch : 8, train loss : 50514984960.00, test loss : 51984244.00 (0.00 s.)
Epoch : 9, train loss : 52616507392.00, test loss : 52535156.00 (0.00 s.)
Epoch : 10, train loss : 55186604032.00, test loss : 47850504.00 (0.00 s.)
Epoch : 11, train loss : 48830312448.00, test loss : 52707864.00 (0.00 s.)
Epoch : 12, train loss : 54355124224.00, test loss : 52063188.00 (0.00 s.)
Epoch : 13, train loss :

In [30]:
#extract the decoder
vae_params = svi.get_params(svi_state)

## Save the decoder

In [31]:
# Get script directory
script_dir = os.getcwd()  # Get current working directory

# Define the correct save path inside model_weights/
save_dir = os.path.abspath(os.path.join(script_dir, "..", "model weights", "aggVAE"))
os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists

# Save decoder parameters
save_path = os.path.join(save_dir, f"aggVAE_e{args['num_epochs']}_h{args['hidden_dim']}_z{args['z_dim']}")

with open(save_path, "wb") as file:
    pickle.dump(vae_params, file)

print(f"Decoder parameters saved to {save_path}")

Decoder parameters saved to c:\Users\jessi\Documents\school\y4\s2\DSE4101\Individual\FYP codes\DSE_FYP\simulation study\model weights\aggVAE\aggVAE_e100_h50_z40
