In [4]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jrandom
from jax import lax
from jax import config
import optax
config.update("jax_enable_x64", True)

import pfjax as pf
import pickle

In [5]:
import matplotlib.pyplot as plt

In [6]:
def logw_to_prob(logw):
    r"""
    Calculate normalized probabilities from unnormalized log weights.

    Args:
        logw: Vector of `n_particles` unnormalized log-weights.

    Returns:
        Vector of `n_particles` normalized weights that sum to 1.
    """
    wgt = jnp.exp(logw - jnp.max(logw))
    prob = wgt / jnp.sum(wgt)
    return prob

def resample_custom(key, x_particles_prev, logw):
    r"""
    Particle resampler with Multivariate Normal approximation.

    Args:
        key: PRNG key.
        x_particles_prev: An `ndarray` with leading dimension `n_particles` consisting of the particles from the previous time step.
        logw: Vector of corresponding `n_particles` unnormalized log-weights.

    Returns:
        A dictionary with elements:
            - `x_particles`: An `ndarray` with leading dimension `n_particles` consisting of the particles from the current time step.
            - `mvn_mean`: Vector of length `n_state = prod(x_particles.shape[1:])` representing the mean of the MVN.
            - `mvn_cov`: Matrix of size `n_state x n_state` representing the covariance matrix of the MVN.
    """
    prob = logw_to_prob(logw)
    
    n_particles = logw.shape[0]
    
    condition = jnp.sum(prob) > 0.
    
    particles_tmp = x_particles_prev[:,-1,0]
    prob_tmp = jnp.where(condition,
                         prob,
                         jnp.array([1./n_particles]*n_particles))
    
    vol_mean = jnp.average(particles_tmp, weights=prob_tmp)
    vol_var = jnp.average((particles_tmp - vol_mean)**2, 
                          weights=prob_tmp)
    
    vol_std = jnp.sqrt(jnp.where(vol_var > 0., vol_var, 1e-6)) 
    vol_samples = vol_mean + vol_std * jrandom.normal(key, shape=(n_particles,))
    x_particles = x_particles_prev.at[:,-1,0].set(vol_samples)
    return {
        "x_particles": x_particles,
        "vol_mean": vol_mean,
        "vol_std": vol_std
    }

In [7]:
class SDEJumpModel(object):
    """
    Generic SDE Jump model. 
    """
    def __init__(self, dt, n_res):
        self._dt = dt
        self._n_res = n_res
        
        def euler_sim_jump(key, x, dt, diff, jump, theta):
            """
            Simulate Jump-Diffusion SDE under Euler-Maruyama scheme.
            Args:
                key: PRNG key.
                x: Current latent state, a vector of size `n_dims`.
                dt: Time between observations, a scalar. 
                drift_diff: Drift and diffusion function that returns a vector of size `n_dims`.
                jump: Jump function that returns a vector of size `n_dims`.
                theta: Static parameters, a vector of size `n_pars`.
            Returns:
                Next latent state sample. A vector of size `n_dims`.
            """
            diff_subkey, jump_subkey = jrandom.split(key)
            diff_term = diff(diff_subkey, x, theta, dt)
            jump_term = jump(jump_subkey, x, theta, dt)
            return jnp.append(diff_term + jump_term[1:],
                              jump_term)
        
        def euler_sim(self, key, x, dt, theta):
            return euler_sim_jump(key, x, dt, self.diff, self.jump, theta)
        
        setattr(self.__class__, 'euler_sim', euler_sim)
    
    def state_sample(self, key, x_prev, theta):
        """
        Samples from `x_curr ~ p(x_curr | x_prev, theta)`.
        Args:
            key: PRNG key.
            x_prev: Latent state at previous time, an array of size `n_res` by `n_dim`.
            theta: Static parameters, a vector of size `n_pars`.
        Returns:
            Sample of the latent state at current time, a dictionary: 
                - "x": an array of size `n_res` by `n_dim`.
                - "key": PRNG key.
        """
        def fun(carry, t):
            key, subkey = jrandom.split(carry["key"])
            x = self.euler_sim(
                key=subkey, x=carry["x"],
                dt=self._dt/self._n_res, theta=theta
            )
            res = {"x": x, "key": key}
            return res, x
        init = {"x": x_prev[-1], "key": key}
        last, full = lax.scan(fun, init, jnp.arange(self._n_res))
        return full
    
    def meas_sample(self, key, x_curr, theta):
        """
        Sample from the error-free measurement model.
        Args:
            key: PRNG key.
            x_curr: Current latent state, an array of size `n_res` by `n_dim`.
            theta: Static parameters, a vector of size `n_pars`.
        Returns:
            Sample of the observation at current time, a scalar.
        """
        return x_curr[-1][1]
    
    def meas_lpdf(self, y_curr, x_curr, theta):
        """
        The log-density the error-free measurement model.
        Args:
            y_curr: Observation at current time, a scalar.
            x_curr: Current latent state, an array of size `n_res` by `n_dim`.
            theta: Static parameters, a vector of size `n_pars`.
        Returns:
            0.0, since measurement model is a delta function.
        """
        return 0.0
        
    

