In [1]:
import jax.numpy as jnp
import casadi as ca
import numpy as np
import jax
from jax_casadi_callback import JaxCasadiCallback

f_test = lambda x,y: jnp.asarray([x[0]*y[0], 5*x[2]+y[1], 4*x[1]**2+y[1]*y[2] - 2*x[2], jnp.exp(y[3]) + x[2] * jnp.sin(x[0]),x[1]*y[2]])
    
opts={'in_dim':[[3,1],[4,1]],'out_dim':[[5,1]],'n_in':2,'n_out':1}
f_callback = JaxCasadiCallback('f1',f_test,opts)
DM_x = ca.DM([1,2,3.0])
DM_y = ca.DM([3,1,2,1.5])

jnp_x = jnp.asarray([1.0,2,3],dtype=jnp.float32)
jnp_y = jnp.asarray([3.0,1,2,1.5],dtype=jnp.float32)

v_callback = f_callback(DM_x,DM_y)
print('callback vjp evaluate:')
print(v_callback)

x = ca.MX.sym("x",3,1)
y = ca.MX.sym("x",4,1)
z = ca.vertcat(x[0]*y[0], 5*x[2]+y[1], 4*x[1]**2+y[1]*y[2] - 2*x[2], ca.exp(y[3]) + x[2] * ca.sin(x[0]),x[1]*y[2])
f_casadi = ca.Function('F', [x,y], [z])
v_casadi = f_casadi(DM_x,DM_y)
print('casadi original evaluate:')
print(v_casadi)


J_callback = ca.Function('J1', [x,y], [ca.jacobian(f_callback(x,y), x)])
v_jac_callback = J_callback(DM_x,DM_y)
print('callback vjp jacobian:')
print(v_jac_callback)
J_casadi = ca.Function('j_casadi', [x,y], [ca.jacobian(f_casadi(x,y), x)])
jac_casadi = J_casadi(DM_x,DM_y)
print('casasi jacobian:')
print(jac_casadi)

H_callback = ca.Function('H1', [x,y], [ca.jacobian(J_callback(x,y), x)])
v_hes_callback = H_callback(DM_x,DM_y)
print('callback vjp hessian:')
print(v_hes_callback)
H_casadi = ca.Function('h_casadi', [x,y], [ca.jacobian(J_casadi(x,y), x)])
hes_casadi = H_casadi(DM_x,DM_y)
print('casasi hessian:')
print(hes_casadi)


callback vjp evaluate:
[3, 16, 12, 7.0061, 4]
casadi original evaluate:
[3, 16, 12, 7.0061, 4]
callback vjp jacobian:

[[3, 0, 0], 
 [0, 0, 5], 
 [0, 16, -2], 
 [1.62091, 0, 0.841471], 
 [0, 2, 0]]
casasi jacobian:

[[3, 00, 00], 
 [00, 00, 5], 
 [00, 16, -2], 
 [1.62091, 00, 0.841471], 
 [00, 2, 00]]
callback vjp hessian:

[[0, 0, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [-2.52441, 0, 0.540302], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0, 8, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0.540302, 0, 0], 
 [0, 0, 0]]
casasi hessian:
sparse: 15-by-3, 4 nnz
 (3, 0) -> -2.52441
 (13, 0) -> 0.540302
 (7, 1) -> 8
 (3, 2) -> 0.540302


In [2]:
print('jax callback evaluate: ')
%timeit f_callback(DM_x,DM_y)
print('casadi callback evaluate: ')
%timeit f_casadi(DM_x,DM_y)

print('jax jacobian evaluate: ')
%timeit J_callback(DM_x,DM_y)
print('casadi jacobian evaluate: ')
%timeit J_casadi(DM_x,DM_y)

print('jax hessian evaluate: ')
%timeit H_callback(DM_x,DM_y)

print('casadi hessian evaluate: ')
%timeit H_casadi(DM_x,DM_y)


jax callback evaluate: 
535 µs ± 4.63 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
casadi callback evaluate: 
4.58 µs ± 39.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
jax jacobian evaluate: 
2.01 ms ± 37 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
casadi jacobian evaluate: 
4.3 µs ± 18.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
jax hessian evaluate: 
8.35 ms ± 59.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
casadi hessian evaluate: 
4.59 µs ± 16.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


47.6 ns ± 0.275 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
