In [32]:
import jax
import jax.numpy as jnp
import numpy as np

# Define the Lorenz 63 system
def lorenz63(state, params):
    x, y, z = state
    sigma, rho, beta = params
    dxdt = sigma * (y - x)
    dydt = x * (rho - z) - y
    dzdt = x * y - beta * z
    return jnp.array([dxdt, dydt, dzdt])

# Analytic Jacobian
def analytic_jacobian(state, params):
    x, y, z = state
    sigma, rho, beta = params
    jac = jnp.array([
        [-sigma,  sigma,  0.0],
        [rho - z, -1.0,  -x],
        [y,       x,  -beta]
    ])
    return jac

def analytical_eigenvalues(params):
    sigma, rho, beta = np.array(params)
    lambda_1 = -beta

    A = -(sigma+1.)
    B = np.sqrt((sigma+1.)**2 + 4*sigma*(rho-1))
    lambda_2 = (A + B)/2.
    lambda_3 = (A-B)/2.
    return [lambda_1,lambda_2,lambda_3]


def calculate_rH(params):
    sigma, rho, beta = np.array(params)
    rH = sigma*(sigma+beta+3.)/(sigma-beta-1)
    return rH


In [33]:
# Example state and parameters
state = jnp.array([.0, .0, .0])
params = jnp.array([10.0, 28.0, 8/3])
rH = calculate_rH(params)

# Compute the Jacobian with respect to the state vector using JAX
jacobian_fn = jax.jacfwd(lorenz63, argnums=0)
jax_jacobian = jacobian_fn(state, params)

# Compute the analytic Jacobian
analytic_jac = analytic_jacobian(state, params)

# Compute analytical eigen values
analytical_eigen = analytical_eigenvalues(params)

#Compute numerical eigen valuesb
num_eigen_values, eigen_vectors = np.linalg.eig(np.array(jax_jacobian))

In [9]:
print("JAX-Computed Jacobian:\n", jax_jacobian)
print("\nAnalytic Jacobian:\n", analytic_jac)

# Compare the Jacobians
print("\nDifference between JAX and Analytic Jacobian:\n", jax_jacobian - analytic_jac)

In [34]:
rH

24.736843848822936

In [28]:
analytical_eigen

[-2.6666667, 11.827723451163457, -22.827723451163457]

In [29]:
num_eigen_values

array([-22.827723 ,  11.8277235,  -2.6666667], dtype=float32)