In [1]:
from jax import config 
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, vmap
from scipy.optimize import minimize
import jaxopt as jo
import numpy as np
import jax
from numpy.testing import assert_array_almost_equal

## Define criterion and params

In [2]:
import numpy as np
import jax.numpy as jnp
import pytest

ModuleNotFoundError: No module named 'pytest'

In [None]:
def distance(x, y):
    return np.sum((x - y) ** 2)


def distance_jax(x, y):
    return jnp.sum((x - y) ** 2)


params = np.arange(5).astype(float)
np.random.seed(1234)
y_arr = np.random.uniform(size=(100, 5)) 


params_jax = jnp.array(params)
y_arr_jax = jnp.array(y_arr)

## Scipy

In [3]:
from scipy.optimize import minimize

In [4]:
def batched_optimization_scipy():
    result = []
    for y in y_arr:
        res = minimize(fun=distance2, x0=params, method="L-BFGS-B", args=(y,))
        result.append(res.x)
    return np.stack(result)

In [5]:
result = batched_optimization_scipy()

assert_array_almost_equal(result, y_arr, decimal=5)

NameError: name 'y_arr' is not defined

In [6]:
%timeit batched_optimization_scipy()

NameError: name 'y_arr' is not defined

## JAXopt

In [8]:
def batched_optimization_jax_loop(solver):
    results = []
    for y in y_arr_jax:
        res = solver.run(init_params=params, y=y)
        results.append(res.params)
    return jnp.stack(results)

## JAXopt Scipy

In [12]:
from jaxopt import ScipyMinimize

In [13]:
solver = ScipyMinimize(method="L-BFGS-B", fun=distance_jax)

In [14]:
results = batched_optimization_jax_loop(solver)

assert_array_almost_equal(results, y_arr_jax)

In [15]:
%timeit batched_optimization_jax_loop(solver)

291 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## JAXopt

In [5]:
from jaxopt import LBFGS

In [6]:
solver = LBFGS(fun=distance_jax)

In [9]:
results = batched_optimization_jax_loop(solver)

assert_array_almost_equal(results, y_arr_jax)


KeyboardInterrupt



### GradientDescent

In [None]:
from jaxopt import GradientDescent

solver = GradientDescent(fun=distance_jax)

In [None]:
results = batched_optimization_jax_loop(solver)

assert_array_almost_equal(results, y_arr_jax, decimal=4)

In [None]:
# %timeit batched_optimization_jax_loop(solver)

### NonlinearCG

In [None]:
from jaxopt import NonlinearCG

In [None]:
solver = NonlinearCG(fun=distance_jax)

In [None]:
results = batched_optimization_jax_loop(solver)

In [None]:
assert_array_almost_equal(results, y_arr_jax, decimal=3)

### Proper batching

In [14]:
solver = jo.LBFGS(fun=distance_jax2)

def solve(x, y):
    return solver.run(init_params=x, y=y)[0]

batch_solve = jit(vmap(solve, in_axes=(None, 0)))

In [15]:
result = batch_solve(params_jax, y_arr_jax)

In [16]:
assert_array_almost_equal(result, y_arr_jax)

In [17]:
%timeit batch_solve(params_jax, y_arr_jax)

117 µs ± 3.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
import jax

In [11]:
import jax.numpy as jnp

In [8]:
import jaxopt

In [9]:
from jaxopt import LBFGS

In [25]:
import jax.numpy as jnp

x0 = jnp.arange(3, dtype=float) + 1

shift = x0.copy()

def criterion(x, shift):
    return jnp.vdot(x, x + shift)

In [26]:
solver = LBFGS(fun=criterion)

In [27]:
res = solver.run(init_params=x0, shift=shift)

In [29]:
res.params

DeviceArray([-0.5, -1. , -1.5], dtype=float64)

In [18]:
import numpy as np

x0 = np.arange(3, dtype=float) + 1

shift = x0.copy()

def criterion(x, shift):
    return np.vdot(x, x + shift)

In [23]:
from scipy.optimize import minimize

In [24]:
minimize(criterion, x0, args=(shift,))

      fun: -3.499999999998926
 hess_inv: array([[ 0.96428578, -0.0714285 , -0.10714278],
       [-0.0714285 ,  0.85714289, -0.21428571],
       [-0.10714278, -0.21428571,  0.67857135]])
      jac: array([-1.84774399e-06, -5.66244125e-07,  7.74860382e-07])
  message: 'Optimization terminated successfully.'
     nfev: 16
      nit: 3
     njev: 4
   status: 0
  success: True
        x: array([-0.50000092, -1.00000028, -1.49999962])