In [36]:
class StochVolJump(SDEJumpModel):
    """
    Jump-Diffusion model. 
    """
    def __init__(self, dt, n_res, lambda_star, vol_type):
        # Inherits from the SDEModel
        super().__init__(dt, n_res) 
        self._n_state = (self._n_res, 5) # [vol, price, jump, jumpsize_vol, jumpsize_price]
        self._lambda_star = lambda_star
        self._vol_type = vol_type
        self._dr_df = self._get_dr_df()
        
    def _dr_df_heston(self, x, param_transformed):
        """
        Drift term.
        Args:
            x: Current latent state, a vector of size `n_dims`.
            theta: Static parameters, a vector of size `n_pars`.
        Returns:
            A drift term, a vector of size 2.
        """
        param = self.get_param(param_transformed)
        mu_1 = param["kappa"]*(jnp.exp(param["theta"])-x[0])
        mu_2 = param["alpha"] - x[0]/2
        sigma_1 = param["sigma"]*jnp.sqrt(x[0])
        sigma_2 = jnp.sqrt(x[0])
        rho = param["rho"]
        return mu_1, mu_2, sigma_1, sigma_2, rho
    
    def _dr_df_expou(self, x, param_transformed):
        """
        Drift term.
        Args:
            x: Current latent state, a vector of size `n_dims`.
            theta: Static parameters, a vector of size `n_pars`.
        Returns:
            A drift term, a vector of size 2.
        """
        param = self.get_param(param_transformed)
        vol = jnp.exp(x[0])   
        mu_1 = param["kappa"]*(param["theta"]-x[0])
        mu_2 = param["alpha"] - vol/2
        sigma_1 = param["sigma"]
        sigma_2 = jnp.sqrt(vol)
        rho = param["rho"]
        return mu_1, mu_2, sigma_1, sigma_2, rho
    
    def _get_dr_df(self):
        if self._vol_type == "Heston":
            return self._dr_df_heston
        else:
            return self._dr_df_expou
    
    def dr_df(self, x, param_transformed):
        return self._dr_df(x, param_transformed)

    def diff(self, key, x, param_transformed, dt):
        #dr, df = self.dr_df(x, param_transformed)
        #diff_process = jrandom.multivariate_normal(key, mean= x[:2] + dr*dt, cov=df*dt)
        mu_1, mu_2, sigma_1, sigma_2, rho = self.dr_df(x, param_transformed)
        dr = jnp.array([mu_1, mu_2])
        df_chol = jnp.array([[sigma_1, 0.],
                             [rho*sigma_2, jnp.sqrt(1-rho**2)*sigma_2]])
        diff_process = x[:2] + dr*dt + jnp.matmul(df_chol, jrandom.normal(key, (2,)))*jnp.sqrt(dt)
        return diff_process
    
    def j_q(self, key, x, param_transformed, dt):
        
    
    def jump(self, key, x, param_transformed, dt):
        """
        Jump process.
        Args:
            key: PRNG key.
            x: Current latent state, a vector of size `n_dims`.
            theta: Static parameters, a vector of size `n_pars`.
            dt: Time step, a scalar.
        Returns:
            The jump process, a vector of size 3.
        """
        param = self.get_param(param_transformed)
        keys = jrandom.split(key, 3)
        is_jump = jrandom.bernoulli(keys[0], p=param["lambda"]*dt)
        jump_process = jnp.array([is_jump,
                                  jnp.where(is_jump, jrandom.exponential(keys[1]) * param["mu_z"], 0.), 
                                  jnp.where(is_jump, param["mu_x"] + param["sigma_x"]*jrandom.normal(keys[2]), 0.),
                                  ])
        return jump_process

    def _bridge_param(self, x, y_curr, param_transformed, n, vz, vx, jvx_invcumsum):
        """
        Calculate the mean and variance of the bridge proposal.
        Args:
            x: Current latent state, a vector of size `n_dims`.
            y_curr: Observation at current time, a scalar.
            theta: Static parameters, a vector of size `n_pars`.
            n: Index of the subinterval, a scalar.
            vz: Jump size in log asset price, a scalar.
            vx: Jump size in volatility, a scalar.
            vx_invcumsum: n-th inverse cumulative sum of volatility jump sizes, a scalar.
        Returns:
            Bridge mean and variance of the bridge proposal, a tuple of scalars.
        """
        k = self._n_res - n
        dt_res = self._dt/self._n_res
        mu_1, mu_2, sigma_1, sigma_2, rho = self.dr_df(x, param_transformed)

        mu_z = x[0] + mu_1*dt_res + vz
        sigma_z = sigma_1*jnp.sqrt(dt_res)

        mu_x = x[1] + (y_curr - x[1])/k + vx - jvx_invcumsum/k
        sigma_x = sigma_2*jnp.sqrt((k - 1.)/k*dt_res)

        return mu_z, mu_x, sigma_z, sigma_x, rho
    
    def pf_step(self, key, x_prev, y_curr, theta):
        """
        Particle update for a bridge particle filter.
        Args:
            key: PRNG key.
            x_prev: State variable at previous time `t-1`.
            y_curr: Measurement variable at current time `t`.
            theta: Parameter value.
        Returns:
            The result of the particle update, a tuple:
                - x_curr: Current latent state sample, an array of size `n_res` by `n_dim`.
                - logw: The log-weight of the particle, a scalar.
        """
        theta_use = self.get_theta(theta)

        key, jump_subkey, perm_subkey, z_subkey, x_subkey = jrandom.split(key, 5)
        jumps = jrandom.permutation(jump_subkey, jnp.append(jnp.zeros(self._n_res-1), x_prev[-1][-1]))
        vzs = (theta_use["mu_z"]*jax.random.exponential(z_subkey, shape=(self._n_res,)))
        vxs = (theta_use["mu_x"]+theta_use["sigma_x"]*jax.random.normal(x_subkey, shape=(self._n_res,)))
        jvxs = vxs * jumps
        jvxs_invcumsums = jax.lax.cumsum(jvxs[::-1])[::-1]

    
    """
    Helper functions for parameter transformation
    """
    def recover_param(self, param_transformed):
        return jnp.array([
            param_transformed[0],
            param_transformed[1],
            jnp.exp(param_transformed[2]),
            jnp.exp(param_transformed[3]),
            jsp.special.expit(param_transformed[4]),
            param_transformed[5],
            jnp.exp(param_transformed[6]),
            jnp.exp(param_transformed[7]),
            (jnp.exp(param_transformed[8])-1.)/(jnp.exp(param_transformed[8])+1.)
        ])
    
    def transform_param(self, param):
        return jnp.array([
            param[0],
            param[1],
            jnp.log(param[2]),
            jnp.log(param[3]),
            jnp.log(param[4])-jnp.log(1.-param[4]),
            param[5],
            jnp.log(param[6]),
            jnp.log(param[7]),
            jnp.log(param[8]+1.)-jnp.log(1.-param[8])
        ])
    
    def get_param(self, param_transformed):
        param = self.recover_param(param_transformed)
        d = {
            "alpha": param[0],
            "theta": param[1],
            "kappa": param[2],
            "sigma": param[3],
            "lambda": param[4],
            "mu_x": param[5],
            "sigma_x": param[6],
            "mu_z": param[7],
            "rho": param[8]
        }
        return d


