This notebook compares scipy.optimize.least_squares, scipy.optimize.minimize and
jaxopt for solving least squares problem. Takeaway: use least_quares.

In [1]:
import diffrax as dfx
import dynax as dx
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np


jax.config.update("jax_enable_x64", True)


class LoudspeakerDynamics(dx.ControlAffine):
    Bl: float
    Re: float
    Rm: float
    K: float
    L: float
    M: float
    outputs: list = eqx.static_field()

    def __init__(self, params, outputs=[0, 2]):
        self.n_states = 3
        self.n_params = 6
        self.Bl, self.Re, self.Rm, self.K, self.L, self.M = params
        self.outputs = outputs

    def f(self, x, t=None):
        i, d, v = x
        di = (-self.Re * i - self.Bl * v) / self.L
        dd = v
        dv = (self.Bl * i - self.Rm * v - self.K * d) / self.M
        return jnp.array([di, dd, dv])

    def g(self, x, t=None):
        di = 1 / self.L
        dd = 0
        dv = 0
        return jnp.array([di, dd, dv])

    def h(self, x, t=None):
        return x[np.array(self.outputs)]


import functools


# from https://github.com/google/jax/pull/762#issuecomment-1002267121


def value_and_jacfwd(f, x):
    pushfwd = functools.partial(jax.jvp, f, (x,))
    basis = jnp.eye(x.size, dtype=x.dtype)
    y, jac = jax.vmap(pushfwd, out_axes=(None, 1))((basis,))
    return y, jac


def value_and_jacrev(f, x):
    y, pullback = jax.vjp(f, x)
    basis = jnp.eye(y.size, dtype=y.dtype)
    jac = jax.vmap(pullback)(basis)
    return y, jac


# Training data
n = 9600
sr = 96000
t = jnp.array(np.arange(n) / sr)
u = jnp.array(np.random.normal(size=n))
coeffs = dfx.backward_hermite_coefficients(t, u)
cubic = dfx.CubicInterpolation(t, coeffs)
ufun = lambda t: cubic.evaluate(t)
initial_params = [1.0, 1.0, 1.0, 1000.0, 1e-3, 1e-3]
dyn = LoudspeakerDynamics([i * 2 for i in initial_params])
true_model = dx.ForwardModel(dyn, sr)
x0 = jnp.array([0.0, 0.0, 0.0])
y, sol = true_model(t, x0, ufun)



In [10]:
# model
model = dx.ForwardModel(LoudspeakerDynamics(initial_params), sr)
init_params, treedef = jax.tree_flatten(model)
std_y = np.std(y, axis=0)


def residuals(params):
    model = treedef.unflatten(params)
    pred_y, _ = model(t, x0, ufun)
    res = ((y - pred_y) / std_y).reshape(-1)
    return res / np.sqrt(len(res))


def loss(params):
    return jnp.sum(jnp.square(residuals(params)))

Sperate functions for fun and jacm

In [27]:
# using least_squares
from scipy.optimize import least_squares


fun = jax.jit(residuals)
jac = jax.jit(jax.jacfwd(residuals))

import time


reslist_ls = []
time_ls = []
for method in ("trf", "dogbox", "lm"):
    res = least_squares(
        fun, init_params, jac=jac, method=method, x_scale="jac", max_nfev=1
    )
    start = time.time()
    res = least_squares(fun, init_params, jac=jac, method=method, x_scale="jac")
    end = time.time()
    print(method, "took", end - start)
    time_ls.append(end - start)
    reslist_ls.append(res)

from tabulate import tabulate


print(
    tabulate(
        [
            ("trf", "dogbox", "lm"),
            [r.cost for r in reslist_ls],
            [r.nfev for r in reslist_ls],
            [r.njev for r in reslist_ls],
            time_ls,
        ]
    )
)

trf took 0.6331028938293457
dogbox took 0.6533203125
lm took 0.48241162300109863
---------------------  ----------------------  --------------------
trf                    dogbox                  lm
7.271429895608476e-31  7.5066315474993315e-31  7.00090733847295e-31
7                      7                       7
7                      7                       6
0.6331028938293457     0.6533203125            0.48241162300109863
---------------------  ----------------------  --------------------


