In [1]:
import optax
import jaxopt
import jax
import jax.numpy as jnp

### Function without additional arguments

In [2]:
def cost(theta):
    return jnp.linalg.norm(theta)

In [3]:
theta_init = jnp.array([4.0, 4.0])



In [4]:
cost(theta_init), cost(theta_init - 1)

(DeviceArray(5.656854, dtype=float32), DeviceArray(4.2426405, dtype=float32))

In [5]:
solver = jaxopt.OptaxSolver(opt=optax.adam(1e-1), fun=cost, maxiter=1000, tol=1e-8)

In [6]:
state = solver.init_state(theta_init)

In [7]:
@jax.jit
def jit_update(theta, state):
    return solver.update(theta, state)


theta = theta_init

for i in range(1000):
    theta, state = jit_update(theta, state)
    if i % 50 == 0:
        print(f"[{i}] {cost(theta)}")

[0] 5.5154337882995605
[50] 0.377135694026947
[100] 0.028257401660084724
[150] 0.031453393399715424
[200] 0.009030254557728767
[250] 0.009096682071685791
[300] 0.009097049944102764
[350] 0.009097031317651272
[400] 0.009097004309296608
[450] 0.00909700058400631
[500] 0.009096973575651646
[550] 0.009097002446651459
[600] 0.009096972644329071
[650] 0.00909694004803896
[700] 0.009096913039684296
[750] 0.009096892550587654
[800] 0.009096899069845676
[850] 0.009096900932490826
[900] 0.009096892550587654
[950] 0.009096898138523102


In [8]:
theta

DeviceArray([-0.00835913, -0.00835913], dtype=float32)

### Function with additional arguments

In [9]:
def cost_2(theta, X, y):
    y_hat = jnp.dot(X, theta[1:]) + theta[0]
    return jnp.linalg.norm(y - y_hat)

In [10]:
X_true = jnp.linspace(-1, 1, 100)
y = 4*X_true + 5
X_true = X_true.reshape(-1, 1)

In [11]:
theta_init = jnp.array([-2., 2.])

In [12]:
cost_2(theta_init, X_true, y)

DeviceArray(70.96497, dtype=float32)

In [13]:
cost_2(jnp.array([5, 4]), X_true, y)

DeviceArray(0., dtype=float32)

In [14]:
solver = jaxopt.OptaxSolver(opt=optax.adam(1e-2), fun=cost_2, maxiter=1000, tol=1e-8)

In [15]:
state = solver.init_state(theta_init)

In [16]:
@jax.jit
def jit_update(theta, state, data):
    X, y = data
    return solver.update(theta, state, X, y)


theta = theta_init

for i in range(1000):
    theta, state = jit_update(theta, state, (X_true, y))
    if i % 50 == 0:
        print(f"[{i}] {cost_2(theta, X_true, y)}")

[0] 70.85675811767578
[50] 65.48595428466797
[100] 60.20701217651367
[150] 55.017539978027344
[200] 49.902427673339844
[250] 44.83944320678711
[300] 39.80586624145508
[350] 34.785404205322266
[400] 29.770063400268555
[450] 24.757165908813477
[500] 19.745960235595703
[550] 14.736101150512695
[600] 9.727349281311035
[650] 4.719514846801758
[700] 0.22945404052734375
[750] 0.03783104941248894
[800] 0.013602446764707565
[850] 0.021956508979201317
[900] 0.01089780405163765
[950] 0.00854769628494978


In [17]:
theta

DeviceArray([4.999103 , 3.9999988], dtype=float32)

### Function with additional arguments and data minibatching

In [18]:
def cost_3(theta, data):
    X, y = data
    y_hat = jnp.dot(X, theta[1:]) + theta[0]
    return jnp.linalg.norm(y - y_hat)

In [19]:
rng = jax.random
batch_size = 5
n_iter = 100
n_samples = len(y)

def data_iterator():
    for _ in range(n_iter):
        perm = rng.permutation(key=jax.random.PRNGKey(_), x = n_samples)[:batch_size]
        yield X_true[perm], y[perm]

In [20]:
iterator = data_iterator()
solver = jaxopt.OptaxSolver(opt=optax.adam(1e-1), fun=cost_3, maxiter=1000, tol=1e-8)
res = solver.run_iterator(theta_init, iterator)

In [21]:
res

OptStep(params=DeviceArray([5.014913 , 3.9717538], dtype=float32), state=OptaxState(iter_num=DeviceArray(100, dtype=int32, weak_type=True), value=DeviceArray(0.10703757, dtype=float32), error=DeviceArray(1.6834657, dtype=float32), internal_state=(ScaleByAdamState(count=DeviceArray(100, dtype=int32), mu=DeviceArray([ 0.3258862 , -0.21115768], dtype=float32), nu=DeviceArray([0.44145292, 0.04544457], dtype=float32)), EmptyState()), aux=None))

In [22]:
res.params

DeviceArray([5.014913 , 3.9717538], dtype=float32)