In [1]:
import jax.numpy as jnp
import casadi as ca
import numpy as np
import jax
from jax_casadi_callback import JaxCasadiCallback


f_test = lambda x,y: jnp.asarray([x[0]*y[0], 5*x[2]+y[1], 4*x[1]**2+y[1]*y[2] - 2*x[2], jnp.exp(y[3]) + x[2] * jnp.sin(x[0]),x[1]*y[2]])
    
opts={'in_dim':[[3,1],[4,1]],'out_dim':[[5,1]],'n_in':2,'n_out':1}
f_callback = JaxCasadiCallback('f1',f_test,opts)
#DM_x = ca.DM([1,2,3.0])
#DM_y = ca.DM([3,1,2,1.5])
DM_x = ca.DM_rand(3,1)
DM_y = ca.DM_rand(4,1)
jnp_x = jnp.asarray([1.0,2,3],dtype=jnp.float32)
jnp_y = jnp.asarray([3.0,1,2,1.5],dtype=jnp.float32)

v_callback = f_callback(DM_x,DM_y)
print('callback vjp evaluate:')
print(v_callback)
f_jax = jax.jit(f_test)

x = ca.MX.sym("x",3,1)
y = ca.MX.sym("x",4,1)
z = ca.vertcat(x[0]*y[0], 5*x[2]+y[1], 4*x[1]**2+y[1]*y[2] - 2*x[2], ca.exp(y[3]) + x[2] * ca.sin(x[0]),x[1]*y[2])
f_casadi = ca.Function('F', [x,y], [z])
v_casadi = f_casadi(DM_x,DM_y)
print('casadi original evaluate:')
print(v_casadi)


J_callback = ca.Function('J1', [x,y], [ca.jacobian(f_callback(x,y), x)])
v_jac_callback = J_callback(DM_x,DM_y)
print('callback vjp jacobian:')
print(v_jac_callback)
J_casadi = ca.Function('j_casadi', [x,y], [ca.jacobian(f_casadi(x,y), x)])
jac_casadi = J_casadi(DM_x,DM_y)
print('casasi jacobian:')
print(jac_casadi)

j_jax = jax.jit(jax.jacobian(f_jax))

H_callback = ca.Function('H1', [x,y], [ca.jacobian(J_callback(x,y), x)])
v_hes_callback = H_callback(DM_x,DM_y)
print('callback vjp hessian:')
print(v_hes_callback)
H_casadi = ca.Function('h_casadi', [x,y], [ca.jacobian(J_casadi(x,y), x)])
hes_casadi = H_casadi(DM_x,DM_y)
print('casasi hessian:')
print(hes_casadi)

h_jax = jax.jit(f_jax)



callback vjp evaluate:
[0.436633, 2.28303, -0.793243, 2.156, 0.0292098]
casadi original evaluate:
[0.436633, 2.28303, -0.793243, 2.156, 0.0292098]
callback vjp jacobian:

[[0.704854, 0, 0], 
 [0, 0, 5], 
 [0, 0.273053, -2], 
 [0.356811, 0, 0.580601], 
 [0, 0.855798, 0]]
casasi jacobian:

[[0.704854, 00, 00], 
 [00, 00, 5], 
 [00, 0.273053, -2], 
 [0.356811, 00, 0.580601], 
 [00, 0.855798, 00]]
callback vjp hessian:

[[0, 0, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [-0.254443, 0, 0.814188], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0, 8, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0, 0, 0], 
 [0.814188, 0, 0], 
 [0, 0, 0]]
casasi hessian:
sparse: 15-by-3, 4 nnz
 (3, 0) -> -0.254443
 (13, 0) -> 0.814188
 (7, 1) -> 8
 (3, 2) -> 0.814188


In [3]:
print('jax callback evaluate: ')
%timeit f_callback(DM_x,DM_y)
print('casadi callback evaluate: ')
%timeit f_casadi(DM_x,DM_y)
print('jax evaluate: ')
%timeit f_jax(jnp_x,jnp_y)


