In [1]:
import optax
import jaxopt
import jax
import jax.numpy as jnp



### Function without additional arguments

In [2]:
def cost(theta):
    return jnp.linalg.norm(theta)

In [3]:
theta_init = jnp.array([4., 4.])



In [4]:
cost(theta_init), cost(theta_init-1)


(DeviceArray(5.656854, dtype=float32), DeviceArray(4.2426405, dtype=float32))

In [9]:
solver = jaxopt.OptaxSolver(opt=optax.adam(1e-1), fun=cost,maxiter=1000,tol=1e-8)

In [6]:
state = solver.init_state(theta_init)

In [7]:
@jax.jit
def jit_update(theta, state):
    return solver.update(theta, state)

theta = theta_init

for i in range(1000):
    theta, state = jit_update(theta, state)
    if i%50==0:
        print(f"[{i}] {cost(theta)}")

[0] 5.5154337882995605
[50] 0.377135694026947
[100] 0.028257401660084724
[150] 0.031453393399715424
[200] 0.009030254557728767
[250] 0.009096682071685791
[300] 0.009097049944102764
[350] 0.009097031317651272
[400] 0.009097004309296608
[450] 0.00909700058400631
[500] 0.009096973575651646
[550] 0.009097002446651459
[600] 0.009096972644329071
[650] 0.00909694004803896
[700] 0.009096913039684296
[750] 0.009096892550587654
[800] 0.009096899069845676
[850] 0.009096900932490826
[900] 0.009096892550587654
[950] 0.009096898138523102


In [8]:
theta

DeviceArray([-0.00835913, -0.00835913], dtype=float32)

### Function with additional arguments

In [10]:
def cost_2(theta, X, y):
    y_hat = jnp.dot(X, theta)
    return jnp.linalg.norm(y-y_hat)