This notebook outlines how we used the functions provided in `jump_diffusion.py` and `particle_gibbs.py` to perform Bayesian filtering.

The Jump-Diffusion model is the following:

$$
\begin{aligned}
Z_{t+\Delta t} &= Z_t + (\theta + \kappa Z_t) \Delta t + \sigma_z \sqrt{Z_t} \Delta W_t^z + V_{t+\Delta t}^z J_{t+\Delta t} \\
X_{t+\Delta t} &= X_t + \alpha \Delta t + \sqrt{Z_t} \Delta W_t^x + V_{t+\Delta t}^x J_{t+\Delta t} 
\end{aligned}
$$

where

$$
\begin{aligned}
\Delta W_t^x, \Delta W_t^z &\sim \textrm{iid } N(0, \Delta t) \\
V_t^z &\sim \exp(\mu_z) \\
V_t^x &\sim N(\mu_x, \sigma^2_x) \\
J_{t+\Delta t} &\sim \textrm{Bern}(\lambda \Delta t).
\end{aligned}
$$

Thus, our parameters are $\Theta = (\alpha, \theta, \kappa, \sigma_z, \lambda,\mu_x, \sigma_x, \mu_z)$, observation is $X_t$ (the logarithm of asset price), and latent state $Z_t$ (unobserved volatility).

Note that all parameters that have positivity constraints are logged.

In [None]:
%load_ext autoreload
%autoreload 2

# needed for local imports in Jupyter:
import os
import sys
module_path = os.path.abspath(os.path.join('.'))
sys.path.append(module_path)
    
from jump_diffusion import *
from utils import *
from particle_filter_native import particle_filter_for

In [None]:
import pandas as pd
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt

# plotting
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import pfjax as pf
from functools import partial
from pfjax import particle_resamplers as resampler
import time
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

### Simulate Data

In [None]:
key = random.PRNGKey(0)

_dt = 1
_n_res = 5
n_obs = 3000

# parameters: 
alpha = 0.08
_theta = 0.02
kappa = -0.03
sigma_z = jnp.log(0.12)
gamma = jnp.log(0.01)
mu_x = -3.1
sigma_x = jnp.log(2.7)
mu_z = 1.7

theta = jnp.array([alpha, _theta, kappa, sigma_z, gamma, mu_x, sigma_x, mu_z])
x_init = jnp.block([[jnp.zeros((_n_res-1, 4))],
                    [jnp.array([2.0, 50.0, 0, 0])]])

jdmodel = JumpDiff(_dt, _n_res)

In [None]:
y_meas, x_state = pf.simulate(jdmodel, key, n_obs, x_init, theta)

point_plot = pd.DataFrame(jnp.array([
    jnp.arange(y_meas.shape[0]) * _n_res,
    y_meas
]).T, columns = ["Time", "Log Asset Price"])

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(12, 7), sharex = True)

sns.lineplot(data= x_state[..., 1].reshape(_n_res*n_obs, 1)[(_n_res-1):].squeeze(),
             ax = ax[0], linewidth=0.7,
             label = "Price").set(title ="$X_t$", ylabel="Price")
sns.lineplot(data= x_state[..., 0].reshape(_n_res*n_obs,1)[(_n_res-1):].squeeze(),
             ax = ax[1], alpha = 0.9, linewidth=0.7, color = "firebrick",
             label = "Volatility").set(xlabel="Time",title = "$Z_t$", ylabel="Volatility");

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(12, 7), sharex = True)

sns.lineplot(data= x_state[..., 1].reshape(_n_res*n_obs, 1)[(_n_res-1):].squeeze(),
             ax = ax[0], linewidth=0.7,
             label = "Price").set(title ="$X_t$", ylabel="Price")
sns.scatterplot(x = "Time", y = "Log Asset Price", 
                data = point_plot,
                color = "firebrick",
                ax = ax[0],
                s=5,
                label="Observed").set(title ="$X_t$")
sns.lineplot(data= x_state[..., 3].reshape(_n_res*n_obs,1)[(_n_res-1):].squeeze() + y_meas[0],
             ax = ax[0], alpha = 0.9, linewidth=0.7,
             label = "Price jumps");

sns.lineplot(data= x_state[..., 0].reshape(_n_res*n_obs,1)[(_n_res-1):].squeeze(),
             ax = ax[1], alpha = 0.9, linewidth=0.7, color = "firebrick",
             label = "Volatility").set(xlabel="Time",title = "$Z_t$", ylabel="Volatility");
sns.lineplot(data= x_state[..., 2].reshape(_n_res*n_obs,1)[(_n_res-1):].squeeze(),
             ax = ax[1], alpha = 0.9, linewidth=0.7,
             label = "Vol jumps");

### Unit Tests

In this section we test the native Python and JAX implementations of our functions: 

In [None]:
x_curr = x_state[2,:,:]
x_prev = x_state[1,:,:]
y_curr = x_prev[_n_res][1]

print("pf_step JAX: ", jdmodel.pf_step(key, x_prev, y_curr, theta))
print("pf_step native Python: ", jdmodel._pf_step_for(key, x_prev, y_curr, theta))