print('jax jacobian evaluate: ')
%timeit J_callback(DM_x,DM_y)
print('casadi jacobian evaluate: ')
%timeit J_casadi(DM_x,DM_y)
print('jax jacobian: ')
%timeit j_jax(jnp_x,jnp_y)

print('jax hessian evaluate: ')
%timeit H_callback(DM_x,DM_y)

print('casadi hessian evaluate: ')
%timeit H_casadi(DM_x,DM_y)
print('jax hessian: ')
%timeit h_jax(jnp_x,jnp_y)


jax callback evaluate: 
6.78 ms ± 42.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
casadi callback evaluate: 
45.9 µs ± 140 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
jax evaluate: 
18.2 µs ± 89.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
jax jacobian evaluate: 


RuntimeError: Error in Function::operator() for 'jac_wrap_f1' [MXFunction] at .../casadi/core/function.cpp:1368:
Error in Function::operator() for 'f1' [CallbackInternal] at .../casadi/core/function.cpp:1368:
.../casadi/core/function_internal.cpp:3366: Failed to evaluate 'eval_dm' for f1:
.../casadi/core/callback_internal.cpp:122: Error calling "eval" for object f1:
.../casadi/core/callback_internal.cpp:122: Assertion "(self_)!=0" failed:
Callback object has been deleted

In [2]:

f_test = lambda x,y: jnp.asarray([x[0,:]*y[0,:], 5*x[2,:]+y[1,:], 4*x[1,:]**2+y[1,:]*y[2,:] - 2*x[2,:], jnp.exp(y[3,:]) + x[2,:] * jnp.sin(x[0,:]),x[1]*y[2,:]])

rows=1000
opts={'in_dim':[[3,rows],[4,rows]],'out_dim':[[5,rows]],'n_in':2,'n_out':1}
f_callback = JaxCasadiCallback('f1',f_test,opts)
#DM_x = ca.DM([1,2,3.0])
#DM_y = ca.DM([3,1,2,1.5])
DM_x = ca.DM_rand(3,rows)
DM_y = ca.DM_rand(4,rows)
jnp_x = jnp.asarray(DM_x,dtype=jnp.float32)
jnp_y = jnp.asarray(DM_y,dtype=jnp.float32)

v_callback = f_callback(DM_x,DM_y)
print('callback vjp evaluate:')
print(v_callback)
f_jax = jax.jit(f_test)

x = ca.MX.sym("x",3,rows)
y = ca.MX.sym("x",4,rows)
z = ca.vertcat(x[0,:]*y[0,:], 5*x[2,:]+y[1,:], 4*x[1,:]**2+y[1,:]*y[2,:] - 2*x[2,:], ca.exp(y[3,:]) + x[2,:] * ca.sin(x[0,:]),x[1,:]*y[2,:])
f_casadi = ca.Function('F', [x,y], [z])
v_casadi = f_casadi(DM_x,DM_y)
print('casadi original evaluate:')
print(v_casadi)


J_callback = ca.Function('J1', [x,y], [ca.jacobian(f_callback(x,y), x)])
v_jac_callback = J_callback(DM_x,DM_y)
print('callback vjp jacobian:')
print(v_jac_callback)
J_casadi = ca.Function('j_casadi', [x,y], [ca.jacobian(f_casadi(x,y), x)])
jac_casadi = J_casadi(DM_x,DM_y)
print('casasi jacobian:')
print(jac_casadi)

j_jax = jax.jit(jax.jacobian(f_jax))

H_callback = ca.Function('H1', [x,y], [ca.jacobian(J_callback(x,y), x)])
v_hes_callback = H_callback(DM_x,DM_y)
print('callback vjp hessian:')
print(v_hes_callback)
H_casadi = ca.Function('h_casadi', [x,y], [ca.jacobian(J_casadi(x,y), x)])
hes_casadi = H_casadi(DM_x,DM_y)
print('casasi hessian:')
print(hes_casadi)

h_jax = jax.jit(f_jax)


callback vjp evaluate:

