In [1]:
!pip install numpyro

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import os
import math
import numpy as np

import torch
import time

import itertools
import jax
import jax.numpy as jnp
from jax import random, lax, jit, ops
from jax.example_libraries import stax

import numpyro
from numpyro.infer import SVI, MCMC, NUTS, init_to_median, Predictive, RenyiELBO
import numpyro.distributions as dist

import geopandas as gpd
import plotly.express as px

from termcolor import colored

import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#ensure this script runs on GPU 1
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use GPU 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Script B running on {device}")

Script B running on cpu


In [4]:
#define the functions necessary
def dist_euclid(x, z):
    """
    Computes Eucledian Distance Between Regions. This function is used by
    exp_sq_kernel function (kernel function for gaussian processes)
    """
    x = jnp.array(x) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    z = jnp.array(z) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    if len(x.shape)==1:
        x = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    if len(z.shape)==1:
        z = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    n_x, m = x.shape # 7304 , 2
    n_z, m_z = z.shape # 7304 , 2
    assert m == m_z
    delta = jnp.zeros((n_x,n_z)) #(ngrid_pts,ngrid_pts) <- i.e (7304,7304)
    for d in jnp.arange(m):
        x_d = x[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        z_d = z[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        delta += (x_d[:,jnp.newaxis] - z_d)**2 # (7304,7304)

    return jnp.sqrt(delta) #(7304,7304)

def exp_sq_kernel(x, z, var, length, noise, jitter=1.0e-4):
    dist = dist_euclid(x, z) #(7304, 7304)
    deltaXsq = jnp.power(dist/ length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    k += (noise + jitter) * jnp.eye(x.shape[0])
    return k # (ngrid_pts, ngrid_pts) <- (7304,7304)

def M_g(M, g):
    '''
    - $M$ is a matrix with binary entries $m_{ij},$ showing whether point $j$ is in polygon $i$
    - $g$ is a vector of GP draws over grid
    - $maltmul(M, g)$ gives a vector of sums over each polygon
    '''
    M = jnp.array(M)
    g = jnp.array(g).T
    return(jnp.matmul(M, g))

# AggVAE Model

## Function for Predictive Simulation (Prior)

In [5]:
def gp_aggr(args):
    x = args["x"]  # Spatial grid points: (num_grid_points, 2)
    gp_kernel = args["gp_kernel"]  # Gaussian Process kernel
    noise = args["noise"]
    jitter = args["jitter"]

    M_lo= args["M_lo"] # (9, 2618)
    M_hi = args["M_hi"] # (49, 2618),

    # GP hyperparameters
    kernel_length = numpyro.sample("kernel_length", args["kernel_length"])
    kernel_var = numpyro.sample("kernel_var", args["kernel_var"])

    # GP Kernel and Sample
    k = gp_kernel(x, x, kernel_var, kernel_length, noise, jitter)
    f = numpyro.sample("f", dist.MultivariateNormal(loc=jnp.zeros(x.shape[0]), covariance_matrix=k))  # (num_grid_points,)

    #aggregate f into gp_aggr according to indexing of (point in polygon)
    gp_aggr_lo = numpyro.deterministic("gp_aggr_lo", M_g(M_lo, f)) #(num_regions,) <- i.e (9,) for lo
    gp_aggr_hi = numpyro.deterministic("gp_aggr_hi", M_g(M_hi, f)) #(49,)
    gp_aggr = numpyro.deterministic("gp_aggr", jnp.concatenate([gp_aggr_lo, gp_aggr_hi])) #(58,)


    return gp_aggr

## Define the VAE

In [6]:
def vae_encoder(hidden_dim = 50, z_dim = 40):
    return stax.serial(
        #(num_samples, num_regions) -> (num_samples, hidden_dims)
        stax.Dense(hidden_dim, W_init = stax.randn()),
        stax.Elu,
        stax.FanOut(2),
        stax.parallel(
            # mean : (num_samples, hidden_dim) -> (num_samples, z_dim)
            stax.Dense(z_dim, W_init = stax.randn()), #(5,50)
            #std : (num_samples, hidden_dim) -> (num_samples, z_dim)
            stax.serial(stax.Dense(z_dim, W_init = stax.randn()), stax.Exp)
        )
    )

def vae_decoder(hidden_dim, out_dim):
    return stax.serial(
        # (num_samples, z_dim) -> (num_samples, hidden_dim)
        stax.Dense(hidden_dim, W_init = stax.randn()),
        stax.Elu,
        # (num_samples, hidden_dim) -> (num_samples, num_regions)
        stax.Dense(out_dim, W_init = stax.randn())
    )


def vae_model(batch, hidden_dim, z_dim):
    """This computes the decoder portion"""
    batch = jnp.reshape(batch, (batch.shape[0], -1)) # (num_samples, num_regions)
    batch_dim, out_dim = jnp.shape(batch)

    # vae-decoder in numpyro module
    decode = numpyro.module(
        name = "decoder",
        nn = vae_decoder(hidden_dim = hidden_dim, out_dim = out_dim),
        input_shape = (batch_dim, z_dim) #(5,40)
    )

    # Sample a univariate normal
    z = numpyro.sample(
        "z",
        dist.Normal(
            jnp.zeros((batch_dim,z_dim)),
            jnp.ones((batch_dim,z_dim))
            )
    )
    # Forward pass from decoder
    gen_loc = decode(z) #(num_regions,)
    obs = numpyro.sample(
        "obs",
        dist.Normal(gen_loc, args["vae_var"]),
        obs = batch
    ) #(num_samples, num_regions)
    return obs


def vae_guide(batch, hidden_dim, z_dim):
    """This computes the encoder portion"""
    batch = jnp.reshape(batch, (batch.shape[0], -1)) #(num_samples, num_regions)
    batch_dim, input_dim = jnp.shape(batch)# num_samples , num_regions

    # vae-encoder in numpyro module
    encode = numpyro.module(
        name = "encoder",
        nn = vae_encoder(hidden_dim=hidden_dim,z_dim = z_dim),
        input_shape = (batch_dim, input_dim) #(5,58)
    ) #(num_samples, num_regions) -> (num_samples, hidden_dims)

    # Samapling mu, sigma - Pretty much the forward pass
    z_loc, z_std = encode(batch) #mu : (num_samples, z_dim), sigma2 : (num_samples, z_dim)
    # Sample a value z based on mu and sigma
    z = numpyro.sample("z", dist.Normal(z_loc, z_std)) #(num_sample, z_dim)
    return z

## Train the VAE encoder

In [7]:
@jax.jit
def epoch_train(rng_key, svi_state, num_train):
    def body_fn(i, val):
        rng_key_i = jax.random.fold_in(rng_key, i) #Array(2,)
        rng_key_i, rng_key_ls, rng_key_var, rng_key_noise = jax.random.split(rng_key_i, 4) #Tuple(Array(2,) x 4)
        loss_sum, svi_state = val #val --svi_state

        batch = agg_gp_predictive(rng_key_i, args)["gp_aggr"] #(5,116) <- num_samples : 5, total_districts : 116
        #* svi is where the vae_model & vae_guide gets applied
        svi_state, loss = svi.update(svi_state, batch)
        loss_sum += loss / args["batch_size"]
        return loss_sum, svi_state

    return lax.fori_loop(lower = 0, upper = num_train, body_fun=body_fn, init_val=(0.0, svi_state))

@jax.jit
def eval_test(rng_key, svi_state, num_test):
    def body_fn(i, loss_sum):
        rng_key_i = jax.random.fold_in(rng_key, i)
        rng_key_i, rng_key_ls, rng_key_varm, rng_key_noise = jax.random.split(rng_key_i, 4)
        batch = agg_gp_predictive(rng_key_i, args)["gp_aggr"]
        #* svi is where the vae_model & vae_guide gets applied
        loss = svi.evaluate(svi_state, batch) / args["batch_size"]
        loss_sum += loss
        return loss_sum

    loss = lax.fori_loop(lower = 0, upper = num_test,body_fun =  body_fn, init_val = 0.0)
    loss = loss / num_test
    return loss

## Function to plot the GP

In [8]:
def plot_process(gp_draws):
    p = px.line()
    for i in range(len(gp_draws)):
        p.add_scatter(x = np.arange(gp_draws.shape[1]), y = gp_draws[i, :])

    p.update_traces(line_color = "black")
    p.update_layout(
        template = "plotly_white",
        xaxis_title = "region", yaxis_title = "num cases",
        showlegend = False)
    p.show()

## Load the variables

In [9]:
# Lat/Lon Values of artificial grid
x = np.load("../data/lat_lon_x_jkt.npy")
pol_pts_jkt_lo = np.load("../data/pol_pts_jkt_lo.npy")
pol_pts_jkt_hi = np.load("../data/pol_pts_jkt_hi.npy")

## Arguments

In [19]:
args = {
        "x": x,
        "gp_kernel": exp_sq_kernel,
        "noise": 1e-4,
        "M_lo": jnp.array(pol_pts_jkt_lo),
        "M_hi": jnp.array(pol_pts_jkt_hi),
        "jitter" : 1e-4,
        # VAE training
        "rng_key": random.PRNGKey(5),
        #common num_epochs 20-50
        "num_epochs": 20,
        #learning rate 0.0005 common choice, ADAM optimiser adapts the learning rate accordingly
        "learning_rate": 0.0005,
        #chosen to be 100 (no tune)
        "batch_size": 100,
        #change this to the optimal values after hyperparameter tuning
        "hidden_dim": 40,
        "z_dim": 40,
        #chosen to be 100 (no tune)
        "num_train": 100,
        "num_test":100,
        #variance set to 1 bc the latent variable prior distribution is assumed to be normal
        "vae_var": 1,
        "kernel_length": dist.LogNormal(-1, 0.5),
        "kernel_var": dist.LogNormal(0, 1)
    }


## Prior predictive simulation

In [11]:
rng_key, rng_key_ = random.split(random.PRNGKey(4))
agg_gp_predictive = Predictive(gp_aggr,num_samples = 5)
agg_gp_draws = agg_gp_predictive(rng_key_, args)["gp_aggr"] #(num_samples, num_regions)

In [12]:
# Plotting
plot_process(agg_gp_draws)

## Hyperparameter Tuning

In [13]:
# Define hyperparameter grid
hidden_dims = [20, 30, 40, 50]
z_dims = [20, 30, 40, 50]

# Store results
test_loss_results_final = {}
test_loss_results_mean = {}

num_train = args["num_train"]
num_test = args["num_test"]
num_epochs = args["num_epochs"]

for hidden_dim, z_dim in itertools.product(hidden_dims, z_dims):
    print(f"Training with hidden_dim={hidden_dim}, z_dim={z_dim}")

    args["hidden_dim"] = hidden_dim
    args["z_dim"] = z_dim

    # Initialize optimizer and SVI
    adam = numpyro.optim.Adam(step_size=args["learning_rate"])
    svi = SVI(
        vae_model,
        vae_guide,
        adam,
        RenyiELBO(),
        hidden_dim=hidden_dim,
        z_dim=z_dim
    )

    # Split RNG keys
    rng_key, rng_key_samp, rng_key_init = random.split(args["rng_key"], 3)
    init_batch = agg_gp_predictive(rng_key_samp, args)["gp_aggr"]

    # Initialize SVI state
    svi_state = svi.init(rng_key_init, init_batch)

    # Pre-allocate test loss array
    test_loss_list = jnp.zeros(num_epochs)

    # Training loop
    for epoch in range(num_epochs):
        rng_key, rng_key_train, rng_key_test = random.split(rng_key, 3)
        t_start = time.time()

        train_loss, svi_state = epoch_train(rng_key_train, svi_state, num_train)
        test_loss = eval_test(rng_key_test, svi_state, num_test)
        test_loss_list = test_loss_list.at[epoch].set(test_loss)

        print(f"Epoch: {epoch}, Train Loss: {train_loss:.2f}, Test Loss: {test_loss:.2f} ({time.time() - t_start:.2f} s)")

        if math.isnan(test_loss):  # Stop early if NaN
            print(f"NaN encountered at hidden_dim={hidden_dim}, z_dim={z_dim}. Skipping...")
            break  # Stop training if NaN occurs

    # Store results only if valid
    if not math.isnan(test_loss_list[-1]):
        test_loss_results_final[(hidden_dim, z_dim)] = test_loss_list[-1]
        test_loss_results_mean[(hidden_dim, z_dim)] = jnp.nanmean(test_loss_list)
        print(f"Final Test loss for hidden_dim={hidden_dim}, z_dim={z_dim}: {test_loss_list[-1]}")
        print(f"Mean Test loss for hidden_dim={hidden_dim}, z_dim={z_dim}: {jnp.nanmean(test_loss_list)}")

print("Grid search complete.")
print("Results:", test_loss_results_final)
print("Mean Test Loss Results:", test_loss_results_mean)

Training with hidden_dim=20, z_dim=20
Epoch: 0, Train Loss: 62690.86, Test Loss: 659.35 (3.63 s)
Epoch: 1, Train Loss: 39545.20, Test Loss: 267.84 (0.12 s)
Epoch: 2, Train Loss: 26016.62, Test Loss: 141.80 (0.13 s)
Epoch: 3, Train Loss: 13316.55, Test Loss: 98.49 (0.10 s)
Epoch: 4, Train Loss: 8194.10, Test Loss: 77.63 (0.11 s)
Epoch: 5, Train Loss: 7716.47, Test Loss: 59.24 (0.10 s)
Epoch: 6, Train Loss: 5440.21, Test Loss: 58.47 (0.09 s)
Epoch: 7, Train Loss: 5796.90, Test Loss: 60.20 (0.09 s)
Epoch: 8, Train Loss: 3981.10, Test Loss: 36.58 (0.09 s)
Epoch: 9, Train Loss: 4065.93, Test Loss: 35.60 (0.10 s)
Epoch: 10, Train Loss: 7679.67, Test Loss: 56.80 (0.10 s)
Epoch: 11, Train Loss: 13832.93, Test Loss: 1347.12 (0.09 s)
Epoch: 12, Train Loss: 3406.36, Test Loss: 26.75 (0.10 s)
Epoch: 13, Train Loss: 2806.83, Test Loss: 28.86 (0.10 s)
Epoch: 14, Train Loss: 2723.90, Test Loss: 30.63 (0.09 s)
Epoch: 15, Train Loss: 2407.02, Test Loss: 25.74 (0.09 s)
Epoch: 16, Train Loss: 2661.98, Te

In [14]:
print(test_loss_results_final)
print(test_loss_results_mean)

{(20, 20): Array(29.021763, dtype=float32), (20, 30): Array(29.71638, dtype=float32), (20, 40): Array(24.346157, dtype=float32), (20, 50): Array(29.911406, dtype=float32), (30, 20): Array(27.477472, dtype=float32), (30, 30): Array(54.17611, dtype=float32), (30, 40): Array(22.630238, dtype=float32), (30, 50): Array(31.427656, dtype=float32), (40, 20): Array(25.22492, dtype=float32), (40, 30): Array(0., dtype=float32), (40, 40): Array(21.912703, dtype=float32), (40, 50): Array(28.029053, dtype=float32), (50, 20): Array(22.802025, dtype=float32), (50, 30): Array(53239.55, dtype=float32), (50, 40): Array(28.411907, dtype=float32), (50, 50): Array(502.92773, dtype=float32)}
{(20, 20): Array(155.94492, dtype=float32), (20, 30): Array(104.37584, dtype=float32), (20, 40): Array(84.874245, dtype=float32), (20, 50): Array(92.755226, dtype=float32), (30, 20): Array(190.97612, dtype=float32), (30, 30): Array(77.080765, dtype=float32), (30, 40): Array(71.7529, dtype=float32), (30, 50): Array(80.850

### Optimal Hyperparams are 40 hidden dims and 40 latent dims


## Initiate Training Loop with optimal hyperparams

In [20]:
adam = numpyro.optim.Adam(step_size = args["learning_rate"])
svi = SVI(
        vae_model,
        vae_guide,
        adam,
        RenyiELBO(),
        hidden_dim = args["hidden_dim"],
        z_dim = args["z_dim"]
    )

rng_key, rng_key_samp, rng_key_init = random.split(args["rng_key"],3)
#(num_samples, num_regions)
init_batch = agg_gp_predictive(rng_key_, args)["gp_aggr"] #(num_samples, num_regions) <- i.e (5,58)
svi_state = svi.init(rng_key_init, init_batch)

test_loss_list = []

In [21]:
for i in range(args["num_epochs"]):
    rng_key, rng_key_train, rng_key_test, rng_key_infer = random.split(rng_key, 4)
    t_start = time.time()

    num_train = 1000
    # Where forward/backward pass gets called for train
    train_loss , svi_state = epoch_train(rng_key_train, svi_state, num_train)

    num_test = 1000

    # Where forward/backward pass gets called for test
    test_loss = eval_test(rng_key_test, svi_state, num_test)
    test_loss_list += [test_loss]

    print("Epoch : {}, train loss : {:.2f}, test loss : {:.2f} ({:.2f} s.)".format(i, train_loss, test_loss, time.time() - t_start))

    if math.isnan(test_loss):
        break

Epoch : 0, train loss : 141412.16, test loss : 29.42 (0.00 s.)
Epoch : 1, train loss : 41129124.00, test loss : 303.64 (0.00 s.)
Epoch : 2, train loss : 601856.00, test loss : 135.27 (0.00 s.)
Epoch : 3, train loss : 82521.20, test loss : 64.69 (0.00 s.)
Epoch : 4, train loss : 110382.38, test loss : 56.13 (0.00 s.)
Epoch : 5, train loss : 45075.00, test loss : 36.62 (0.00 s.)
Epoch : 6, train loss : 32760.90, test loss : 28.69 (0.00 s.)
Epoch : 7, train loss : 26884.22, test loss : 23.66 (0.00 s.)
Epoch : 8, train loss : 21530.72, test loss : 17.94 (0.00 s.)
Epoch : 9, train loss : 17382.08, test loss : 16.38 (0.00 s.)
Epoch : 10, train loss : 15925.57, test loss : 14.91 (0.00 s.)
Epoch : 11, train loss : 15186.05, test loss : 15.35 (0.00 s.)
Epoch : 12, train loss : 14599.62, test loss : 13.73 (0.00 s.)
Epoch : 13, train loss : 13660.88, test loss : 13.42 (0.00 s.)
Epoch : 14, train loss : 13085.31, test loss : 13.34 (0.00 s.)
Epoch : 15, train loss : 14449.29, test loss : 13.14 (0.0

In [22]:
#extract the decoder
decoder_params = svi.get_params(svi_state)

## Save the decoder

In [24]:
# Get script directory
script_dir = os.getcwd()  # Get current working directory

# Define the correct save path inside model_weights/
save_dir = os.path.abspath(os.path.join(script_dir, "..", "model weights", "aggVAE"))
os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists

# Save decoder parameters
save_path = os.path.join(save_dir, f"aggVAE_e{args['num_epochs']}_h{args['hidden_dim']}_z{args['z_dim']}")

with open(save_path, "wb") as file:
    pickle.dump(decoder_params, file)

print(f"Decoder parameters saved to {save_path}")

Decoder parameters saved to c:\Users\jessi\Documents\school\y4\s2\DSE4101\Individual\FYP codes\DSE_FYP\simulation study\model weights\aggVAE\aggVAE_e20_h40_z40
