# Learn Optax

## Quick Start

Let's use optax to fit a parametrized function. We will consider the problem of learning to identify when a value is odd or even.

We will begin by creating a dataset that consists of batches of random 8 bit integers (represented using their binary representation), with each value labelled as "odd" or "even" using 1-hot encoding (i.e. `[1, 0]` means odd `[0, 1]` means even).


In [1]:
import optax
import jax.numpy as jnp
import jax

BATCH_SIZE = 5
NUM_TRAIN_STEPS = 1_000
RAW_TRAINING_DATA = jax.random.randint(jax.random.PRNGKey(42), (NUM_TRAIN_STEPS, BATCH_SIZE, 1), 0, 255)

TRAINING_DATA = jnp.unpackbits(RAW_TRAINING_DATA.astype(jnp.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)

We may now define a parametrized function using JAX. This will allow us to efficiently compute gradients.

There are a number of libraries that provide common building blocks for parametrized functions (such as flax and haiku). For this case though, we shall implement our function from scratch.

Our function will be a 1-layer MLP (multi-layer perceptron) with a single hidden layer, and a single output layer. We initialize all parameters using a standard Gaussian $\mathcal{N}(0,1)$ distribution.

In [2]:
initial_params = {
    'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),
    'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)),
}


def net(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
  x = jnp.dot(x, params['hidden'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['output'])
  return x


def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  y_hat = net(batch, params)

  # optax also provides a number of common loss functions.
  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

  return loss_value.mean()

We will use `optax.adam` to compute the parameter updates from their gradients on each optimizer step.

Note that since optax optimizers are implemented using pure functions, we will need to also keep track of the optimizer state. For the Adam optimizer, this state will contain the momentum values.

In [3]:
@jax.jit
def step(params, opt_state, batch, labels):
  loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss_value

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'Step: {i:3}, Loss: {loss_value:.3f}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)

Step:   0, Loss: 5.624
Step: 100, Loss: 0.188
Step: 200, Loss: 0.053
Step: 300, Loss: 0.025
Step: 400, Loss: 0.004
Step: 500, Loss: 0.028
Step: 600, Loss: 0.002
Step: 700, Loss: 0.025
Step: 800, Loss: 0.017
Step: 900, Loss: 0.003


We see that our loss appears to have converged, which should indicate that we have successfully found better parameters for our network

## Weight Decay, Schedules and Clipping

Many research models make use of techniques such as learning rate scheduling, and gradient clipping. These may be achieved by _chaining_ together gradient transformations such as `optax.adam` and `optax.clip`.

In the following, we will use `Adam` with weight decay (`optax.adamw`), a cosine learning rate schedule (with warmup) and also gradient clipping.

In [4]:
schedule = optax.warmup_cosine_decay_schedule(
  init_value=0.0,
  peak_value=1.0,
  warmup_steps=50,
  decay_steps=1_000,
  end_value=0.0,
)

optimizer = optax.chain(
  optax.clip(1.0),
  optax.adamw(learning_rate=schedule),
)

params = fit(initial_params, optimizer)

Step:   0, Loss: 5.624
Step: 100, Loss: 0.000
Step: 200, Loss: 0.000
Step: 300, Loss: 0.000
Step: 400, Loss: 0.000
Step: 500, Loss: 0.000
Step: 600, Loss: 0.000
Step: 700, Loss: 0.000
Step: 800, Loss: 0.000
Step: 900, Loss: 0.000


## Reading the Learning Rate inside the Train Loop

Sometimes we want to access certain hyperparameters in the optimizer. For example, we may want to log the learning rate at a service.

To extract the learning rate inside the train loop, we can use the [inject_hyperparams](https://optax.readthedocs.io/en/latest/api.html#optax.inject_hyperparams) wrapper to make any hyperparameter a modifiable part of the optimizer state. This means that you can promote the learning rate to be part of the optimizer state so that you can access it in the optimizer state directly.

The following example demonstrates how to extend the previous code to extract the learning rate.

In [5]:
# Wrap the optimizer to inject the hyperparameters
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=schedule)

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  # Since we injected hyperparams, we can access them directly here
  print(f'Available hyperparams: {" ".join(opt_state.hyperparams.keys())}\n')

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      # Get the updated learning rate
      lr = opt_state.hyperparams['learning_rate']
      print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.3f}')

  return params

params = fit(initial_params, optimizer)

Available hyperparams: b1 b2 eps eps_root weight_decay learning_rate

Step   0, Loss: 5.624, Learning rate: 0.020
Step 100, Loss: 0.000, Learning rate: 0.993
Step 200, Loss: 0.000, Learning rate: 0.939
Step 300, Loss: 0.000, Learning rate: 0.837
Step 400, Loss: 0.000, Learning rate: 0.699
Step 500, Loss: 0.000, Learning rate: 0.540
Step 600, Loss: 0.000, Learning rate: 0.376
Step 700, Loss: 0.000, Learning rate: 0.225
Step 800, Loss: 0.000, Learning rate: 0.104
Step 900, Loss: 0.000, Learning rate: 0.027
