In [1]:
from jax import config 
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, vmap
from scipy.optimize import minimize
import jaxopt as jo
import numpy as np
import jax
from numpy.testing import assert_array_almost_equal

## Define criterion and params

In [2]:
import numpy as np
import jax.numpy as jnp
import pytest

In [18]:
def distance(x, y):
    return np.sum((x - y) ** 2)


def distance_jax(x, y):
    return jnp.sum((x - y) ** 2)


params = np.arange(5).astype(float)
np.random.seed(1234)
y_arr = np.random.uniform(size=(100, 5)) 


params_jax = jnp.array(params)
y_arr_jax = jnp.array(y_arr)

## Scipy

In [19]:
from scipy.optimize import minimize

In [21]:
def batched_optimization_scipy():
    result = []
    for y in y_arr:
        res = minimize(fun=distance2, x0=params, method="L-BFGS-B", args=(y,))
        result.append(res.x)
    return np.stack(result)

In [9]:
result = batched_optimization_scipy()

assert_array_almost_equal(result, y_arr, decimal=5)

In [22]:
%timeit batched_optimization_scipy()

138 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## JAXopt

In [8]:
def batched_optimization_jax_loop(solver):
    results = []
    for y in y_arr_jax:
        res = solver.run(init_params=params, y=y)
        results.append(res.params)
    return jnp.stack(results)

## JAXopt Scipy

In [12]:
from jaxopt import ScipyMinimize

In [13]:
solver = ScipyMinimize(method="L-BFGS-B", fun=distance_jax)

In [14]:
results = batched_optimization_jax_loop(solver)

assert_array_almost_equal(results, y_arr_jax)

In [15]:
%timeit batched_optimization_jax_loop(solver)

291 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## JAXopt

In [5]:
from jaxopt import LBFGS

In [6]:
solver = LBFGS(fun=distance_jax)

In [9]:
results = batched_optimization_jax_loop(solver)

assert_array_almost_equal(results, y_arr_jax)


KeyboardInterrupt



### GradientDescent

In [None]:
from jaxopt import GradientDescent

solver = GradientDescent(fun=distance_jax)

In [None]:
results = batched_optimization_jax_loop(solver)

assert_array_almost_equal(results, y_arr_jax, decimal=4)

In [None]:
# %timeit batched_optimization_jax_loop(solver)

### NonlinearCG

In [None]:
from jaxopt import NonlinearCG

In [None]:
solver = NonlinearCG(fun=distance_jax)

In [None]:
results = batched_optimization_jax_loop(solver)

In [None]:
assert_array_almost_equal(results, y_arr_jax, decimal=3)

### Proper batching

In [14]:
solver = jo.LBFGS(fun=distance_jax2)

def solve(x, y):
    return solver.run(init_params=x, y=y)[0]

batch_solve = jit(vmap(solve, in_axes=(None, 0)))

In [15]:
result = batch_solve(params_jax, y_arr_jax)

In [16]:
assert_array_almost_equal(result, y_arr_jax)

In [17]:
%timeit batch_solve(params_jax, y_arr_jax)

117 µs ± 3.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
