# Learn Optax

## Quick Start

Let's use optax to fit a parametrized function. We will consider the problem of learning to identify when a value is odd or even.

We will begin by creating a dataset that consists of batches of random 8 bit integers (represented using their binary representation), with each value labelled as "odd" or "even" using 1-hot encoding (i.e. `[1, 0]` means odd `[0, 1]` means even).


In [None]:
import random
from typing import Tuple

import optax
import jax.numpy as jnp
import jax
import numpy as np

BATCH_SIZE = 5
NUM_TRAIN_STEPS = 1_000
RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))

TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)

We may now define a parametrized function using JAX. This will allow us to efficiently compute gradients.

There are a number of libraries that provide common building blocks for parametrized functions (such as flax and haiku). For this case though, we shall implement our function from scratch.

Our function will be a 1-layer MLP (multi-layer perceptron) with a single hidden layer, and a single output layer. We initialize all parameters using a standard Gaussian $\mathcal{N}(0,1)$ distribution.

In [None]:
initial_params = {
    'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),
    'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)),
}


def net(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
  x = jnp.dot(x, params['hidden'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['output'])
  return x


def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  y_hat = net(batch, params)

  # optax also provides a number of common loss functions.
  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

  return loss_value.mean()

We will use `optax.adam` to compute the parameter updates from their gradients on each optimizer step.

Note that since optax optimizers are implemented using pure functions, we will need to also keep track of the optimizer state. For the Adam optimizer, this state will contain the momentum values.

In [None]:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)

step 0, loss: 5.60183048248291
step 100, loss: 0.14773361384868622
step 200, loss: 0.28999248147010803
step 300, loss: 0.05951451137661934
step 400, loss: 0.08592046797275543
step 500, loss: 0.005035111214965582
step 600, loss: 0.0028563595842570066
step 700, loss: 0.013286210596561432
step 800, loss: 0.01311601884663105
step 900, loss: 0.003692328929901123


We see that our loss appears to have converged, which should indicate that we have successfully found better parameters for our network

## Weight Decay, Schedules and Clipping

Many research models make use of techniques such as learning rate scheduling, and gradient clipping. These may be achieved by _chaining_ together gradient transformations such as `optax.adam` and `optax.clip`.

In the following, we will use `Adam` with weight decay (`optax.adamw`), a cosine learning rate schedule (with warmup) and also gradient clipping.

In [None]:
schedule = optax.warmup_cosine_decay_schedule(
  init_value=0.0,
  peak_value=1.0,
  warmup_steps=50,
  decay_steps=1_000,
  end_value=0.0,
)

optimizer = optax.chain(
  optax.clip(1.0),
  optax.adamw(learning_rate=schedule),
)

params = fit(initial_params, optimizer)

step 0, loss: 5.60183048248291
step 100, loss: 1.0181801179953709e-08
step 200, loss: 0.27725887298583984
step 300, loss: 0.0
step 400, loss: 0.0
step 500, loss: 0.0
step 600, loss: 0.0
step 700, loss: 0.0
step 800, loss: 0.0
step 900, loss: 0.0


## Components

We refer to the [docs](https://optax.readthedocs.io/en/latest/index.html) for a detailed list of available Optax components. Here, we highlight the main categories of building blocks provided by Optax.

### Gradient Transformations ([transform.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/transform.py))

One of the key building blocks of `optax` is a `GradientTransformation`.

Each transformation is defined by two functions:

In [None]:
state = init(params)
grads, state = update(grads, state, params=None)

The `init` function initializes a (possibly empty) set of statistics (aka state) and the `update` function transforms a candidate gradient given some statistics, and (optionally) the current value of the parameters.

For example:

In [None]:
tx = scale_by_rms()
state = tx.init(params)  # init stats
grads, state = tx.update(grads, state, params)  # transform & update stats.

### Composing Gradient Transformations ([combine.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/combine.py))

The fact that transformations take candidate gradients as input and return processed gradients as output (in contrast to returning the updated parameters) is critical to allow to combine arbitrary transformations into a custom optimiser / gradient processor, and also allows to combine transformations for different gradients that operate on a shared set of variables.

For instance, `chain` combines them sequentially, and returns a new `GradientTransformation` that applies several transformations in sequence.

For example:

In [None]:
my_optimiser = chain(
    clip_by_global_norm(max_norm),
    scale_by_adam(eps=1e-4),
    scale(-learning_rate))

### Wrapping Gradient Transformations ([wrappers.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/wrappers.py))

Optax also provides several wrappers that take a `GradientTransformation` as input and return a new `GradientTransformation` that modifies the behaviour of the inner transformation in a specific way.

For instance, the `flatten` wrapper flattens gradients into a single large vector before applying the inner GradientTransformation. The transformed updates are then unflattened before being returned to the user. This can be used to reduce the overhead of performing many calculations on lots of small variables, at the cost of increasing memory usage.

For example:

In [None]:
my_optimiser = flatten(adam(learning_rate))

Other examples of wrappers include accumulating gradients over multiple steps or applying the inner transformation only to specific parameters or at specific steps.

### Schedules ([schedule.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/schedule.py))

Many popular transformations use time-dependent components, e.g. to anneal some hyper-parameter (e.g. the learning rate). Optax provides for this purpose `schedules` that can be used to decay scalars as a function of a `step` count.

For example, you may use a polynomial schedule (with `power=1`) to decay a hyper-parameter linearly over a number of steps:

In [None]:
schedule_fn = polynomial_schedule(
    init_value=1., end_value=0., power=1, transition_steps=5)

for step_count in range(6):
  print(schedule_fn(step_count))  # [1., 0.8, 0.6, 0.4, 0.2, 0.]

Schedules are used by certain gradient transformations, for instance:

In [None]:
schedule_fn = polynomial_schedule(
    init_value=-learning_rate, end_value=0., power=1, transition_steps=5)
optimiser = chain(
    clip_by_global_norm(max_norm),
    scale_by_adam(eps=1e-4),
    scale_by_schedule(schedule_fn))

### Popular optimisers ([alias.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/alias.py))

In addition to the low-level building blocks, we also provide aliases for popular optimisers built using these components (e.g. RMSProp, Adam, AdamW, etc, ...). These are all still instances of a `GradientTransformation`, and can therefore be further combined with any of the individual building blocks.

For example:

In [None]:
def adamw(learning_rate, b1, b2, eps, weight_decay):
  return chain(
      scale_by_adam(b1=b1, b2=b2, eps=eps),
      scale_and_decay(-learning_rate, weight_decay=weight_decay))

### Applying updates ([update.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/update.py))

After transforming an update using a `GradientTransformation` or any custom manipulation of the update, you will typically apply the update to a set of parameters. This can be done trivially using `tree_map`. 

For convenience, we expose an `apply_updates` function to apply updates to parameters. The function just adds the updates and the parameters together, i.e. `tree_map(lambda p, u: p + u, params, updates)`.

In [None]:
updates, state = tx.update(grads, state, params)  # transform & update stats.
new_params = optax.apply_updates(params, updates)  # update the parameters.

Note that separating gradient transformations from the parameter update is critical to support composing a sequence of transformations (e.g. `chain`), as well as combining multiple updates to the same parameters (e.g. in multi-task settings where different tasks need different sets of gradient transformations).

### Losses ([loss.py](https://github.com/google-deepmind/optax/tree/main/optax/losses))

Optax provides a number of standard losses used in deep learning, such as `l2_loss`, `softmax_cross_entropy`, `cosine_distance`, etc.

In [None]:
loss = huber_loss(predictions, targets)

The losses accept batches as inputs, however, they perform no reduction across the batch dimension(s). This is trivial to do in JAX, for example:

In [None]:
avg_loss = jnp.mean(huber_loss(predictions, targets))
sum_loss = jnp.sum(huber_loss(predictions, targets))

### Second Order ([second_order.py](https://github.com/google-deepmind/optax/tree/main/optax/second_order))

Computing the Hessian or Fisher information matrices for neural networks is typically intractable due to the quadratic memory requirements. Solving for the diagonals of these matrices is often a better solution. The library offers functions for computing these diagonals with sub-quadratic memory requirements.

### Stochastic gradient estimators ([stochastic_gradient_estimators.py](https://github.com/google-deepmind/optax/blob/main/optax/monte_carlo/stochastic_gradient_estimators.py))

Stochastic gradient estimators compute Monte Carlo estimates of gradients of the expectation of a function under a distribution with respect to the distribution's parameters.

Unbiased estimators, such as the score function estimator (REINFORCE), pathwise estimator (reparameterization trick) or measure valued estimator, are implemented: `score_function_jacobians`, `pathwise_jacobians` and `measure_valued_jacobians`. Their applicability (both in terms of functions and distributions) is discussed in their respective documentation.

Stochastic gradient estimators can be combined with common control variates for variance reduction via `control_variates_jacobians`. For provided control variates see `delta` and `moving_avg_baseline`.

The result of a gradient estimator or `control_variates_jacobians` contains the Jacobians of the function with respect to the samples from the input distribution. These can then be used to update distributional parameters or to assess gradient variance.

Example of how to use the `pathwise_jacobians` estimator:

In [None]:
dist_params = [mean, log_scale]
function = lambda x: jnp.sum(x * weights)
jacobians = pathwise_jacobians(
      function, dist_params,
      utils.multi_normal, rng, num_samples)

mean_grads = jnp.mean(jacobians[0], axis=0)
log_scale_grads = jnp.mean(jacobians[1], axis=0)
grads = [mean_grads, log_scale_grads]
optim_update, optim_state = optim.update(grads, optim_state)
updated_dist_params = optax.apply_updates(dist_params, optim_update)

where optim is an optax optimizer.