
<p align="center">
  <img src="https://github.com/based-robotics/jaxadi/blob/master/_assets/_logo.png?raw=true" alt="JAXADI Logo" width="500"/>
</p>

Welcome to [JaxADi](https://github.com/based-robotics/jaxadi), a Python library designed to seamlessly bridge the gap between CasADi and JAX-compatible functions. By harnessing the power of both CasADi and JAX, JaxADi opens up a world of possibilities for creating highly efficient, batchable code that can be executed effortlessly across CPUs, GPUs, and TPUs.

JaxADi shines in various scenarios, including:

- Complex robotics simulations
- Challenging optimal control problems
- Machine learning models with intricate dynamics

Let's dive in and explore the capabilities of JaxADi!

# **Getting Started with JaxADi**




## **Installation**

Getting JaxADi up and running is a breeze. Simply use pip to install the [package]((https://pypi.org/project/jaxadi/)):


In [None]:
!pip install jaxadi

## **Basic Usage**

JaxADi offers a straightforward and intuitive API. Let's start by defining an example CasADi function:

In [None]:
import casadi as cs

# Define input variables
x = cs.SX.sym("x", 3, 2)
y = cs.SX.sym("y", 2, 2)
# Define a nonlinear function
z = x @ y  # Matrix multiplication
z_squared = z * z  # Element-wise squaring
z_sin = cs.sin(z)  # Element-wise sine
result = z_squared + z_sin  # Element-wise addition
# Create the CasADi function
casadi_fn = cs.Function("complex_nonlinear_func", [x, y], [result])
casadi_fn

Get JAX-compatible function string representation:

In [None]:
from jaxadi import translate

# Get JAX-compatible function string representation
jax_fn_string = translate(casadi_fn)
jax_fn_string


Define JAX function from CasADi one

In [None]:
from jaxadi import convert

# Define JAX function from CasADi one
jax_fn = convert(casadi_fn, compile=True)
jax_fn

Now, let's verify that our JaxADi function produces the same results as the original CasADi function:

In [None]:
import numpy as np
from jax import numpy as jnp

# Run compiled function
input_x = np.random.rand(3, 2)
input_y = np.random.rand(2, 2)
output_jaxadi = np.array(jax_fn(jnp.array(input_x), jnp.array(input_y)))
output_casadi = np.array(casadi_fn(input_x, input_y))
if np.allclose(output_jaxadi, output_casadi):
    print("The outputs of casadi and jaxadi functions are same")
else:
    print("Something went wrong...")

# **JaxADi in Action: Pendulum Rollout Example**

To showcase the power of JaxADi, let's dive into a practical example: simulating an uncontrolled pendulum. We'll compare the performance of CasADi and JAX implementations for batch simulations.

First, let's set up our pendulum model:

In [6]:
# Static parameters
dt = 0.02
g = 9.81  # Acceleration due to gravity
L = 1.0  # Length of the pendulum
b = 0.1  # Damping coefficient
I = 1.0  # Moment of inertia

Define pendulum model as CasADi function

In [7]:
state = cs.SX.sym("state", 2)
theta, omega = state[0], state[1]

theta_dot = omega
omega_dot = (-b * omega - (g / L) * cs.sin(theta)) / I

next_theta = theta + theta_dot * dt
next_omega = omega + omega_dot * dt

next_state = cs.vertcat(next_theta, next_omega)
casadi_pendulum = cs.Function("pendulum_model", [state], [next_state])

Convert it to JAX:

In [8]:
jax_model = convert(casadi_pendulum, compile=True)

Now, let's implement rollout functions for both CasADi and JaxADi:

In [10]:
import jax

timesteps = 100


def casadi_sequential_rollout(initial_states):
    batch_size = initial_states.shape[0]
    rollout_states = np.zeros((timesteps + 1, batch_size, 2))

    rollout_states[0] = initial_states
    for t in range(1, timesteps + 1):
        rollout_states[t] = np.array([casadi_pendulum(state).full().flatten() for state in rollout_states[t - 1]])

    return rollout_states


@jax.jit
def jax_vectorized_rollout(initial_states):
    def single_step(state):
        return jnp.array(jax_model(state)).reshape(
            2,
        )

    def scan_fn(carry, _):
        next_state = jax.vmap(single_step)(carry)
        return next_state, next_state

    _, rollout_states = jax.lax.scan(scan_fn, initial_states, None, length=timesteps)
    return jnp.concatenate([initial_states[None, ...], rollout_states], axis=0)

Let's compare the performance of these two implementations:


In [None]:
import timeit

# Test parameters
batch_size = 4096


def generate_random_inputs(batch_size):
    return np.random.uniform(-np.pi, np.pi, (batch_size, 2))


initial_states = generate_random_inputs(batch_size)
print("Warming up JAX...")
_ = jax_vectorized_rollout(initial_states)
print("Warm-up complete. Let's roll!")

print("\nPerformance Showdown:")
initial_states = generate_random_inputs(batch_size)

print(f"CasADi sequential rollout ({batch_size} pendulums, {timesteps} timesteps):")
casadi_time = timeit.timeit(lambda: casadi_sequential_rollout(initial_states), number=1)
print(f"Time: {casadi_time:.4f} seconds")

print(f"\nJAX vectorized rollout ({batch_size} pendulums, {timesteps} timesteps):")
jax_time = timeit.timeit(lambda: np.array(jax_vectorized_rollout(initial_states)), number=1)
print(f"Time: {jax_time:.4f} seconds")

print(f"\nSpeedup factor: {casadi_time / jax_time:.2f}x")

# Verify results
print("\nDouble-checking our results:")
casadi_results = casadi_sequential_rollout(initial_states[:10])
jax_results = np.array(jax_vectorized_rollout(initial_states[:10]))

print("First 10 rollouts match:", np.allclose(casadi_results, jax_results, atol=1e-4))


# **JaxADi + Other Libraries: A Perfect Match**

JaxADi plays well with other CasADi-oriented libraries. Let's see how we can use it with [liecasadi](https://github.com/ami-iit/liecasadi) to vectorize the `log` method for SO3 groups:

In [None]:
!pip install liecasadi

Let us form the casadi function that takes quaternion and returns tangent:

In [13]:
from liecasadi import SO3

# Create SO3 object from quaternion
quat = cs.SX.sym("quaternion", 4)
transform = SO3(xyzw=quat)
# Get the tangent via Log and convert this to function
tang_vec = transform.log().vec
tang_fn = cs.Function("tangent_function", [quat], [tang_vec])

Generate JAX function to calculate the tangent

In [14]:
jax_tang = convert(tang_fn, compile=True)

Test the functions

In [None]:
quat_random = np.random.randn(4)
quat_random /= np.linalg.norm(quat_random)
print(np.array(tang_fn(quat_random)).reshape(3))
print(np.array(jax_tang(quat_random)).reshape(3))

With this JAX-compatible function, you can now easily batch the log operation and perform your sample-based calculations efficiently!

# **Wrapping Up**

We've just scratched the surface of what's possible with JaxADi. There's a whole world of CasADi-oriented libraries out there waiting to be supercharged with JaxADi. We encourage you to explore how you can [use JaxADi to transform Pinocchio](https://github.com/based-robotics/jaxadi/blob/master/examples/04_pinocchio.py) calculations, [compare it with MJX](https://github.com/based-robotics/jaxadi/blob/master/examples/04_pinocchio.py), and dive into our repository for more [examples](https://github.com/based-robotics/jaxadi/tree/master/examples).

We're always on the lookout for exciting applications, such as parallelizable Model Predictive Control (MPC). If you come up with something cool, don't hesitate to share it with the community!

If JaxADi helps you in your research, we'd be thrilled if you could cite it:

```bibtex
@misc{jaxadi2024,
  title = {JaxADi: Bridging CasADi and JAX for Efficient Numerical Computing},
  author = {Alentev, Igor and Kozlov, Lev and Nedelchev, Simeon},
  year = {2024},
  url = {https://github.com/based-robotics/jaxadi},
  note = {Accessed: [Insert Access Date]}
}
```


Got questions, issues, or brilliant ideas? We'd love to hear from you! [Open an issue](https://github.com/based-robotics/jaxadi/issues)  on our GitHub repository, and let's make JaxADi even better together.

We hope JaxADi supercharges your numerical computing and optimization tasks. Now go forth and compute efficiently! Happy coding!