In [1]:
import jax.numpy as jnp

def loss(x, x0):
    r = x - x0
    return jnp.sum(r ** 2)

In [2]:
import optax

optimizer = optax.adamw(learning_rate=1e-3)
optimizer

GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x10fdb0fe0>, update=<function chain.<locals>.update_fn at 0x10fdb1080>)

In [3]:
import jax

# x0 and x could be any point, e.g. np.zeros((5,), dtype=float)
x0 = jnp.array([-1., 0., 1.])
x = jnp.array([0., 0., 0.])

# Gradient with respect to x. x0 is treated as constant.
grad_loss = jax.grad(loss, argnums=0)
grad = grad_loss(x, x0)
grad



Array([ 2.,  0., -2.], dtype=float32)

In [4]:
optimizer_state = optimizer.init(x)
optimizer_state

(ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)),
 EmptyState(),
 EmptyState())

In [5]:
updates, optimizer_state = optimizer.update(grad, optimizer_state, x)
display(updates)
display(optimizer_state)

Array([-0.00099999, -0.        ,  0.00099999], dtype=float32)

(ScaleByAdamState(count=Array(1, dtype=int32), mu=Array([ 0.2,  0. , -0.2], dtype=float32), nu=Array([0.004, 0.   , 0.004], dtype=float32)),
 EmptyState(),
 EmptyState())

In [6]:
x = optax.apply_updates(x, updates)
x

Array([-0.00099999,  0.        ,  0.00099999], dtype=float32)

In [7]:
# Putting everything together
for i in range(3000):
    grad = grad_loss(x, x0)
    updates, optimizer_state = optimizer.update(grad, optimizer_state, x)
    x = optax.apply_updates(x, updates)
    if i % 100 == 0:
        print(x)


[-0.00199996  0.          0.00199996]
[-0.10017119  0.          0.10017119]
[-0.19332755  0.          0.19332755]
[-0.28124353  0.          0.28124353]
[-0.36382174  0.          0.36382174]
[-0.44096917  0.          0.44096917]
[-0.5126047  0.         0.5126047]
[-0.578666  0.        0.578666]
[-0.63911706  0.          0.63911706]
[-0.69395846  0.          0.69395846]
[-0.74323493  0.          0.74323493]
[-0.78704435  0.          0.78704435]
[-0.8255418  0.         0.8255418]
[-0.8589458  0.         0.8589458]
[-0.88753575  0.          0.88753575]
[-0.9116481  0.         0.9116481]
[-0.9316673  0.         0.9316673]
[-0.9480142  0.         0.9480142]
[-0.9611299  0.         0.9611299]
[-0.9714603  0.         0.9714603]
[-0.9794408  0.         0.9794408]
[-0.9854826  0.         0.9854826]
[-0.9899609  0.         0.9899609]
[-0.9932073  0.         0.9932073]
[-0.9955071  0.         0.9955071]
[-0.9970973  0.         0.9970973]
[-0.99816936  0.          0.99816936]
[-0.99887353  0.      