## Stochastic Volatility Model Explanation and Code

The stochastic volatility model is one that is commonly seen in finance for options pricing. They generally consist of a two-dimensional stochastic differential equation (SDE) represented below:

$$ dS_t = \alpha S_t dt + S_t V_t dB_{t}^{S} $$
$$ dV_t = \mu(V_t, \phi) dt + \sigma(V_t, \phi) dB_t^{V}

Where $B_t^{V}$ and $B_t^{S}$ are both brownian motions. In this particular case, we will be looking at the multivariate version of this SDE where we have the following equations:

$$dS_{it} = \alpha S_{it} dt + S_{it}V_{it}dB_{it}^{S}$$
$$d \log (V_{it}) = \gamma_{i}(\mu_i - \log V_{it})dt + \sigma_{i} dB_{it}^{V}$$
$$d \log (V_{0t}) = \gamma_{0}(\mu_0 - \log V_{0t})dt + \sigma_{0} dB_{0t}^{V}$$

In [None]:
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random

class StochVolModel(sde.SDEModel):
    def __init__(self, dt, n_res):
        # creates "private" variables self._dt and self._n_res
        super().__init__(dt, n_res, diff_diag=True)
        self._n_state = (self._n_res, 2)

    def drift(self, key, x, theta, q):
        """
        Calculates the SDE drift function.
        """

        for i in range(q):
            b_v = random.noraml(key)
            b_s = random.normal(key)
            alpha = theta[i, 0]
            mu = theta[i, 1]
            phi = theta[i, 2]
            sigma = theta[i, 3]
            new_S = alpha * x[i,0] + x[i, 0] * x[i, 1] * b_s
            new_V = mu(x[i, 0], phi) + sigma(x[i, 1], phi) * b_v

    def diff(self, x, theta):
        """
        Calculates the SDE diffusion function.
        """
        return theta[4:6]
