In [2]:
import os
import math
import numpy as np

import torch
import time

import itertools
import jax
import jax.numpy as jnp
from jax import random, lax, jit, ops
from jax.example_libraries import stax

import numpyro
from numpyro.infer import SVI, MCMC, NUTS, init_to_median, Predictive, RenyiELBO
import numpyro.distributions as dist

import geopandas as gpd
import pandas as pd
import plotly.express as px

from termcolor import colored

import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#define the functions necessary
def dist_euclid(x, z):
    """
    Computes Eucledian Distance Between Regions. This function is used by
    exp_sq_kernel function (kernel function for gaussian processes)
    """
    x = jnp.array(x) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    z = jnp.array(z) # (ngrid_pts, lat/lon) <- i.e (7304,2)
    if len(x.shape)==1:
        x = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    if len(z.shape)==1:
        z = x.reshape(x.shape[0], 1) #(2618,) -> (7304,1)
    n_x, m = x.shape # 7304 , 2
    n_z, m_z = z.shape # 7304 , 2
    assert m == m_z
    delta = jnp.zeros((n_x,n_z)) #(ngrid_pts,ngrid_pts) <- i.e (7304,7304)
    for d in jnp.arange(m):
        x_d = x[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        z_d = z[:,d] #(ngrid_pts-lat/lon,) <- (7304,)
        delta += (x_d[:,jnp.newaxis] - z_d)**2 # (7304,7304)

    return jnp.sqrt(delta) #(7304,7304)

def exp_sq_kernel(x, z, var, length, noise, jitter=1.0e-4):
    dist = dist_euclid(x, z) #(7304, 7304)
    deltaXsq = jnp.power(dist/ length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    k += (noise + jitter) * jnp.eye(x.shape[0])
    return k # (ngrid_pts, ngrid_pts) <- (7304,7304)

def M_g(M, g):
    '''
    - $M$ is a matrix with binary entries $m_{ij},$ showing whether point $j$ is in polygon $i$
    - $g$ is a vector of GP draws over grid
    - $maltmul(M, g)$ gives a vector of sums over each polygon
    '''
    M = jnp.array(M)
    g = jnp.array(g).T
    return(jnp.matmul(M, g))

# AggVAE Model

## Function for predictive simulation (prior)

In [4]:
def gp_aggr(args):
    x = args["x"]  # Spatial grid points: (num_grid_points, 2)
    gp_kernel = args["gp_kernel"]  # Gaussian Process kernel
    noise = args["noise"]
    jitter = args["jitter"]
    M = args["M"]  # (num_districts, num_grid_points) aggregation matrix

    # GP hyperparameters
    kernel_length = numpyro.sample("kernel_length", args["kernel_length"])
    kernel_var = numpyro.sample("kernel_var", args["kernel_var"])

    # GP Kernel and Sample
    k = gp_kernel(x, x, kernel_var, kernel_length, noise, jitter)
    f = numpyro.sample("f", dist.MultivariateNormal(loc=jnp.zeros(x.shape[0]), covariance_matrix=k))  # (num_grid_points,)

    # Aggregate GP values to district level
    gp_aggr = numpyro.deterministic("gp_aggr", M @ f)  # (num_districts,)

    return gp_aggr

## Read in the training dataset

In [5]:
df = pd.read_csv("../data/processed/df_aggVAE_rf_split.csv")

In [6]:
#get df_train only for observations where split = train
df_train = df[df['split'] == 'train']

In [7]:
#get df_test only for observations where split = test
df_test = df[df['split'] == 'test']

In [8]:
#output the train and test datasets
df_train.to_csv('../data/processed/df_train.csv')
df_test.to_csv('../data/processed/df_test.csv')

In [9]:
# Preprocess data BEFORE training loop
# Convert DataFrame features to JAX array once
train_features = jnp.array(
    df_train[['Cases', 'Population']].values, 
    dtype=jnp.float32
)
test_features = jnp.array(
    df_test[['Cases', 'Population']].values,
    dtype=jnp.float32
)

## Define the VAE

In [10]:
def vae_encoder(hidden_dim = 50, z_dim = 40):
    return stax.serial(
        #(num_samples, num_regions) -> (num_samples, hidden_dims)
        stax.Dense(hidden_dim, W_init = stax.randn()), 
        stax.Elu,
        stax.FanOut(2),
        stax.parallel(
            # mean : (num_samples, hidden_dim) -> (num_samples, z_dim)
            stax.Dense(z_dim, W_init = stax.randn()), #(5,50) 
            #std : (num_samples, hidden_dim) -> (num_samples, z_dim)
            stax.serial(stax.Dense(z_dim, W_init = stax.randn()), stax.Exp) 
        )
    )

def vae_decoder(hidden_dim, out_dim):
    return stax.serial(
        # (num_samples, z_dim) -> (num_samples, hidden_dim)
        stax.Dense(hidden_dim, W_init = stax.randn()),
        stax.Elu,
        # (num_samples, hidden_dim) -> (num_samples, num_regions) 
        stax.Dense(out_dim, W_init = stax.randn())
    )


def vae_model(batch, hidden_dim, z_dim):
    """This computes the decoder portion"""
    batch = jnp.reshape(batch, (batch.shape[0], -1)) # (num_samples, num_regions) 
    batch_dim, out_dim = jnp.shape(batch) 

    # vae-decoder in numpyro module
    decode = numpyro.module(
        name = "decoder",
        nn = vae_decoder(hidden_dim = hidden_dim, out_dim = out_dim),
        input_shape = (batch_dim, z_dim) #(5,40)
    )

    # Sample a univariate normal
    z = numpyro.sample(
        "z",
        dist.Normal(
            jnp.zeros((batch_dim,z_dim)),
            jnp.ones((batch_dim,z_dim))
            )
    ) 
    # Forward pass from decoder
    gen_loc = decode(z) #(num_regions,) 
    obs = numpyro.sample(
        "obs",
        dist.Normal(gen_loc, args["vae_var"]),
        obs = batch
    ) #(num_samples, num_regions) 
    return obs


def vae_guide(batch, hidden_dim, z_dim):
    """This computes the encoder portion"""
    batch = jnp.reshape(batch, (batch.shape[0], -1)) #(num_samples, num_regions) 
    batch_dim, input_dim = jnp.shape(batch)# num_samples , num_regions 

    # vae-encoder in numpyro module
    encode = numpyro.module(
        name = "encoder",
        nn = vae_encoder(hidden_dim=hidden_dim,z_dim = z_dim),
        input_shape = (batch_dim, input_dim) #(5,58)
    ) #(num_samples, num_regions) -> (num_samples, hidden_dims) 

    # Samapling mu, sigma - Pretty much the forward pass
    z_loc, z_std = encode(batch) #mu : (num_samples, z_dim), sigma2 : (num_samples, z_dim) 
    # Sample a value z based on mu and sigma
    z = numpyro.sample("z", dist.Normal(z_loc, z_std)) #(num_sample, z_dim) 
    return z

## Train the VAE encoder

In [11]:
@jax.jit
def epoch_train(rng_key, svi_state, num_train):
    def body_fn(i, val):
        rng_key_i = jax.random.fold_in(rng_key, i) #Array(2,)
        rng_key_i, rng_key_ls, rng_key_var, rng_key_noise = jax.random.split(rng_key_i, 4) #Tuple(Array(2,) x 4)
        loss_sum, svi_state = val #val --svi_state

        # Sample from REAL training data
        # Use JAX array indexing (no Pandas)
        batch_idx = jax.random.choice(
            rng_key_i, 
            train_features.shape[0], 
            (args["batch_size"],)
        )
        batch = train_features[batch_idx]  # Direct array indexing

        #* svi is where the vae_model & vae_guide gets applied
        svi_state, loss = svi.update(svi_state, batch)
        loss_sum += loss / args["batch_size"]
        return loss_sum, svi_state

    return lax.fori_loop(lower = 0, upper = num_train, body_fun=body_fn, init_val=(0.0, svi_state))

@jax.jit
def eval_test(rng_key, svi_state, num_test):
    def body_fn(i, loss_sum):
        rng_key_i = jax.random.fold_in(rng_key, i)
        rng_key_i, rng_key_ls, rng_key_varm, rng_key_noise = jax.random.split(rng_key_i, 4)

        # Use JAX array indexing (no Pandas)
        batch_idx = jax.random.choice(
            rng_key_i, 
            test_features.shape[0], 
            (args["batch_size"],)
        )
        batch = test_features[batch_idx]  # Direct array indexing
        
        #* svi is where the vae_model & vae_guide gets applied
        loss = svi.evaluate(svi_state, batch) / args["batch_size"]
        loss_sum += loss
        return loss_sum

    loss = lax.fori_loop(lower = 0, upper = num_test,body_fun =  body_fn, init_val = 0.0)
    loss = loss / num_test
    return loss

## Function to plot the GP

In [12]:
def plot_process(gp_draws):
    p = px.line()
    for i in range(len(gp_draws)):
        p.add_scatter(x = np.arange(gp_draws.shape[1]), y = gp_draws[i, :])

    p.update_traces(line_color = "black")
    p.update_layout(
        template = "plotly_white",
        xaxis_title = "region", yaxis_title = "num cases",
        showlegend = False)
    p.show()

## Load the variables

In [13]:
# Lat/Lon Values of artificial grid
x = np.load("../data/processed/lat_lon_x_all.npy")
pol_pts_all = np.load("../data/processed/pol_pts_all.npy")
pt_which_pol_all = np.load("../data/processed/pt_which_pol_all.npy")

df_combined = gpd.read_file("../data/processed/final_combined_divisions/final_combined_divisions.shp")

## Arguments

In [14]:
args = {
        "x": x,
        "gp_kernel": exp_sq_kernel,
        "noise": 1e-4,
        "M": jnp.array(pol_pts_all),
        "jitter" : 1e-4,
        # VAE training
        "rng_key": random.PRNGKey(5),
        #common num_epochs 20-50
        "num_epochs": 20,
        #learning rate 0.0005 common choice, ADAM optimiser adapts the learning rate accordingly
        "learning_rate": 0.0005,
        #chosen to be 100 (no tune)
        "batch_size": 100,
        #change this to the optimal values after hyperparameter tuning
        "hidden_dim": 50,
        "z_dim": 20,
        #chosen to be 100 (no tune)
        "num_train": 100,
        "num_test":100,
        #variance set to 1 bc the latent variable prior distribution is assumed to be normal
        "vae_var": 1,
        "kernel_length": dist.InverseGamma(3,3),
        "kernel_var": dist.HalfNormal(1e-5)
    }


## Prior predictive simulation

In [15]:
rng_key, rng_key_ = random.split(random.PRNGKey(4))
agg_gp_predictive = Predictive(gp_aggr,num_samples = 5)
agg_gp_draws = agg_gp_predictive(rng_key_, args)["gp_aggr"] #(num_samples, num_regions) 

In [None]:
# Plotting
plot_process(agg_gp_draws)

: 

## Hyperparameter tuning (z_dim, h_dim) using grid search

In [None]:
# Define hyperparameter grid
hidden_dims = [20, 30, 40, 50]
z_dims = [20, 30, 40, 50]

# Store results
test_loss_results_final = {}
test_loss_results_mean = {}

num_train = args["num_train"]
num_test = args["num_test"]
num_epochs = args["num_epochs"]

for hidden_dim, z_dim in itertools.product(hidden_dims, z_dims):
    print(f"Training with hidden_dim={hidden_dim}, z_dim={z_dim}")
        
    args["hidden_dim"] = hidden_dim
    args["z_dim"] = z_dim
        
    # Initialize optimizer and SVI
    adam = numpyro.optim.Adam(step_size=args["learning_rate"])
    svi = SVI(
        vae_model,
        vae_guide,
        adam,
        RenyiELBO(),
        hidden_dim=hidden_dim,
        z_dim=z_dim
    )
        
    # Split RNG keys
    rng_key, rng_key_samp, rng_key_init = random.split(args["rng_key"], 3)
    init_batch = df_train
    
    # Initialize SVI state
    # After: Convert to JAX array
    init_batch = jnp.array(df_train[['Cases', 'Population']].values)
    svi_state = svi.init(rng_key_init, init_batch)
        
    # Pre-allocate test loss array
    test_loss_list = jnp.zeros(num_epochs)
        
    # Training loop
    for epoch in range(num_epochs):
        rng_key, rng_key_train, rng_key_test = random.split(rng_key, 3)
        t_start = time.time()
            
        train_loss, svi_state = epoch_train(rng_key_train, svi_state, num_train)
        test_loss = eval_test(rng_key_test, svi_state, num_test)
        test_loss_list = test_loss_list.at[epoch].set(test_loss)
            
        print(f"Epoch: {epoch}, Train Loss: {train_loss:.2f}, Test Loss: {test_loss:.2f} ({time.time() - t_start:.2f} s)")
            
        if math.isnan(test_loss):  # Stop early if NaN
            print(f"NaN encountered at hidden_dim={hidden_dim}, z_dim={z_dim}. Skipping...")
            break  # Stop training if NaN occurs
        
    # Store results only if valid
    if not math.isnan(test_loss_list[-1]):
        test_loss_results_final[(hidden_dim, z_dim)] = test_loss_list[-1]
        test_loss_results_mean[(hidden_dim, z_dim)] = jnp.nanmean(test_loss_list)
        print(f"Final Test loss for hidden_dim={hidden_dim}, z_dim={z_dim}: {test_loss_list[-1]}")
        print(f"Mean Test loss for hidden_dim={hidden_dim}, z_dim={z_dim}: {jnp.nanmean(test_loss_list)}")
    
print("Grid search complete.")
print("Results:", test_loss_results_final)
print("Mean Test Loss Results:", test_loss_results_mean)

Training with hidden_dim=20, z_dim=20


### optimal parameters are 50 hidden_dim and 20 z_dim

## Initiate Training Loop with optimal hyperparameters 

In [None]:
adam = numpyro.optim.Adam(step_size = args["learning_rate"])
svi = SVI(
        vae_model,
        vae_guide,
        adam,
        RenyiELBO(),
        hidden_dim = args["hidden_dim"],
        z_dim = args["z_dim"]
    )

rng_key, rng_key_samp, rng_key_init = random.split(args["rng_key"],3)
#(num_samples, num_regions)
init_batch = agg_gp_predictive(rng_key_, args)["gp_aggr"] #(num_samples, num_regions) <- i.e (5,58)
svi_state = svi.init(rng_key_init, init_batch)

test_loss_list = []

In [None]:
for i in range(args["num_epochs"]):
    rng_key, rng_key_train, rng_key_test, rng_key_infer = random.split(rng_key, 4)
    t_start = time.time()

    num_train = 1000
    # Where forward/backward pass gets called for train
    train_loss , svi_state = epoch_train(rng_key_train, svi_state, num_train)

    num_test = 1000
    
    # Where forward/backward pass gets called for test
    test_loss = eval_test(rng_key_test, svi_state, num_test)
    test_loss_list += [test_loss]

    print("Epoch : {}, train loss : {:.2f}, test loss : {:.2f} ({:.2f} s.)".format(i, train_loss, test_loss, time.time() - t_start))

    if math.isnan(test_loss):
        break

Epoch : 0, train loss : 4832.24, test loss : 4.76 (214.08 s.)
Epoch : 1, train loss : 4780.85, test loss : 4.72 (213.74 s.)
Epoch : 2, train loss : 4741.84, test loss : 4.73 (213.64 s.)
Epoch : 3, train loss : 4727.01, test loss : 4.74 (213.71 s.)
Epoch : 4, train loss : 4732.86, test loss : 4.73 (213.87 s.)
Epoch : 5, train loss : 4720.53, test loss : 4.71 (214.00 s.)
Epoch : 6, train loss : 4708.59, test loss : 4.70 (214.04 s.)
Epoch : 7, train loss : 4708.42, test loss : 4.69 (213.91 s.)
Epoch : 8, train loss : 4707.25, test loss : 4.72 (213.81 s.)
Epoch : 9, train loss : 4706.12, test loss : 4.71 (213.74 s.)
Epoch : 10, train loss : 4701.12, test loss : 4.70 (213.68 s.)
Epoch : 11, train loss : 4697.85, test loss : 4.70 (213.65 s.)
Epoch : 12, train loss : 4694.84, test loss : 4.70 (213.67 s.)
Epoch : 13, train loss : 4698.27, test loss : 4.69 (213.84 s.)
Epoch : 14, train loss : 4699.49, test loss : 4.69 (213.99 s.)
Epoch : 15, train loss : 4700.86, test loss : 4.69 (214.07 s.)
Ep

In [None]:
#extract the decoder
decoder_params = svi.get_params(svi_state)

## Save the decoder

In [None]:
# Get script directory
script_dir = os.getcwd()  # Get current working directory

# Define the correct save path inside model_weights/
save_dir = os.path.abspath(os.path.join(script_dir, "..", "model_weights", "aggVAE"))
os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists

# Save decoder parameters
save_path = os.path.join(save_dir, f"aggVAE_e{args['num_epochs']}_h{args['hidden_dim']}_z{args['z_dim']}")

with open(save_path, "wb") as file:
    pickle.dump(decoder_params, file)

print(f"Decoder parameters saved to {save_path}")

Decoder parameters saved to /home/jupyter-jwidyawati/model_weights/aggVAE/aggVAE_e20_h50_z20
