From https://www.sciencedirect.com/science/article/pii/S030440762100227X#appSB

# Scalable Inference for stochastic volatility.
## Produce randomised datasets for the given model

In [63]:
import jax.random as jr
import jax.numpy as jnp
from jax import lax, jit, Array, vmap
import jax

SEED = 10

### Diagnostic Functions

In [64]:
def is_psd_cholesky(A: jnp.ndarray) -> bool:
    try:
        _ = jnp.linalg.cholesky(A + 1e-10 * jnp.eye(A.shape[0]))
        return True
    except jax.errors.ConcretizationTypeError:
        # when tracing under jit, don't use Python exceptions
        raise
    except Exception:
        return False


### Helper Functions

In [47]:
def givens_rotation(D, i, j, omega):
    """
    Create an n x n Givens rotation matrix that rotates in the (i, j) plane.

    Parameters:
        D (int): Dimension of the square matrix.
        i (int): First index (0-based).
        j (int): Second index (0-based).
        omega (float): Rotation angle in radians.

    Returns:
        jnp.ndarray: Givens rotation matrix.
    """
    
    # Start with identity
    G = jnp.eye(D)
    c = jnp.cos(omega)
    s = jnp.sin(omega)

    # Apply rotation in (i, j) plane
    G = G.at[i, i].set(c)
    G = G.at[j, j].set(c)
    G = G.at[i, j].set(s)
    G = G.at[j, i].set(-s)

    return G


@jit
def givens_product(w_ij: Array):
    """
    w_ij: (D, D) angles; we use only strict upper-triangle (i<j).
    
    Returns:
        G_all: (M, D, D) stack of Givens in lexicographic (i,j) order
        R: (D, D) product G_{(0)} @ G_{(1)} @ ... @ G_{(M-1)}
    """
    
    D = w_ij.shape[0]
    i_idx, j_idx = jnp.triu_indices(D, k=1)          # vectors of length M
    omegas = w_ij[i_idx, j_idx]                       # (M,)

    # construct Givens matrices over all (i,j,omega)
    G_all = vmap(
        lambda ii, jj, om: givens_rotation(D, ii, jj, om),
        in_axes=(0, 0, 0)
    )(i_idx, j_idx, omegas)

    # multiply all Givens matrices
    P = lax.associative_scan(lambda A, B: B @ A, G_all)[-1]
    return P


def 

In [82]:
def generate_stock_price_data(key: jr.PRNGKey, K: int, D: int, T: int):
    """
    Generate returns according to the model in scalable inference paper

    K: factor dimension
    D: returns dimension
    """

    assert D >= K, "Factor dimension cannot be greater than observation dimension"

    # randomly sample all necessary parameters 
    key, key_h = jr.split(key)
    h_i0 = jr.normal(key_h, shape=(K))

    key, key_d = jr.split(key)
    d_ij0 = jr.normal(key_d, shape=(K, K))
    d_ij0 = d_ij0.at[jnp.tril_indices(K)].set(0)

    key, key_phi_h = jr.split(key)
    phi_h = jr.uniform(key_phi_h, shape=(K))

    key, key_phi_d = jr.split(key)
    phi_d = jr.uniform(key_phi_d, shape=(K, K))
    phi_d = phi_d.at[jnp.tril_indices(K)].set(0)

    key, key_sigma_h = jr.split(key)
    sigma_h = jr.uniform(key_sigma_h, shape=(K))

    key, key_sigma_d = jr.split(key)
    sigma_d = jr.uniform(key_sigma_d, shape=(K, K))

    key, key_h1 = jr.split(key)
    h_i1 = (jr.normal(key_h1, shape=(K)) + h_i0) * sigma_h / jnp.sqrt(1 - phi_h**2)

    key, key_d1 = jr.split(key)
    d_ij1 = (jr.normal(key_d1, shape=(K, K)) + d_ij0) * sigma_d / jnp.sqrt(1 - phi_d**2)
    
    key, key_B = jr.split(key)
    B = jr.normal(key_B, shape=(D, K))

    key, key_V = jr.split(key)
    V = jnp.diag(jr.uniform(key_V, shape=(D)))

    # calculate factors and returns at t1 
    w_ij1 = (jnp.pi / 2) * (jnp.exp(d_ij1) - 1) / (jnp.exp(d_ij1) + 1)
    P1 = givens_product(w_ij1)
    L1 = jnp.diag(jnp.exp(h_i1))
    S1 = P1 @ L1 @ P1.T

    key, key_f, key_r = jr.split(key, 3)
    f1 = jr.multivariate_normal(key_f, jnp.zeros(K), S1)
    r1 = jr.multivariate_normal(key_r, B @ f1, V)

    # pre-generate noise for latent Gaussian processes
    key, key_eta_h, key_eta_d = jr.split(key, 3)
    eta_h = jr.normal(key_eta_h, shape=(T, K))
    eta_d = jr.normal(key_eta_d, shape=(T, K, K))

    # simulate latent dynamics
    def step(carry, noise):
        
        h_it, d_ijt = carry   # other params are accessed from global scope
        eta_ht, eta_dt = noise
        
        h_i = h_i0 + phi_h * (h_it - h_i0) + sigma_h * eta_ht 
        d_ij = d_ij0 + phi_d * (d_ijt - d_ij0) + sigma_d * eta_dt

        return (h_i, d_ij), (h_i, d_ij)
        
    carry0 = (h_i1, d_ij1)
    noise = (eta_h, eta_d)

    _, paths = lax.scan(step, carry0, noise)

    # back-calculate all rotation angles, factors and returns
    w_ij = vmap(lambda: )

    
    return h_i0, d_ij0, phi_h, phi_d, sigma_h, sigma_d, h_i1, d_ij1, B, V, w_ij1, L1, P1, S1, f1, r1, seq, final

In [84]:
h_i0, d_ij0, phi_h, phi_d, sigma_h, sigma_d, h_i1, d_ij1, B, V, w_ij1, eigs_1, P1, S1, f1, r1, seq, final = generate_stock_price_data(jr.PRNGKey(SEED), 5, 10, 100)

In [86]:
final

(Array([[ 1.16322368e-01, -8.00348520e-02, -1.19448161e+00,
          6.61807120e-01, -1.84862345e-01],
        [ 2.80451179e-02,  7.78588891e-01, -1.33685720e+00,
          9.77791369e-01,  6.88703835e-01],
        [-9.39396545e-02,  1.54957271e+00, -1.39010227e+00,
          2.03449273e+00,  8.01795006e-01],
        [-5.53299785e-01,  1.56143284e+00, -1.56367052e+00,
          1.50840664e+00,  1.51718175e+00],
        [ 2.23242402e-01,  2.41772485e+00, -1.31321096e+00,
          6.91148043e-01,  1.23480189e+00],
        [ 6.51936293e-01,  7.25263059e-01, -1.48336744e+00,
         -9.86862779e-02,  1.36611784e+00],
        [ 6.41526163e-01,  5.48454106e-01, -1.69663239e+00,
          7.22866893e-01,  1.01559985e+00],
        [ 7.77406454e-01,  1.36226511e+00, -1.65846574e+00,
          1.31945157e+00,  7.83660769e-01],
        [ 5.47308683e-01,  1.95877862e+00, -1.43825305e+00,
          1.58688676e+00,  5.52213073e-01],
        [-2.47778863e-01,  1.09520423e+00, -1.51004577e+00,
    