result: for 10s at 96kHz: trf 28s, dogbox 27s, lm 40s. So use trf!

Solve primal and tangent problem simultaneously with cache. Cache misses.

In [28]:
# using least_squares
from scipy.optimize import least_squares


jac2 = jax.jit(jax.jacfwd(residuals))
fun2 = jax.jit(residuals)
val_jac = jax.jit(lambda x: value_and_jacfwd(residuals, x))
jac = True

val_cache = {}
jac_cache = {}


def fun(arr):
    try:
        return val_cache.pop(id(arr))
    except KeyError:
        (v, j) = val_jac(arr)
        jac_cache[id(arr)] = j
        return v


def jac(arr):
    try:
        return jac_cache.pop(id(arr))
    except KeyError:
        print("missed jac_cache")
        (v, j) = val_jac(arr)
        val_cache[id(arr)] = v
        return j


import time


reslist_ls = []
time_ls = []
for method in ("trf", "dogbox", "lm"):
    res = least_squares(
        fun, init_params, jac=jac, method=method, x_scale="jac", max_nfev=1
    )
    start = time.time()
    res = least_squares(fun, init_params, jac=jac, method=method, x_scale="jac")
    end = time.time()
    print(method, "took", end - start)
    time_ls.append(end - start)
    reslist_ls.append(res)

from tabulate import tabulate


print(
    tabulate(
        [
            ("trf", "dogbox", "lm"),
            [r.cost for r in reslist_ls],
            [r.nfev for r in reslist_ls],
            [r.njev for r in reslist_ls],
            time_ls,
        ]
    )
)

trf took 0.44023728370666504
dogbox took 0.4348287582397461
missed jac_cache
missed jac_cache
missed jac_cache
missed jac_cache
missed jac_cache
lm took 0.6225440502166748
---------------------  ----------------------  ---------------------
trf                    dogbox                  lm
7.271429895608476e-31  7.5066315474993315e-31  6.433023508154689e-31
7                      7                       9
7                      7                       8
0.44023728370666504    0.4348287582397461      0.6225440502166748
---------------------  ----------------------  ---------------------


Common cache leads to more cache misses and fucks up LM result.

In [29]:
# using least_squares
from scipy.optimize import least_squares


val_jac = jax.jit(lambda x: value_and_jacfwd(residuals, x))
jac = True

cache = {}


def fun(arr):
    try:
        return cache.pop(id(arr))[0]
    except KeyError:
        (v, j) = val_jac(arr)
        cache[id(arr)] = (v, j)
        return v


def jac(arr):
    try:
        return cache.pop(id(arr))[1]
    except KeyError:
        print("missed cache")
        (v, j) = val_jac(arr)
        cache[id(arr)] = (v, j)
        return j


import time


reslist_ls = []
time_ls = []
for method in ("trf", "dogbox", "lm"):
    res = least_squares(
        fun, init_params, jac=jac, method=method, x_scale="jac", max_nfev=1
    )
    start = time.time()
    res = least_squares(fun, init_params, jac=jac, method=method, x_scale="jac")
    end = time.time()
    print(method, "took", end - start)
    time_ls.append(end - start)
    reslist_ls.append(res)

from tabulate import tabulate


print(
    tabulate(
        [
            ("trf", "dogbox", "lm"),
            [r.cost for r in reslist_ls],
            [r.nfev for r in reslist_ls],
            [r.njev for r in reslist_ls],
            time_ls,
        ]
    )
)

trf took 0.4343602657318115
dogbox took 0.5049810409545898
missed cache
missed cache
missed cache
missed cache
missed cache
missed cache
missed cache
missed cache
missed cache
missed cache
missed cache
missed cache
missed cache
lm took 0.9419846534729004
---------------------  ----------------------  -------------------
trf                    dogbox                  lm
7.271429895608476e-31  7.5066315474993315e-31  0.07904991004683685
7                      7                       32
7                      7                       11
0.4343602657318115     0.5049810409545898      0.9419846534729004
---------------------  ----------------------  -------------------


Common cache with hashing tuples removes cache misses and LM problem. Saves around 33% time! 

In [30]:
# using least_squares
from scipy.optimize import least_squares


val_jac = jax.jit(lambda x: value_and_jacfwd(residuals, x))

