In [9]:
import jax.numpy as jnp
import jax
from jax import jit
import optax
import functools
import time

In [13]:
@functools.partial(jax.vmap, in_axes=(None, 0))
def network(params, x):
  return jnp.dot(params, x)

def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = jnp.mean(optax.l2_loss(y_pred, y))
  return loss

In [14]:
key = jax.random.PRNGKey(42)
target_params = 0.5

# Generate some data.
xs = jax.random.normal(key, (16, 2))
ys = jnp.sum(xs * target_params, axis=-1)

In [15]:
xs

Array([[-2.0201101e+00,  8.6349919e-03],
       [-2.0828791e+00,  9.4689780e-01],
       [-7.4697673e-02,  2.1900117e-01],
       [ 1.7400689e+00,  1.4436092e+00],
       [ 1.6966347e+00, -9.8481425e-04],
       [ 1.9873228e+00, -1.3630089e+00],
       [-3.1369337e-01, -4.6323735e-01],
       [ 1.7433221e+00, -9.6858436e-01],
       [ 5.2875841e-01, -1.1646140e-02],
       [ 3.3797663e-01,  9.7233158e-01],
       [ 6.4158249e-01,  7.3273242e-01],
       [ 1.5225714e+00,  8.5781729e-01],
       [ 2.6217616e-01,  2.8434937e+00],
       [ 5.7764381e-01,  5.0313210e-01],
       [ 3.8976768e-01, -1.3724487e-01],
       [-1.7162652e+00, -9.3954539e-01]], dtype=float32)

In [42]:
start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)

# Initialize parameters of the model + optimizer.
params = jnp.array([0.0, 0.0])
opt_state = optimizer.init(params)

# grad 
compute_loss_grad = jax.grad(compute_loss, 0)
compute_loss_grad_jit = jax.jit(compute_loss_grad)
compute_loss_grad_jit(params,xs,ys)

Array([-0.9058708, -0.5763672], dtype=float32)

In [43]:
t0 = time.time()
# A simple update loop.
for _ in range(1000):
#   grads = jax.grad(compute_loss)(params, xs, ys)
  grads = compute_loss_grad_jit(params,xs,ys)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
print(f"time taken:{ time.time()-t0 }")
assert jnp.allclose(params, target_params), \
'Optimization should retrive the target params used to generate the data.'

time taken:0.2632148265838623


In [41]:
params

Array([0.5, 0.5], dtype=float32)

In [24]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=start_learning_rate, 
    transition_steps=1000,
    decay_rate=0.99)

# Combining gradient transforms using `optax.chain`.
gradient_transform = optax.chain(
    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
    optax.scale_by_adam(),  # Use the updates from adam.
    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
    optax.scale(-1.0)
)
     

In [25]:

# Initialize parameters of the model + optimizer.
params = jnp.array([0.0, 0.0])  # Recall target_params=0.5.
opt_state = gradient_transform.init(params)

# A simple update loop.
for _ in range(1000):
  grads = jax.grad(compute_loss)(params, xs, ys)
  updates, opt_state = gradient_transform.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

assert jnp.allclose(params, target_params), \
'Optimization should retrive the target params used to generate the data.'

In [26]:
params

Array([0.5, 0.5], dtype=float32)

In [10]:
@jit
def fun(x):
    return jnp.square(x)

In [13]:
fun(jnp.array([1,2,3,9]))

Array([ 1,  4,  9, 81], dtype=int32)