In [None]:
print("log-pdf JAX: ", jdmodel.state_lpdf(x_curr, x_prev, theta))
print("log-pdf native Python: ", jdmodel._state_lpdf_for(x_curr, x_prev, theta))

### Particle Filter

This section runs the particle filter on the simulated data

In [None]:
# create partial function for resampler to use with particle filter
num_particles = 100 #30_000
resample_jittered_multinomial = partial(
    jittered_multinomial,
    h=1/(num_particles*5)
)

In [None]:
pf_jit = jax.jit(partial(
    pf.particle_filter,
    model = jdmodel,
    key = random.PRNGKey(0),
    y_meas = y_meas,
    theta = theta,
    resampler = resample_jittered_multinomial,
    history=True
), static_argnames="n_particles")

multinom_pf = pf_jit(n_particles=num_particles)

In [None]:
# estimate volatility as weighted mean of particles at each timestep
est_vol_mean = jax.vmap(
    lambda x, w: jnp.average(x, axis=0, weights=pf.utils.logw_to_prob(w)),
    in_axes = (0, 0))(multinom_pf["x_particles"][1:, ..., 0],
                      multinom_pf["logw"][1:, ...])

est_vol_lower = jax.vmap(
    lambda x, logw: x[quantile_index(logw, q=0.025)],
    in_axes = (0, 0))(multinom_pf["x_particles"][1:, ..., 0],
                      multinom_pf["logw"][1:, ...])

est_vol_upper = jax.vmap(
    lambda x, logw: x[quantile_index(logw, q=0.975)],
    in_axes = (0, 0))(multinom_pf["x_particles"][1:, ..., 0],
                      multinom_pf["logw"][1:, ...])

In [None]:
# plot estimate of volatility along with 95% bounds on filtering distribution: 
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(12, 7), sharex = True)

sns.lineplot(
    data = x_state[1:, :, 0].flatten(),
    linewidth = 0.9, ax=ax[0],
    label = "True Volatility").set(xlabel="Time",title = "$Z_t$", ylabel="Volatility")
sns.lineplot(data = est_vol_mean.flatten(), 
             color = "firebrick", linewidth = 0.9, 
             alpha = 0.7, ax=ax[0],
             label = "Estimated Volatility")

sns.lineplot(
    data = x_state[1:, :, 0].flatten(), ax=ax[1],
    linewidth=0.9, 
    label = "True Volatility").set(xlabel="Time", ylabel="Volatility")
ax[1].fill_between(
    x = jnp.arange((n_obs-1)*_n_res),
    y1=est_vol_lower.flatten(),
    y2=est_vol_upper.flatten(),
    color = "firebrick",
    label="95% Posterior Bands",
    alpha = 0.3
)
ax[1].legend();

In [None]:
# Find RMSE of estimates volatility: 
rmse(est_vol_mean.flatten(), x_state[1:, :, 0].flatten())

In [None]:
# plot all particles for (X_t, Z_t) - this plot will take about an hour to run, not advised...
plot_particles(x_state, y_meas, 
               vol_particles=multinom_pf["x_particles"][1:, :, :, 0], 
               price_particles=multinom_pf["x_particles"][1:, :, :, 1],
               plot_res=False,
               n_res=_n_res,
               n_obs=n_obs,
               point_plot=point_plot,
               title = "Filtering of Jump Diffusion")

### Runtime Simulations

Before jumping into the simulation, we first test that both implementations of the particle filter produce the same results. This is done by comparing the estiamtes of the log-likelihood.

In [None]:
def particle_filter_for (model, key, y_meas, theta, n_particles, for_loop=False):
    r"""
    Implementation of the particle filter in Algorithm 1 of Stat 906 project writeup
    """
    n_obs = y_meas.shape[0]
    key, *subkeys = random.split(key, num=n_particles+1)
    x_particles = jnp.zeros((n_particles, *model._n_state))
    logw = jnp.zeros(n_particles)

    # initial particles and weights
    for i, _subkey in enumerate(subkeys):
        init_tmp = model.pf_init(key=_subkey, y_init=y_meas[0], theta=theta)
        x_particles = x_particles.at[i].set(init_tmp[0])
        logw = logw.at[i].set(init_tmp[1])

    # start particle filter: 
    all_particles = jnp.zeros((n_obs, *x_particles.shape))
    all_particles = all_particles.at[0].set(x_particles)
    all_logw = jnp.zeros((n_obs, n_particles))
    all_logw = all_logw.at[0].set(logw)
    loglik = jsp.special.logsumexp(logw)
    for t in jnp.arange(1, n_obs):
        key, subkey = random.split(key)

        # resample particles
        resample_out = resampler.resample_multinomial(
            key=subkey,
            x_particles_prev=x_particles,
            logw=logw
        )

        # sample particles for current timepoint
        key, *subkeys = random.split(key, num=n_particles+1)
        x_particles, logw = jax.vmap(
            lambda k, x, y: model.pf_step(key=k, x_prev=x, y_curr=y, theta=theta),
            in_axes=(0, 0, None)
        )(jnp.array(subkeys), resample_out["x_particles"], y_meas[t])
        
        loglik += jsp.special.logsumexp(logw) # log-likelihood calculation
        all_particles = all_particles.at[t].set(x_particles)
        all_logw = all_logw.at[t].set(logw)
    
    return {
        "x_particles": all_particles,
        "logw": all_logw,
        "loglik": loglik - n_obs * jnp.log(n_particles)
    }

