<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/least_squares_linear_operator_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from scipy.optimize import least_squares
import jax.numpy as jnp
import jax
import numpy as np
from scipy.sparse.linalg import LinearOperator
from jax.config import config
from functools import partial
config.update("jax_enable_x64", True)

In [6]:
def getLO(f):
    def mv(x,v):
        return jax.jvp(f,(x,),(jnp.squeeze(v),))[1]

    def rmv(x,v):
        return jax.vjp(f,x)[1](v)

    jit_mv = jax.jit(mv)
    jit_rmv = jax.jit(rmv)

    def LO(x):
        return LinearOperator((N,N), matvec = partial(jit_mv,x), rmatvec=partial(jit_rmv, x))

    return LO

In [7]:
N=1000
N_offdiagonal = int(10*N)
xguess=jnp.asarray(np.random.uniform(size=N))
pairs = np.random.randint(0,N,size=(N_offdiagonal,2))
eq = np.random.randint(0,N,N_offdiagonal)
coeff = np.random.uniform(-10,10,size=N)

In [8]:
def func(x):
    x=jnp.asarray(x)+coeff
    res=x**2
    res=res.at[eq].add(-x[pairs[:,0]]*x[pairs[:,1]])
    return res

jit_func = jax.jit(func)

In [9]:
sol=least_squares(jit_func,xguess,jac=getLO(func),gtol=1e-12)

In [10]:
func(sol.x)

DeviceArray([ 1.54967075e-10, -2.20217351e-10,  5.93419800e-10,
              2.86712012e-11,  1.63856875e-10,  3.64316137e-10,
              9.04084824e-11, -4.25735413e-11,  1.30799783e-10,
              1.65875766e-10, -8.43838644e-12,  3.27676199e-10,
              1.93103098e-10,  1.02992774e-10,  9.39662293e-11,
              6.56094189e-11, -7.63125437e-11,  2.32157409e-10,
              2.05205519e-10,  1.56727441e-10,  2.41205445e-10,
             -2.89039673e-10,  3.14580096e-10,  1.76590653e-10,
             -2.91858743e-10,  2.57637499e-10,  6.42632274e-11,
              3.51655874e-10,  4.41734593e-11,  9.70069914e-11,
              3.67208275e-10,  1.85589645e-10,  1.02694312e-10,
             -5.43341102e-11,  2.94018381e-10, -8.92837294e-11,
              4.77238662e-10, -8.27183790e-11,  5.39420918e-11,
              1.88777942e-11,  8.93695540e-11,  2.61221948e-11,
              6.63765876e-10,  4.40499200e-10, -7.30668328e-12,
              1.07162723e-10,  4.3621114

In [None]:
sol.x

array([-6.33352467, -0.61726225, -4.40946944, ..., -8.06705956,
       -0.54791887, -1.51400014])

In [None]:
sol.nfev, sol.njev

(21, 21)