In [None]:
import jax.numpy as jnp
from jax import jit, vmap
import diffrax as dfx

# Constants
STANDARD_GRAV = 9.81

# Define control function
@jit
def lts_control(t, x , params):
    a = params[3]
    b = params[4]
    return jnp.arctan(a*t+ b)
# Define dynamics function with JAX
@jit
def dynamics(t, x, params):
    thrust, g, Isp = params
    u = lts_control(t, x, params)
    
    cos_theta = jnp.cos(u)
    sin_theta = jnp.sin(u)
    
    dx = jnp.zeros_like(x)
    dx = dx.at[0].set(x[2])                         # dx[0] = x[2]
    dx = dx.at[1].set(x[3])                         # dx[1] = x[3]
    dx = dx.at[2].set((thrust * cos_theta) / x[4])  # dx[2] = (thrust * cos_theta) / x[4]
    dx = dx.at[3].set((thrust * sin_theta) / x[4] - g)  # dx[3] = (thrust * sin_theta) / x[4] - g
    dx = dx.at[4].set(-thrust / (STANDARD_GRAV * Isp))  # dx[4] = -thrust / (STANDARD_GRAV * Isp)
    
    return dx

# Vectorize the dynamics function
vectorized_dynamics = vmap(dynamics, in_axes=(None, 0, None))

# Set up parameters and initial conditions
params = jnp.array([25000.0, 9.81, 300.0])  # Example: thrust, gravity, specific impulse, a0, b0
x0 = jnp.array([0.0, 0.0, 1.0, 1.0, 100.0])  # Initial state [x, y, vx, vy, mass]
t0, t1 = 0.0, 10.0  # Time span

# Set up the differential equation problem in Diffrax
solver = dfx.Tsit5()  # Tsitouras 5/4 Runge-Kutta method
stepsize_controller = dfx.ConstantStepSize()  # Optional: constant step size

# Define the problem with Diffrax
term = dfx.ODETerm(lambda t, x, args: dynamics(t, x, params))
solver = dfx.diffeqsolve(term, solver, t0=t0, t1=t1, dt0=0.1, y0=x0)

# Solution times and states
print("Times:", solver.ts)
print("Solution:", solver.ys)