# Quickstart with Optax.

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/master/examples/quick_start.ipynb)

Optax is a simple optimization library for [Jax](https://jax.readthedocs.io/). The main object is the `GradientTransformation`, which can be chained
with other transformations to obtain the final update operation and the optimizer state.
Optax also contains some simple loss functions and utilities to help you write the full optimization steps. This notebook walks you through a few examples on how to use Optax to train a simple linear model. Begin by importing the necessary packages:




In [None]:
import jax.numpy as jnp
import jax
import optax
import functools

In this example, we begin by setting up a simple linear model and a loss function. You can use any other library, such as [haiku](https://github.com/deepmind/dm-haiku) or [Flax](https://github.com/google/flax) to construct your networks. Here, we keep it simple and write it ourselves. The loss function (L2 Loss) comes from optax's [common loss functions](https://optax.readthedocs.io/en/latest/api.html#common-losses) via `optax.l2_loss()`.

In [None]:
@functools.partial(jax.vmap, in_axes=(None, 0))
def network(params, x):
  return jnp.dot(params, x)

def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = jnp.mean(optax.l2_loss(y_pred, y))
  return loss

Here we generate data under a known linear model (with `target_params=0.5`):

In [None]:
key = jax.random.PRNGKey(42)
target_params = 0.5

# Generate some data.
xs = jax.random.normal(key, (16, 2))
ys = jnp.sum(xs * target_params, axis=-1)

## Basic usage of Optax

Optax contains implementations of [many popular optimizers](https://optax.readthedocs.io/en/latest/api.html#Common-Optimizers) that can be used very simply. For example, the gradient transform for the Adam optimizer is available at `optax.adam()`. For now, let's start by calling the `GradientTransform` object for Adam the `optimizer`. We then initialize the optimizer state using the `init` function and `params` of the network.

In [None]:
start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)

# Initialize parameters of the model + optimizer.
params = jnp.array([0.0, 0.0])
opt_state = optimizer.init(params)



Next we write the update loop. The `GradientTransform` object contains an `update` function that takes in the current optimizer state and gradients and returns the `updates` that need to be applied to the parameters: `updates, new_opt_state = optimizer.update(grads, opt_state)`.

Optax comes with a few simple [update rules](https://optax.readthedocs.io/en/latest/api.html#apply-updates) that apply the updates from the gradient transforms to the current parameters to return new ones: `new_params = optax.apply_updates(params, updates)`.



In [None]:
# A simple update loop.
for _ in range(1000):
  grads = jax.grad(compute_loss)(params, xs, ys)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

assert jnp.allclose(params, target_params), \
'Optimization should retrive the target params used to generate the data.'

### Custom optimizers

Optax makes it easy to create custom optimizers by `chain`ing gradient transforms. For example, this creates an optimizer based on `adam`. Note that the scaling is `-learning_rate` which is an important detail since `apply_updates` is additive.

In [None]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=start_learning_rate, 
    transition_steps=1000,
    decay_rate=0.99)

# Combining gradient transforms using `optax.chain`.
gradient_transform = optax.chain(
    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
    optax.scale_by_adam(),  # Use the updates from adam.
    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
    optax.scale(-1.0)
)

In [None]:
# Initialize parameters of the model + optimizer.
params = jnp.array([0.0, 0.0])  # Recall target_params=0.5.
opt_state = gradient_transform.init(params)

# A simple update loop.
for _ in range(1000):
  grads = jax.grad(compute_loss)(params, xs, ys)
  updates, opt_state = gradient_transform.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

assert jnp.allclose(params, target_params), \
'Optimization should retrive the target params used to generate the data.'

## Advanced usage of Optax

### Modifying hyperparameters of optimizers in a schedule.

In some scenarios, changing the hyperparameters (other than the learning rate) of an optimizer can be useful to ensure training reliability. We can do this easily by using `optax.inject_hyperparameters`. For example, this piece of code decays the `max_norm` of the `clip_by_global_norm` gradient transform as training progresses:




In [None]:
decaying_global_norm_tx = optax.inject_hyperparams(optax.clip_by_global_norm)(
    max_norm=optax.linear_schedule(1.0, 0.0, transition_steps=99))

opt_state = decaying_global_norm_tx.init(None)
assert opt_state.hyperparams['max_norm'] == 1.0, 'Max norm should start at 1.0'

for _ in range(100):
  _, opt_state = decaying_global_norm_tx.update(None, opt_state)

assert opt_state.hyperparams['max_norm'] == 0.0, 'Max norm should end at 0.0'