In [1]:
from jax import config 
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, vmap
from scipy.optimize import minimize
import jaxopt as jo
import numpy as np
import jax

In [2]:
def distance(x, y):
    return np.linalg.norm(x - y)


def distance_jax(x, y):
    return jnp.linalg.norm(x - y)


params = np.arange(5).astype(float)
np.random.seed(1234)
y_arr = np.random.uniform(size=(100, 5)) 


params_jax = jnp.array(params)
y_arr_jax = jnp.array(y_arr)





In [3]:
solver = jo.LBFGS(fun=distance_jax)

def solve(x, y):
    return solver.run(init_params=x, y=y)[0]

batch_solve = jit(vmap(solve, in_axes=(None, 0)))


%timeit batch_solve(params_jax, y_arr_jax)



6.5 ms ± 87 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
%%timeit 
for y in y_arr:
    minimize(distance, params, method="BFGS", args=(y,))

2.11 s ± 30.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
res_jax = batch_solve(params_jax, y_arr_jax)

In [13]:
res_scipy = np.stack(
    [minimize(distance, params, method="BFGS", args=(y,)).x
    for y in y_arr],
)

In [17]:
from numpy.testing import assert_array_almost_equal

In [18]:
assert_array_almost_equal(res_scipy, y_arr)

In [19]:
assert_array_almost_equal(res_jax, y_arr_jax)

AssertionError: 
Arrays are not almost equal to 6 decimals

Mismatched elements: 500 / 500 (100%)
Max absolute difference: 1.8668777e+164
Max relative difference: 1.37190048e+165
 x: array([[-3.998883e+154,  7.870015e+154,  3.293105e+155,  7.617420e+155,
         1.103644e+156],
       [ 5.120822e+157, -1.247023e+158, -8.499112e+157, -1.495049e+158,...
 y: array([[0.191519, 0.622109, 0.437728, 0.785359, 0.779976],
       [0.272593, 0.276464, 0.801872, 0.958139, 0.875933],
       [0.357817, 0.500995, 0.683463, 0.712702, 0.370251],...