In [37]:
# Settings
my_key = jrandom.PRNGKey(143)
my_dt = 1.
my_n_res = 20
my_n_obs = 252 * 5 # trading days in years
my_lambda_star = 0.3

In [38]:
heston = StochVolJump(my_dt,my_n_res,my_lambda_star,"Heston")
expou = StochVolJump(my_dt,my_n_res,my_lambda_star,"ExpOU")

In [43]:
my_heston_param = jnp.array([0.22, jnp.log(0.2), 0.023, 0.04, 
                            0.012, -2.1, 1.7, 0.24, 
                            -0.6])
my_x_init = jnp.block([[jnp.zeros((my_n_res-1, 5))],
                [jnp.array([jnp.exp(my_heston_param[1]), 100.0,0.0,0.0, 0.0])]])
my_param_transformed = heston.transform_param(my_heston_param)
y_meas, x_state = pf.simulate(heston, my_key, my_n_obs, my_x_init, my_param_transformed)

In [64]:
my_expou_param = jnp.array([0.22, jnp.log(0.2), 0.023, 0.04, 
                            0.012, -2.1, 1.7, 0.54,
                            -0.6])
my_x_init = jnp.block([[jnp.zeros((my_n_res-1, 5))],
                [jnp.array([(my_expou_param[1]), 100.0,0.0,0.0, 0.0])]])
my_param_transformed = expou.transform_param(my_expou_param)
y_meas, x_state = pf.simulate(expou, my_key, my_n_obs, my_x_init, my_param_transformed)

In [575]:
def pf_mvn_objective_full(param, key, n_particles):
    pf_mvn = pf.particle_filter(
            model=expou,
            key=key,
            y_meas=y_meas,
            theta=param,
            n_particles=n_particles,
            resampler=resample_custom
        )
    
    negloglik = -pf_mvn["loglik"]
    return negloglik

In [12]:
y_meas

Array([100.        , 100.15931131, 100.06231277, ..., 117.15425584,
       117.34576192, 117.29252756], dtype=float64)