
<p align="center">
  <img src="https://github.com/based-robotics/jaxadi/blob/master/_assets/_logo.png?raw=true" alt="JAXADI Logo" width="500"/>
</p>

Welcome to [JaxADi](https://github.com/based-robotics/jaxadi), a Python library designed to seamlessly bridge the gap between CasADi and JAX-compatible functions. By harnessing the power of both CasADi and JAX, JaxADi opens up a world of possibilities for creating highly efficient, batchable code that can be executed effortlessly across CPUs, GPUs, and TPUs.

JaxADi shines in various scenarios, including:

- Complex robotics simulations
- Challenging optimal control problems
- Machine learning models with intricate dynamics

Let's dive in and explore the capabilities of JaxADi!

# **Getting Started with JaxADi**




## **Installation**

Getting JaxADi up and running is a breeze. Simply use pip to install the [package]((https://pypi.org/project/jaxadi/)):


In [1]:
!pip install jaxadi



## **Basic Usage**

Define a CasADi function

In [2]:
import casadi as cs

# Define input variables
x = cs.SX.sym("x", 3, 2)
y = cs.SX.sym("y", 2, 2)
# Define a nonlinear function
z = x @ y  # Matrix multiplication
z_squared = z * z  # Element-wise squaring
z_sin = cs.sin(z)  # Element-wise sine
result = z_squared + z_sin  # Element-wise addition
# Create the CasADi function
casadi_fn = cs.Function("complex_nonlinear_func", [x, y], [result])
casadi_fn

Function(complex_nonlinear_func:(i0[3x2],i1[2x2])->(o0[3x2]) SXFunction)

Equivalent Numpy function

In [3]:
import numpy as np

def numpy_fn(x, y):
    z = x @ y  # Matrix multiplication
    z_squared = z * z  # Element-wise squaring
    z_sin = np.sin(z)  # Element-wise sine
    return z_squared + z_sin  # Element-wise addition

numpy_fn

Make equivalent Jax function

In [4]:
from jax import jit
from jax import numpy as jnp

@jit
def jax_fn(x, y):
    z = x @ y  # Matrix multiplication
    z_squared = z * z  # Element-wise squaring
    z_sin = jnp.sin(z)  # Element-wise sine
    return z_squared + z_sin  # Element-wise addition

jax_fn

<PjitFunction of <function jax_fn at 0x7f1a182eb370>>

Get JAX-compatible function string representation:

In [5]:
from jaxadi import translate

# Get JAX-compatible function string representation
jax_fn_string = translate(casadi_fn)
jax_fn_string

'def evaluate_complex_nonlinear_func(*args):\n    inputs = [jnp.expand_dims(jnp.array(arg), axis=-1) for arg in args]\n    o = [jnp.zeros(out) for out in [(3, 2)]]\n    o[0] = o[0].at[([0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1])].set([((((inputs[0][0, 0]*inputs[1][0, 0])+(inputs[0][0, 1]*inputs[1][1, 0])) * ((inputs[0][0, 0]*inputs[1][0, 0])+(inputs[0][0, 1]*inputs[1][1, 0])))+jnp.sin(((inputs[0][0, 0]*inputs[1][0, 0])+(inputs[0][0, 1]*inputs[1][1, 0]))))[0], ((((inputs[0][1, 0]*inputs[1][0, 0])+(inputs[0][1, 1]*inputs[1][1, 0])) * ((inputs[0][1, 0]*inputs[1][0, 0])+(inputs[0][1, 1]*inputs[1][1, 0])))+jnp.sin(((inputs[0][1, 0]*inputs[1][0, 0])+(inputs[0][1, 1]*inputs[1][1, 0]))))[0], ((((inputs[0][2, 0]*inputs[1][0, 0])+(inputs[0][2, 1]*inputs[1][1, 0])) * ((inputs[0][2, 0]*inputs[1][0, 0])+(inputs[0][2, 1]*inputs[1][1, 0])))+jnp.sin(((inputs[0][2, 0]*inputs[1][0, 0])+(inputs[0][2, 1]*inputs[1][1, 0]))))[0], ((((inputs[0][0, 0]*inputs[1][0, 1])+(inputs[0][0, 1]*inputs[1][1, 1])) * ((inputs


Define JAX function from CasADi one

In [6]:
from jaxadi import convert

# Define JAX function from CasADi one
jaxadi_fn = convert(casadi_fn, compile=True)
jaxadi_fn

<function jaxadi._declare.evaluate_complex_nonlinear_func(*args)>

In [7]:
# Random input matrices
input_x = np.random.rand(3, 2)
input_y = np.random.rand(2, 2)
input_x_jax = jnp.array(input_x)
input_y_jax = jnp.array(input_y)
input_x_cas = cs.DM(input_x)
input_y_cas = cs.DM(input_y)

In [8]:
# Call functions
output_jax = jax_fn(input_x_jax, input_y_jax)
output_jaxadi = jaxadi_fn(input_x_jax, input_y_jax)
output_casadi = casadi_fn(input_x_cas, input_y_cas)
output_numpy = numpy_fn(input_x, input_y)

# Compare results
assert np.allclose(output_jax, output_casadi)
assert np.allclose(output_jaxadi, output_casadi)
assert np.allclose(output_numpy, output_casadi)

In [9]:
%time
output_jax = jax_fn(input_x_jax, input_y_jax)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 8.11 µs


In [10]:
%%time
output_jaxadi = jaxadi_fn(input_x_jax, input_y_jax)

CPU times: user 36 ms, sys: 0 ns, total: 36 ms
Wall time: 46.8 ms


In [11]:
%%time
output_numpy = numpy_fn(input_x, input_y)

CPU times: user 77 µs, sys: 0 ns, total: 77 µs
Wall time: 81.8 µs


In [12]:
%%time
output_cas = casadi_fn(input_x_cas, input_y_cas)

CPU times: user 94 µs, sys: 0 ns, total: 94 µs
Wall time: 98.5 µs
