In [3]:
import jax
import jax.numpy as jnp
import scipy 

In [4]:
class CIR:
    def __init__(self, tau, alpha, beta, seed=None):
        self.tau = tau
        self.alpha = alpha
        self.beta = beta
        self.key = jax.random.key(0) if seed is None else jax.random.key(seed)

    @staticmethod
    def sample_wiener_increments(self, num_paths, num_timepoints, delta_t):
        return (
            jax.random.normal(self.key, (num_paths, num_timepoints)) 
            * jnp.sqrt(delta_t)
        )

    @staticmethod
    def sample_ncx2(self, lambda_):
        return scipy.stats.ncx2(lambda_)

    def __call__(self, x_0, num_paths, num_timesteps, delta_t, method):
        # TODO: add validation logic for x_0 (should be 1D array)

        # Preallocate and set initial value.
        x = jnp.full((num_paths, num_timesteps), fill_value=jnp.nan)
        x[:, 0] = x_0

        if method == 'euler_maruyama':
            dW = self.sample_wiener_increments(num_paths, num_timesteps, delta_t)
            kappa = 1 / self.tau
            sigma = jnp.sqrt(2 / (self.tau * self.beta))
            theta = self.alpha / self.beta
            for i_timestep in range(num_timesteps-1):
                dX = kappa * (theta - x[:, i_timestep]) + sigma * jnp.sqrt(x[:, i_timestep]) * dW[:, i_timestep]
                x[:, i_timestep+1] = x[:, i_timestep] + dX

            # Collect params to be returned.
            params = {
                'kappa' : kappa,
                'theta' : theta, 
                'sigma' : sigma
            }

        elif method == 'ncx2':
            # Precompute necessary parameters for sampling.
            kappa = 1 / self.tau
            sigma_sqr = 2 / (self.tau * self.beta)
            upsilon = jnp.exp(-kappa * delta_t)
            eta = (
                4 * kappa * upsilon 
                / (sigma_sqr * (1 - upsilon))
            )
            dof = 4 * kappa * (self.alpha / self.beta) / sigma_sqr

            # Iteratively update via sampling.
            for i_timestep in range(num_timesteps-1):
                x[:, i_timestep + 1] \
                    = self.sample_ncx2(eta * x[:, i_timestep])
                
            # Collect params to be returned.
            params = {
                'kappa' : kappa,
                'sigma_sqr' : sigma_sqr,
                'eta' : eta,
                'dof' : dof
            }

        return x, params



In [5]:

params = {
    'rate' : 10,
    'parameters' : {
        'tau' : 20,
        'beta' : 1,
    },
    'simulation' : {
        'num_paths' : 10,
        'num_timesteps' : 1000,
        'delta_t' : 0.1,
        'method' : 'euler_maruyama'
    }
}

params['parameters']['alpha'] = params['parameters']['beta'] * params['rate']
params['simulation']['x_0'] = jnp.zeros((params['simulation']['num_paths'],))

cir = CIR(**params['parameters'])




In [6]:
x, sim_params = cir(**params['simulation'])

TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html