[[0.387178, 0.403445, 0.801093, ..., 0.523478, 0.0774161, 0.148112], 
 [1.37725, 4.40724, 4.78205, ..., 4.89566, 5.17198, 1.88916], 
 [4.07808, -1.59985, -3.03557e-06, ..., 0.0507618, -1.615, -0.2471], 
 [2.13315, 3.19641, 2.31766, ..., 2.20798, 2.64426, 1.72517], 
 [0.876587, 0.160778, 0.365872, ..., 0.4849, 0.0814726, 0.0693544]]
casadi original evaluate:

[[0.387178, 0.403445, 0.801093, ..., 0.523478, 0.0774161, 0.148112], 
 [1.37725, 4.40724, 4.78205, ..., 4.89566, 5.17198, 1.88916], 
 [4.07808, -1.59985, -3.03557e-06, ..., 0.0507618, -1.615, -0.2471], 
 [2.13315, 3.19641, 2.31766, ..., 2.20798, 2.64426, 1.72517], 
 [0.876587, 0.160778, 0.365872, ..., 0.4849, 0.0814726, 0.0693544]]


RuntimeError: Error in MX::jacobian at .../casadi/core/mx.cpp:1663:
Error in XFunction::jac for 'helper_jacobian_MX' [MXFunction] at .../casadi/core/x_function.hpp:719:
Error in MXFunction::ad_forward at .../casadi/core/mx_function.cpp:831:
Error in MX::ad_forward for node of type N6casadi4CallE at .../casadi/core/mx.cpp:2035:
Error in Call::ad_forward for 'f1' [CallbackInternal] at .../casadi/core/casadi_call.cpp:123:
Error in Function::jacobian for 'wrap_f1' [MXFunction] at .../casadi/core/function.cpp:824:
Error in XFunction::get_jacobian for 'wrap_f1' [MXFunction] at .../casadi/core/x_function.hpp:891:
Error in XFunction::jac for 'flattened_jac_wrap_f1' [MXFunction] at .../casadi/core/x_function.hpp:719:
Error in MXFunction::ad_reverse at .../casadi/core/mx_function.cpp:1042:
Error in MX::ad_reverse for node of type N6casadi4CallE at .../casadi/core/mx.cpp:2044:
Error in Call::ad_reverse for 'f1' [CallbackInternal] at .../casadi/core/casadi_call.cpp:147:
.../casadi/core/callback_internal.cpp:170: Error calling "has_reverse" for object f1:
KeyboardInterrupt

In [2]:
print('jax callback evaluate: ')
%timeit f_callback(DM_x,DM_y)
print('casadi callback evaluate: ')
%timeit f_casadi(DM_x,DM_y)
print('jax evaluate: ')
%timeit f_jax(jnp_x,jnp_y)


print('jax jacobian evaluate: ')
%timeit J_callback(DM_x,DM_y)
print('casadi jacobian evaluate: ')
%timeit J_casadi(DM_x,DM_y)
print('jax jacobian: ')
%timeit j_jax(jnp_x,jnp_y)

print('jax hessian evaluate: ')
%timeit H_callback(DM_x,DM_y)

print('casadi hessian evaluate: ')
%timeit H_casadi(DM_x,DM_y)
print('jax hessian: ')
%timeit h_jax(jnp_x,jnp_y)


jax callback evaluate: 
516 ms ± 5.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
casadi callback evaluate: 
796 µs ± 4.83 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
jax evaluate: 
18.3 µs ± 156 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
jax jacobian evaluate: 
1.93 s ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
casadi jacobian evaluate: 
716 µs ± 1.69 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
jax jacobian: 
18.4 µs ± 106 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
jax hessian evaluate: 
8.21 s ± 77.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
casadi hessian evaluate: 
788 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
jax hessian: 
18.4 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [12]:
param_converter = lambda x:x.full()
%timeit param_converter(DM_x)

5.59 µs ± 40.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [5]:
param_converter = lambda jnp_x:np.asarray(jnp_x)
%timeit param_converter(jnp_x)


1.89 µs ± 18.7 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
