# Stochastic Volatility Model

Implementing model in [Bayesian Filtering for Jump-Diffusions With Application to Stochastic Volatility](https://www.tandfonline.com/doi/abs/10.1198/jcgs.2009.07137).

\begin{equation}
    \begin{aligned}
        dX_t = \alpha dt + \sqrt{Z_t} dW_t^x + V_t^x dN_t \\
        dZ_t = (\theta + \kappa Z_t) dt + \sigma_z \sqrt{Z_t} dW_t^z + V_t^z dN_t \\
    \end{aligned}
\end{equation}

Euler-Maruyama approximation: 

\begin{equation}
    \begin{aligned}
        X_{t + \Delta t} = X_t + \alpha \Delta t + \sqrt{Z_t} \Delta W_t^x + V_t^x J_{t + \Delta t} \\
        Z_{t + \Delta t} = Z_t + (\theta + \kappa Z_t) \Delta t + \sigma_z \sqrt{Z_t} \Delta W_t^z + V_t^z J_{t + \Delta t} \\
    \end{aligned}
\end{equation}


For `pfjax`, we need to specify: 

- `state_lpdf`
- `state_sample`
- `meas_lpdf`
- `meas_sample`

In [3]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
from jax import lax
from pfjax import sde as sde


import numpy as np
import seaborn as sns
import scipy.stats
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
sns.set()

In [4]:
# Measurement model is X + N(0,1)
# State model is the joint density: (X, Z)

class SVJumps(sde.SDEModel):
    def __init__ (self, dt, n_res):
        super().__init__(dt, n_res, diff_diag=True)
        self._n_state = (self._n_res, 2) # 2 latent variables (X, Z)
        self.theta_names = ["alpha", "mu_x", "sigma_x", "lambda", "_theta", "kappa", "mu_z", "sigma_z"]
        
    def zip_theta (self, theta):
        return dict(zip(self.theta_names, theta))
    
    def state_lpdf (self, x_curr, x_prev, theta):
        r"""
        Evaluate the log-pdf for the current state
        """
        pass
    
    def state_sample (self, key, x_prev, theta):
        r"""
        Sample from the state model
        
        Args:
            - key: PRNGKey
            - x_prev: (X, Z)
            - theta: (alpha, mu_x, sigma_x, lambda, theta, kappa, mu_z, sigma_z)
        """
        theta_dict = self.zip_theta(theta)
        
        # sample jump sizes: 
        V_x = theta_dict["mu_x"] + theta_dict["sigma_x"]*random.normal(key) # jump size for X
        V_z = random.exponential(key)/theta_dict["mu_z"] # jump size for Z
        
        # sample jump:
        jump = random.bernoulli(key, p=theta_dict["lambda"]*self._dt)
        
        # sample X conditional on x_prev
        noise_term = jnp.sqrt(x_prev[-1, 1]*self._dt)
        x_mean = x_prev[-0] + theta_dict["alpha"]*self._dt + V_x*jump
        x_sd = noise_term
        
        z_mean = x_prev[1] + (theta_dict["_theta"] + theta_dict["kappa"]*x_prev[-1, 1])*self._dt + V_z*jump
        z_sd = theta_dict["sigma_z"]*noise_term
        
        return jnp.array([x_mean, z_mean]) + \
            jnp.array([x_sd, z_sd]) * random.normal(key, (self._n_state[1],))
    
    def meas_lpdf (self, y_curr, x_curr, theta):
        pass
    
    def meas_sample (self, key, x_curr, theta):
        return x_curr[-1]

In [5]:
theta_true = [0.5, 1., 0.1, 0.2, 1., 0.5, 2, 0.3]
n_res = 3
dt = 0.5

x_init = jnp.block([[jnp.zeros((n_res-1, 2))],
           [jnp.array([5., 3.])]])

sv = SVJumps(dt = dt, n_res = n_res)

x_curr = sv.state_sample(
    random.PRNGKey(0), 
    x_prev = x_init, 
    theta = theta_true)

In [6]:
x_curr

DeviceArray([[-0.7111381 ,  0.56467795],
             [ 0.28886187,  1.564678  ]], dtype=float32)

In [None]:
def simulate(model, key, n_obs, x_init, theta):
    def fun(carry, x):
        key, *subkeys = random.split(carry["key"], num=3)
        x_state = model.state_sample(subkeys[0], carry["x_state"], theta)
        y_meas = model.meas_sample(subkeys[1], x_state, theta)
        res = {"y_meas": y_meas, "x_state": x_state, "key": key}
        return res, res

    key, subkey = random.split(key)
    init = {
        "y_meas": model.meas_sample(subkey, x_init, theta),
        "x_state": x_init,
        "key": key
    }
    # scan itself
    last, full = lax.scan(fun, init, jnp.arange(n_obs-1))
    # append initial values
    x_state = tree_append_first(full["x_state"], first=init["x_state"])
    y_meas = tree_append_first(full["y_meas"], first=init["y_meas"])
    return y_meas, x_state