In [None]:
# For loop implementation of particle filter: 
start = time.perf_counter()
jd_for = particle_filter_for(
    model = jdmodel,
    key = random.PRNGKey(0),
    y_meas = y_meas,
    theta = theta,
    n_particles = 100
)
print("Time: ", time.perf_counter() - start)

# JAX implementation of particle filter using pfjax: 
start = time.perf_counter()
multinom_pf = pf.particle_filter(
    theta=theta, 
    model=jdmodel, 
    y_meas=y_meas, 
    n_particles=100, 
    key=random.PRNGKey(0),
    history = True)
print("Time: ", time.perf_counter() - start)

In [None]:
print("Python log-likelihood: ", jd_for["loglik"])
print("JAX log-likelihood: ", multinom_pf["loglik"])

In [None]:
# partial function evals for 3 implementations of the particle filter: 
non_jax_pf = partial(
    particle_filter_for,
    model = jdmodel,
    key = random.PRNGKey(0),
    y_meas = y_meas,
    theta = theta   
)

non_jit_pf = partial(
    pf.particle_filter,
    model = jdmodel,
    key = random.PRNGKey(0),
    y_meas = y_meas,
    theta = theta   
)

jit_pf = jax.jit(partial(
    pf.particle_filter,
    model = jdmodel,
    key = random.PRNGKey(0),
    y_meas = y_meas,
    theta = theta,
), static_argnames="n_particles")

In [None]:
num_particle_list = [50, 100, 250, 500]
non_jax_info = pf_timer(non_jax_pf, num_particle_list, n_sim=3)
non_jit_info = pf_timer(non_jit_pf, num_particle_list, n_sim=15)
jit_info = pf_timer(jit_pf, num_particle_list, n_sim=15)

In [None]:
sns.lineplot(x=num_particle_list, y=jit_info["avg_times"], label="JIT").set(
    title = "Runtime for PF Different Implementations", xlabel = "Number of Particles",
    ylabel = "Runtime (seconds)"
)
sns.lineplot(x=num_particle_list, y=non_jit_info["avg_times"], label="JAX, no JIT")
sns.lineplot(x=num_particle_list, y=non_jax_info["avg_times"], label="Native Python");

In [None]:
timing_df = pd.DataFrame(
    dict(zip(num_particle_list, non_jit_info["avg_times"])).items(),
    columns=["Num particles", "Non-JIT Runtime"])
timing_df["JIT Runtime"] = jit_info["avg_times"]
timing_df["Non-JAX Runtime"] = non_jax_info["avg_times"]
timing_df
# print(timing_df.style.to_latex())

### S&P 500 Index Data

We also run the particle filter on S&P 500 daily closing prices from January 1986 to Janurary 2000

In [None]:
import yfinance as yf

snp_data = yf.download('^GSPC','1986-01-03','2000-01-03')
snp_closing = jnp.array(snp_data["Adj Close"])
snp_log_closing = jnp.log(snp_closing)

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(12, 6), sharex = True)
# fig.suptitle("S&P 500 Daily Closing Prices: Jan 1986 - Jan 2000")
sns.lineplot(snp_closing, ax = ax[0], label = "Daily Closing Price")
sns.lineplot(snp_log_closing, ax = ax[1], label = "log(Daily Closing Price)");

In [None]:
snp_theta = jnp.array([0.076, 0.018, -0.03, 0.007, -3.175, 2.595, 1.489])
snp_jdmodel = JumpDiff(dt=1, n_res=5)

snp_pf = jax.jit(partial(
    pf.particle_filter,
    model = snp_jdmodel,
    key = random.PRNGKey(0),
    y_meas = snp_closing,
    theta = snp_theta,
    resampler = resample_jittered_multinomial,
    history = True
), static_argnames="n_particles")

start = time.perf_counter()
snp_filtered = snp_pf(n_particles = num_particles)
print("Time: ", time.perf_counter() - start)

In [None]:
# plot estiamtes of volatility: 
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12, 5), sharex = True)

est_vol_mean = jax.vmap(
    lambda x, w: jnp.average(x, axis=0, weights=pf.utils.logw_to_prob(w)),
    in_axes = (0, 0))(snp_filtered["x_particles"][1:, ..., 0],
                      snp_filtered["logw"][1:, ...])

sns.lineplot(data = est_vol_mean.flatten(), 
             color = "firebrick", linewidth = 0.9, 
             alpha = 0.7, ax=ax,
             label = "Estimated Volatility").set(xlabel="Time",title = "Daily Closing Prices, S&P 500, $Z_t$", 
                                                 ylabel="Volatility");