cache = {}


def fun(arr):
    try:
        return cache[tuple(arr)][0]
    except KeyError:
        (v, j) = val_jac(arr)
        cache[tuple(arr)] = (v, j)
        return v


def jac(arr):
    try:
        return cache[tuple(arr)][1]
    except KeyError:
        print("missed cache")
        (v, j) = val_jac(arr)
        cache[tuple(arr)] = (v, j)
        return j


import time


reslist_ls = []
time_ls = []
for method in ("trf", "dogbox", "lm"):
    res = least_squares(
        fun, init_params, jac=jac, method=method, x_scale="jac", max_nfev=1
    )
    cache.clear()
    start = time.time()
    res = least_squares(fun, init_params, jac=jac, method=method, x_scale="jac")
    end = time.time()
    print(method, "took", end - start)
    time_ls.append(end - start)
    reslist_ls.append(res)

from tabulate import tabulate


print(
    tabulate(
        [
            ("trf", "dogbox", "lm"),
            [r.cost for r in reslist_ls],
            [r.nfev for r in reslist_ls],
            [r.njev for r in reslist_ls],
            time_ls,
        ]
    )
)

trf took 0.42288732528686523
dogbox took 0.4270446300506592
lm took 0.3169374465942383
---------------------  ----------------------  --------------------
trf                    dogbox                  lm
7.271429895608476e-31  7.5066315474993315e-31  7.00090733847295e-31
7                      7                       7
7                      7                       6
0.42288732528686523    0.4270446300506592      0.3169374465942383
---------------------  ----------------------  --------------------


In [31]:
class MemoizeJac:
    """Decorator that caches the return values of a function returning `(fun, grad)`
    each time it is called.

    from https://github.com/scipy/scipy/blob/85895a2fdfed853801846b56c9f1418886e2ccc2/scipy/optimize/_optimize.py#L57
    """

    def __init__(self, fun):
        self.fun = fun
        self.jac = None
        self._value = None
        self.x = None

    def _compute_if_needed(self, x, *args, der=False):
        if not np.all(x == self.x) or self._value is None or self.jac is None:
            if der:
                print("Cache missed.")
            self.x = np.asarray(x).copy()
            fg = self.fun(x, *args)
            self.jac = fg[1]
            self._value = fg[0]

    def __call__(self, x, *args):
        """returns the the function value"""
        self._compute_if_needed(x, *args)
        return self._value

    def derivative(self, x, *args):
        self._compute_if_needed(x, *args, der=True)
        return self.jac


val_jac = jax.jit(lambda x: value_and_jacfwd(residuals, x))

fun = MemoizeJac(val_jac)
jac = fun.derivative

import time


reslist_ls = []
time_ls = []
for method in ("trf", "dogbox", "lm"):
    res = least_squares(
        fun, init_params, jac=jac, method=method, x_scale="jac", max_nfev=1
    )
    cache.clear()
    start = time.time()
    res = least_squares(fun, init_params, jac=jac, method=method, x_scale="jac")
    end = time.time()
    print(method, "took", end - start)
    time_ls.append(end - start)
    reslist_ls.append(res)

from tabulate import tabulate


print(
    tabulate(
        [
            ("trf", "dogbox", "lm"),
            [r.cost for r in reslist_ls],
            [r.nfev for r in reslist_ls],
            [r.njev for r in reslist_ls],
            time_ls,
        ]
    )
)

trf took 0.36748623847961426
dogbox took 0.4010353088378906
lm took 0.32466626167297363
---------------------  ----------------------  --------------------
trf                    dogbox                  lm
7.271429895608476e-31  7.5066315474993315e-31  7.00090733847295e-31
7                      7                       7
7                      7                       6
0.36748623847961426    0.4010353088378906      0.32466626167297363
---------------------  ----------------------  --------------------


This here doesn't work and i have no idea why.

In [20]:
def least_squares_double_time(val_jac, x0, **kwargs):
    val_jac(jnp.array(x0))
    cache = {}

    def fun(arr):
        try:
            return cache[tuple(arr)][0]
        except KeyError:
            (v, j) = val_jac(arr)
            cache[tuple(arr)] = (v, j)
            return v

    def jac(arr):
        try:
            return cache[tuple(arr)][1]
        except KeyError:
            (v, j) = val_jac(arr)
            cache[tuple(arr)] = (v, j)
            return j

    return least_squares(fun, x0, jac=jac, **kwargs)


