In [None]:
# Optional!
# This will enable 64bit (double) precision for internal calculations
# More accurate, but half the speed. Whether it is necessary depends on the problem (to be discussed)

from os import environ
environ["JAX_ENABLE_X64"] = "True"

In [None]:
import numpy as np
import pandas as pd
from scipy.stats import truncnorm
import matplotlib.pyplot as plt

from jax import random, lax, numpy as jnp

from emu_filter.efjax.state_model import ensemble_step
from emu_filter.efjax.importance import truncnorm_importance
from emu_filter.efjax.outputs import get_counts_from_particles, plot_particle_results, get_links_from_pedigree, plot_links

In [None]:
# Inputs
contact_rate = 1.0
recovery_rate = 0.1
total_pop = 1000
inf_init = 12
n_particles = 100

# Needs to be jax array
observations = jnp.array([0, 15, 40, 65, 124, 204, 252, 210])
target_sd = 50.0

In [None]:
# Initialise particles

init_state = np.zeros([3, n_particles], dtype=np.int32)
init_state[0, :] = total_pop - inf_init  # Susceptible
init_state[1, :] = inf_init  # Infectious


In [None]:
# Generate a set of jax random keys; one for each timestep
k = random.PRNGKey(0)

def step_particle_filter(carry, x):
    particles, k = carry
    # Split the keys for our 2 random operations (and one to pass on to the next iteration)
    k_next, k_ens, k_choice = random.split(k, 3)

    proposed_particles = ensemble_step(particles, k_ens, contact_rate, recovery_rate, total_pop)
    
    # Importance    
    weights = truncnorm_importance(proposed_particles[1], observations[x], target_sd)
    norm_weights = weights / weights.sum()

    # Resampling
    indices = random.choice(k_choice, n_particles, shape=(n_particles,), p=norm_weights)
    resamp_particles = proposed_particles[:, indices]
    
    # Update
    return (resamp_particles,k_next), {"proposed": proposed_particles, "resampled": resamp_particles, "pedigree": indices}

_, res = lax.scan(step_particle_filter, (init_state,k), jnp.arange(0,len(observations)))

In [None]:
%%time
# Check performance post-JIT
_, res = lax.scan(step_particle_filter, (init_state, k), jnp.arange(0,len(observations)))

In [None]:
# Change the key to get different (repeatable) runs
k = random.PRNGKey(8)

_, res = lax.scan(step_particle_filter, (init_state, k), jnp.arange(0,len(observations)))

# Plotting tools expect all arrays to contain the initial state at the start

proposed = np.vstack([np.array((init_state,)),res["proposed"]])
resampled = np.vstack([np.array((init_state,)),res["resampled"]])
pedigree = res["pedigree"]
particles = resampled

plot_particle_results(proposed, resampled, observations)

In [None]:
#links = get_links_from_pedigree(particles, pedigree, observations)
#plot_links(particles, links, observations)