In [10]:
!pip install numpyro

Collecting numpyro
  Downloading numpyro-0.18.0-py3-none-any.whl.metadata (37 kB)
Downloading numpyro-0.18.0-py3-none-any.whl (365 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/365.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m358.4/365.8 kB[0m [31m12.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m365.8/365.8 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpyro
Successfully installed numpyro-0.18.0


In [11]:
import os
import math
import numpy as np

import torch
import time

import itertools
import jax
import jax.numpy as jnp
from jax import random, lax, jit, ops
from jax.example_libraries import stax

import numpyro
from numpyro.infer import SVI, MCMC, NUTS, init_to_median, Predictive, RenyiELBO
import numpyro.distributions as dist

import geopandas as gpd
import plotly.express as px

from termcolor import colored

import pickle

In [None]:
#ensure this script runs on GPU 1
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use GPU 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Script B running on {device}")

In [2]:
#define the functions necessary
def dist_euclid(x, z):
    """
    Computes Eucledian Distance Between Regions. This function is used by
    exp_sq_kernel function (kernel function for gaussian processes)
    """
    x = jnp.array(x) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    z = jnp.array(z) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    if len(x.shape)==1:
        x = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    if len(z.shape)==1:
        z = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    n_x, m = x.shape # 7304 , 2
    n_z, m_z = z.shape # 7304 , 2
    assert m == m_z
    delta = jnp.zeros((n_x,n_z)) #(ngrid_pts,ngrid_pts) <- i.e (7304,7304)
    for d in jnp.arange(m):
        x_d = x[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        z_d = z[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        delta += (x_d[:,jnp.newaxis] - z_d)**2 # (7304,7304)

    return jnp.sqrt(delta) #(7304,7304)

def exp_sq_kernel(x, z, var, length, noise, jitter=1.0e-4):
    dist = dist_euclid(x, z) #(7304, 7304)
    deltaXsq = jnp.power(dist/ length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    k += (noise + jitter) * jnp.eye(x.shape[0])
    return k # (ngrid_pts, ngrid_pts) <- (7304,7304)

def M_g(M, g):
    '''
    - $M$ is a matrix with binary entries $m_{ij},$ showing whether point $j$ is in polygon $i$
    - $g$ is a vector of GP draws over grid
    - $maltmul(M, g)$ gives a vector of sums over each polygon
    '''
    M = jnp.array(M)
    g = jnp.array(g).T
    return(jnp.matmul(M, g))

#AggVAE Model

##Function for Predictive Simulation (Prior)

In [3]:
def gp_aggr(args):
    x = args["x"]  # Spatial grid points: (num_grid_points, 2)
    gp_kernel = args["gp_kernel"]  # Gaussian Process kernel
    noise = args["noise"]
    jitter = args["jitter"]
    M = args["M"]  # (num_districts, num_grid_points) aggregation matrix

    # GP hyperparameters
    kernel_length = numpyro.sample("kernel_length", args["kernel_length"])
    kernel_var = numpyro.sample("kernel_var", args["kernel_var"])

    # GP Kernel and Sample
    k = gp_kernel(x, x, kernel_var, kernel_length, noise, jitter)
    f = numpyro.sample("f", dist.MultivariateNormal(loc=jnp.zeros(x.shape[0]), covariance_matrix=k))  # (num_grid_points,)

    # Aggregate GP values to district level
    gp_aggr = numpyro.deterministic("gp_aggr", M @ f)  # (num_districts,)

    return gp_aggr

##Define the VAE

In [4]:
def vae_encoder(hidden_dim = 50, z_dim = 40):
    return stax.serial(
        #(num_samples, num_regions) -> (num_samples, hidden_dims)
        stax.Dense(hidden_dim, W_init = stax.randn()),
        stax.Elu,
        stax.FanOut(2),
        stax.parallel(
            # mean : (num_samples, hidden_dim) -> (num_samples, z_dim)
            stax.Dense(z_dim, W_init = stax.randn()), #(5,50)
            #std : (num_samples, hidden_dim) -> (num_samples, z_dim)
            stax.serial(stax.Dense(z_dim, W_init = stax.randn()), stax.Exp)
        )
    )

def vae_decoder(hidden_dim, out_dim):
    return stax.serial(
        # (num_samples, z_dim) -> (num_samples, hidden_dim)
        stax.Dense(hidden_dim, W_init = stax.randn()),
        stax.Elu,
        # (num_samples, hidden_dim) -> (num_samples, num_regions)
        stax.Dense(out_dim, W_init = stax.randn())
    )


def vae_model(batch, hidden_dim, z_dim):
    """This computes the decoder portion"""
    batch = jnp.reshape(batch, (batch.shape[0], -1)) # (num_samples, num_regions)
    batch_dim, out_dim = jnp.shape(batch)

    # vae-decoder in numpyro module
    decode = numpyro.module(
        name = "decoder",
        nn = vae_decoder(hidden_dim = hidden_dim, out_dim = out_dim),
        input_shape = (batch_dim, z_dim) #(5,40)
    )

    # Sample a univariate normal
    z = numpyro.sample(
        "z",
        dist.Normal(
            jnp.zeros((batch_dim,z_dim)),
            jnp.ones((batch_dim,z_dim))
            )
    )
    # Forward pass from decoder
    gen_loc = decode(z) #(num_regions,)
    obs = numpyro.sample(
        "obs",
        dist.Normal(gen_loc, args["vae_var"]),
        obs = batch
    ) #(num_samples, num_regions)
    return obs


def vae_guide(batch, hidden_dim, z_dim):
    """This computes the encoder portion"""
    batch = jnp.reshape(batch, (batch.shape[0], -1)) #(num_samples, num_regions)
    batch_dim, input_dim = jnp.shape(batch)# num_samples , num_regions

    # vae-encoder in numpyro module
    encode = numpyro.module(
        name = "encoder",
        nn = vae_encoder(hidden_dim=hidden_dim,z_dim = z_dim),
        input_shape = (batch_dim, input_dim) #(5,58)
    ) #(num_samples, num_regions) -> (num_samples, hidden_dims)

    # Samapling mu, sigma - Pretty much the forward pass
    z_loc, z_std = encode(batch) #mu : (num_samples, z_dim), sigma2 : (num_samples, z_dim)
    # Sample a value z based on mu and sigma
    z = numpyro.sample("z", dist.Normal(z_loc, z_std)) #(num_sample, z_dim)
    return z

##Train the VAE encoder

In [5]:
@jax.jit
def epoch_train(rng_key, svi_state, num_train):
    def body_fn(i, val):
        rng_key_i = jax.random.fold_in(rng_key, i) #Array(2,)
        rng_key_i, rng_key_ls, rng_key_var, rng_key_noise = jax.random.split(rng_key_i, 4) #Tuple(Array(2,) x 4)
        loss_sum, svi_state = val #val --svi_state

        batch = agg_gp_predictive(rng_key_i, args)["gp_aggr"] #(5,116) <- num_samples : 5, total_districts : 116
        #* svi is where the vae_model & vae_guide gets applied
        svi_state, loss = svi.update(svi_state, batch)
        loss_sum += loss / args["batch_size"]
        return loss_sum, svi_state

    return lax.fori_loop(lower = 0, upper = num_train, body_fun=body_fn, init_val=(0.0, svi_state))

@jax.jit
def eval_test(rng_key, svi_state, num_test):
    def body_fn(i, loss_sum):
        rng_key_i = jax.random.fold_in(rng_key, i)
        rng_key_i, rng_key_ls, rng_key_varm, rng_key_noise = jax.random.split(rng_key_i, 4)
        batch = agg_gp_predictive(rng_key_i, args)["gp_aggr"]
        #* svi is where the vae_model & vae_guide gets applied
        loss = svi.evaluate(svi_state, batch) / args["batch_size"]
        loss_sum += loss
        return loss_sum

    loss = lax.fori_loop(lower = 0, upper = num_test,body_fun =  body_fn, init_val = 0.0)
    loss = loss / num_test
    return loss

##Function to plot the GP

In [6]:
def plot_process(gp_draws):
    p = px.line()
    for i in range(len(gp_draws)):
        p.add_scatter(x = np.arange(gp_draws.shape[1]), y = gp_draws[i, :])

    p.update_traces(line_color = "black")
    p.update_layout(
        template = "plotly_white",
        xaxis_title = "region", yaxis_title = "num cases",
        showlegend = False)
    p.show()

##Load the variables

In [12]:
# Lat/Lon Values of artificial grid
x = np.load("lat_lon_x_jkt.npy")
pol_pts_jkt = np.load("pol_pts_jkt.npy")
pt_which_pol_jkt = np.load("pt_which_pol_jkt.npy")

df_combined = gpd.read_file("jkt_combined_divisions.shp")

##Arguments

In [18]:
args = {
        "x": x,
        "gp_kernel": exp_sq_kernel,
        "noise": 1e-4,
        "M": jnp.array(pol_pts_jkt),
        "jitter" : 1e-4,
        # VAE training
        "rng_key": random.PRNGKey(5),
        #common num_epochs 20-50
        "num_epochs": 20,
        #learning rate 0.0005 common choice, ADAM optimiser adapts the learning rate accordingly
        "learning_rate": 0.0005,
        #chosen to be 100 (no tune)
        "batch_size": 100,
        #change this to the optimal values after hyperparameter tuning
        "hidden_dim": 30,
        "z_dim": 40,
        #chosen to be 100 (no tune)
        "num_train": 100,
        "num_test":100,
        #variance set to 1 bc the latent variable prior distribution is assumed to be normal
        "vae_var": 1,
        "kernel_length": dist.InverseGamma(3,3),
        "kernel_var": dist.HalfNormal(1e-5)
    }


##Prior predictive simulation

In [14]:
rng_key, rng_key_ = random.split(random.PRNGKey(4))
agg_gp_predictive = Predictive(gp_aggr,num_samples = 5)
agg_gp_draws = agg_gp_predictive(rng_key_, args)["gp_aggr"] #(num_samples, num_regions)

In [15]:
# Plotting
plot_process(agg_gp_draws)

##Hyperparameter Tuning

In [16]:
# Define hyperparameter grid
hidden_dims = [20, 30, 40, 50]
z_dims = [20, 30, 40, 50]

# Store results
test_loss_results_final = {}
test_loss_results_mean = {}

num_train = args["num_train"]
num_test = args["num_test"]
num_epochs = args["num_epochs"]

for hidden_dim, z_dim in itertools.product(hidden_dims, z_dims):
    print(f"Training with hidden_dim={hidden_dim}, z_dim={z_dim}")

    args["hidden_dim"] = hidden_dim
    args["z_dim"] = z_dim

    # Initialize optimizer and SVI
    adam = numpyro.optim.Adam(step_size=args["learning_rate"])
    svi = SVI(
        vae_model,
        vae_guide,
        adam,
        RenyiELBO(),
        hidden_dim=hidden_dim,
        z_dim=z_dim
    )

    # Split RNG keys
    rng_key, rng_key_samp, rng_key_init = random.split(args["rng_key"], 3)
    init_batch = agg_gp_predictive(rng_key_samp, args)["gp_aggr"]

    # Initialize SVI state
    svi_state = svi.init(rng_key_init, init_batch)

    # Pre-allocate test loss array
    test_loss_list = jnp.zeros(num_epochs)

    # Training loop
    for epoch in range(num_epochs):
        rng_key, rng_key_train, rng_key_test = random.split(rng_key, 3)
        t_start = time.time()

        train_loss, svi_state = epoch_train(rng_key_train, svi_state, num_train)
        test_loss = eval_test(rng_key_test, svi_state, num_test)
        test_loss_list = test_loss_list.at[epoch].set(test_loss)

        print(f"Epoch: {epoch}, Train Loss: {train_loss:.2f}, Test Loss: {test_loss:.2f} ({time.time() - t_start:.2f} s)")

        if math.isnan(test_loss):  # Stop early if NaN
            print(f"NaN encountered at hidden_dim={hidden_dim}, z_dim={z_dim}. Skipping...")
            break  # Stop training if NaN occurs

    # Store results only if valid
    if not math.isnan(test_loss_list[-1]):
        test_loss_results_final[(hidden_dim, z_dim)] = test_loss_list[-1]
        test_loss_results_mean[(hidden_dim, z_dim)] = jnp.nanmean(test_loss_list)
        print(f"Final Test loss for hidden_dim={hidden_dim}, z_dim={z_dim}: {test_loss_list[-1]}")
        print(f"Mean Test loss for hidden_dim={hidden_dim}, z_dim={z_dim}: {jnp.nanmean(test_loss_list)}")

print("Grid search complete.")
print("Results:", test_loss_results_final)
print("Mean Test Loss Results:", test_loss_results_mean)

Training with hidden_dim=20, z_dim=20
Epoch: 0, Train Loss: 92.06, Test Loss: 0.92 (9.48 s)
Epoch: 1, Train Loss: 92.07, Test Loss: 0.92 (0.28 s)
Epoch: 2, Train Loss: 92.08, Test Loss: 0.92 (0.27 s)
Epoch: 3, Train Loss: 92.08, Test Loss: 0.92 (0.49 s)
Epoch: 4, Train Loss: 92.09, Test Loss: 0.92 (0.42 s)
Epoch: 5, Train Loss: 92.09, Test Loss: 0.92 (0.50 s)
Epoch: 6, Train Loss: 92.05, Test Loss: 0.92 (0.74 s)
Epoch: 7, Train Loss: 92.09, Test Loss: 0.92 (0.67 s)
Epoch: 8, Train Loss: 92.08, Test Loss: 0.92 (0.65 s)
Epoch: 9, Train Loss: 92.05, Test Loss: 0.92 (0.52 s)
Epoch: 10, Train Loss: 92.05, Test Loss: 0.92 (0.46 s)
Epoch: 11, Train Loss: 92.11, Test Loss: 0.92 (0.38 s)
Epoch: 12, Train Loss: 92.13, Test Loss: 0.92 (0.29 s)
Epoch: 13, Train Loss: 92.05, Test Loss: 0.92 (0.27 s)
Epoch: 14, Train Loss: 92.10, Test Loss: 0.92 (0.26 s)
Epoch: 15, Train Loss: 92.06, Test Loss: 0.92 (0.30 s)
Epoch: 16, Train Loss: 92.05, Test Loss: 0.92 (0.26 s)
Epoch: 17, Train Loss: 92.07, Test Lo

In [17]:
print(test_loss_results_final)
print(test_loss_results_mean)

{(20, 20): Array(0.923106, dtype=float32), (20, 30): Array(0.92014986, dtype=float32), (20, 40): Array(0.9173188, dtype=float32), (20, 50): Array(0.9182265, dtype=float32), (30, 20): Array(0.92311615, dtype=float32), (30, 30): Array(0.9201259, dtype=float32), (30, 40): Array(0.9172175, dtype=float32), (30, 50): Array(0.91810465, dtype=float32), (40, 20): Array(0.923177, dtype=float32), (40, 30): Array(0.92018914, dtype=float32), (40, 40): Array(0.91723204, dtype=float32), (40, 50): Array(0.91817856, dtype=float32), (50, 20): Array(0.9231604, dtype=float32), (50, 30): Array(0.92014307, dtype=float32), (50, 40): Array(0.9172466, dtype=float32), (50, 50): Array(0.9181372, dtype=float32)}
{(20, 20): Array(0.9207327, dtype=float32), (20, 30): Array(0.9213196, dtype=float32), (20, 40): Array(0.9207714, dtype=float32), (20, 50): Array(0.9211856, dtype=float32), (30, 20): Array(0.9207401, dtype=float32), (30, 30): Array(0.9213139, dtype=float32), (30, 40): Array(0.92079985, dtype=float32), (30

### Optimal Hyperparams are 30 hidden dims and 40 latent dims


##Initiate Training Loop with optimal hyperparams

In [19]:
adam = numpyro.optim.Adam(step_size = args["learning_rate"])
svi = SVI(
        vae_model,
        vae_guide,
        adam,
        RenyiELBO(),
        hidden_dim = args["hidden_dim"],
        z_dim = args["z_dim"]
    )

rng_key, rng_key_samp, rng_key_init = random.split(args["rng_key"],3)
#(num_samples, num_regions)
init_batch = agg_gp_predictive(rng_key_, args)["gp_aggr"] #(num_samples, num_regions) <- i.e (5,58)
svi_state = svi.init(rng_key_init, init_batch)

test_loss_list = []

In [20]:
for i in range(args["num_epochs"]):
    rng_key, rng_key_train, rng_key_test, rng_key_infer = random.split(rng_key, 4)
    t_start = time.time()

    num_train = 1000
    # Where forward/backward pass gets called for train
    train_loss , svi_state = epoch_train(rng_key_train, svi_state, num_train)

    num_test = 1000

    # Where forward/backward pass gets called for test
    test_loss = eval_test(rng_key_test, svi_state, num_test)
    test_loss_list += [test_loss]

    print("Epoch : {}, train loss : {:.2f}, test loss : {:.2f} ({:.2f} s.)".format(i, train_loss, test_loss, time.time() - t_start))

    if math.isnan(test_loss):
        break

Epoch : 0, train loss : 920.80, test loss : 0.92 (0.00 s.)
Epoch : 1, train loss : 920.90, test loss : 0.92 (0.00 s.)
Epoch : 2, train loss : 920.93, test loss : 0.92 (0.00 s.)
Epoch : 3, train loss : 920.79, test loss : 0.92 (0.00 s.)
Epoch : 4, train loss : 920.90, test loss : 0.92 (0.00 s.)
Epoch : 5, train loss : 920.90, test loss : 0.92 (0.00 s.)
Epoch : 6, train loss : 920.93, test loss : 0.92 (0.00 s.)
Epoch : 7, train loss : 920.90, test loss : 0.93 (0.00 s.)
Epoch : 8, train loss : 920.93, test loss : 0.92 (0.00 s.)
Epoch : 9, train loss : 920.82, test loss : 0.92 (0.00 s.)
Epoch : 10, train loss : 920.98, test loss : 0.92 (0.00 s.)
Epoch : 11, train loss : 920.87, test loss : 0.92 (0.00 s.)
Epoch : 12, train loss : 920.94, test loss : 0.92 (0.00 s.)
Epoch : 13, train loss : 920.93, test loss : 0.92 (0.00 s.)
Epoch : 14, train loss : 920.90, test loss : 0.92 (0.00 s.)
Epoch : 15, train loss : 920.89, test loss : 0.92 (0.00 s.)
Epoch : 16, train loss : 920.83, test loss : 0.92 

In [21]:
#extract the decoder
decoder_params = svi.get_params(svi_state)

##Save the decoder

In [24]:
# Get script directory
script_dir = os.getcwd()  # Get current working directory

# Define the correct save path inside model_weights/
save_dir = os.path.abspath(os.path.join(script_dir, "..", model_weights", "aggVAE"))
os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists

# Save decoder parameters
save_path = os.path.join(save_dir, f"aggVAE_e{args['num_epochs']}_h{args['hidden_dim']}_z{args['z_dim']}")

with open(save_path, "wb") as file:
    pickle.dump(decoder_params, file)

print(f"Decoder parameters saved to {save_path}")

Decoder parameters saved to /content/model_weights/aggVAE/aggVAE_e20_h30_z40