val_jac = jax.jit(lambda x: value_and_jacfwd(res, x))

import time


reslist_ls = []
time_ls = []
for method in ("trf", "dogbox", "lm"):
    res = least_squares_double_time(
        val_jac, init_params, method=method, x_scale="jac", max_nfev=1
    )
    start = time.time()
    res = least_squares_double_time(val_jac, init_params, method=method, x_scale="jac")
    end = time.time()
    print(method, "took", end - start)
    time_ls.append(end - start)
    reslist_ls.append(res)

from tabulate import tabulate


print(
    tabulate(
        [
            ("trf", "dogbox", "lm"),
            [r.cost for r in reslist_ls],
            [r.nfev for r in reslist_ls],
            [r.njev for r in reslist_ls],
            time_ls,
        ]
    )
)

TypeError: Expected a callable value, got  active_mask: array([0, 0, 0, 0, 0, 0])
        cost: 7.00090733847295e-31
         fun: array([ 0.00000000e+00,  0.00000000e+00, -1.74586542e-18, ...,
        1.92049312e-18,  4.74875393e-17,  2.03935100e-18])
        grad: array([ 4.64704768e-17, -4.80156921e-17, -3.04459798e-17,  1.16417243e-20,
       -5.42168243e-14, -2.74112284e-14])
         jac: array([[-0.00000000e+00, -0.00000000e+00, -0.00000000e+00,
        -0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
       [-0.00000000e+00, -0.00000000e+00, -0.00000000e+00,
        -0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
       [ 1.34918270e-08,  1.80654375e-06, -1.83101570e-11,
        -3.80927710e-17,  2.92603070e-01, -6.72756523e-06],
       ...,
       [ 6.34014634e-04,  1.44353715e-03,  8.51964278e-04,
         9.21516599e-07, -1.74760805e+00, -2.07755178e+00],
       [ 1.24289610e-03, -2.08329983e-03, -1.46188193e-04,
        -2.59478169e-07, -2.12932450e+00, -2.15781688e-01],
       [ 7.01884472e-04,  1.42511011e-03,  8.30072503e-04,
         9.30277658e-07, -1.79175454e+00, -2.12699459e+00]])
     message: '`xtol` termination condition is satisfied.'
        nfev: 7
        njev: 6
  optimality: 5.421682434511197e-14
      status: 3
     success: True
           x: array([2.e+00, 2.e+00, 2.e+00, 2.e+03, 2.e-03, 2.e-03])

In [32]:
# using minimize
from scipy.optimize import minimize


fun = jax.value_and_grad(loss)
jac = jax.jit(jax.grad(loss))
hess = jax.jit(jax.hessian(loss))
reslist = []
times = []
methods = ["Newton-CG", "dogleg", "trust-ncg", "trust-krylov", "trust-exact"]
for method in methods:
    print("method:", method)
    start = time.time()
    res = minimize(fun, init_params, jac=True, hess=hess, method=method)
    end = time.time()
    print("time:", end - start)
    print("sol:", res.x)
    print("success:", res.success)
    reslist.append(res)
    times.append(end - start)

print(
    tabulate(
        [
            methods,
            [r.fun for r in reslist],
            [r.nfev for r in reslist],
            [r.njev for r in reslist],
            [r.nhev for r in reslist],
            times,
        ]
    )
)

method: Newton-CG


In [34]:
from jaxopt import GaussNewton, LevenbergMarquardt


gm = GaussNewton(residual_fun=jax.jit(residuals))
lm = LevenbergMarquardt(residual_fun=jax.jit(residuals))
pred_params = lm.run(jnp.array(initial_params))

# from scipy.optimize import least_squares, minimize
# # # solve least_squares in scaled parameter space
# fun = jax.jit(residuals)
# jac = jax.jit(jax.jacfwd(residuals))
# res = least_squares(fun, init_params, jac=jac, x_scale='jac', verbose=2)

NotImplementedError: outfeed rewrite custom_linear_solve