# Optax Lessons

In [1]:
import optax
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
from flax import struct


seed = 42
key = jax.random.PRNGKey(seed)


A simple linear regression model will let us explore Optax with minimal fuss.

In [29]:
class LinearRegression(nn.Module):
    features: int

    def setup(self):
        self.dense = nn.Dense(features=self.features)

    def __call__(self, x):
        return self.dense(x)

# Initialize the model
model = LinearRegression(features=1)

# Initialize parameters
x = jnp.ones((1, 1))  # Dummy input
params = model.init(key, x)

# Define a simple mean squared error loss function
def mse_loss(params, model, x, y):
    preds = model.apply(params, x)
    return jnp.mean((preds - y) ** 2)

# Example usage
x_train = jnp.array([[1.0], [2.0], [3.0], [4.0]])
y_train = jnp.array([[2.0], [4.0], [6.0], [8.0]])


In [30]:
import jax
import jax.numpy as jnp
from flax import struct
import optax


def scheduler(epoch: int) -> float:
    """"
    We use jax.lax.cond as a way to implement conditional logic inside a jitted function.   
    """
    return jax.lax.cond(
        epoch < 100,
        lambda _: 1e-3,
        lambda _: jax.lax.cond(
            epoch < 500,
            lambda _: 1e-4,
            lambda _: 1e-5,
            operand=None
        ),
        operand=None
    )

# This function will return the optimizer class.
def momentum_optimizer(args: dict) -> optax.GradientTransformation:
    """
    Returns a optax.GradientTransformation class. This type of class accepts two functions 
    and assigns them to the init and update methods.
    """
    beta = args.get('beta', 0.9)
    fixed_learning_rate = args.get('learning_rate', None)
    scheduler = args.get('scheduler', None)

    # to pass objects to jitted functions, we need to define classes as dataclasses
    @struct.dataclass
    class OptState:
        momentum: any

    def init_fn(params):
        momentum = jax.tree_map(jnp.zeros_like, params)
        return OptState(momentum)

    def update_fn(grads, opt_state, epoch=None):
        # If there is no scheduler, we fall back to the learning rate provided in the args
        learning_rate = jax.lax.cond(
            scheduler is not None,
            lambda _: scheduler(epoch),
            lambda _: fixed_learning_rate,
            operand=None
        )

        # update momentum
        momentum_next = jax.tree_map(
            lambda m, g: beta * m + g, opt_state.momentum, grads
        )

        # Just compute the update to the params, 
        # This will be applied outside the optimizer.
        param_updates = jax.tree_map(
            lambda m: -learning_rate * m, momentum_next
        )

        return param_updates, OptState(momentum_next)

    return optax.GradientTransformation(init_fn, update_fn)

# Example usage:
opt_args = {'learning_rate': 0.1, 'beta': 0.9, 'scheduler': scheduler}
# opt_args = {'learning_rate': 0.1, 'beta': 0.9}
optimizer = momentum_optimizer(opt_args)
opt_state = optimizer.init(params)

# Training step
@jax.jit
def train_step(params, opt_state, x, y, epoch):
    loss, grads = jax.value_and_grad(mse_loss)(params, model, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, epoch=epoch)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Training loop
for epoch in range(1000):
    params, opt_state, loss = train_step(params, opt_state, x_train, y_train, epoch)
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, Loss: {loss}')

Epoch 0, Loss: 47.301307678222656
Epoch 100, Loss: 0.056781843304634094
Epoch 200, Loss: 0.052745550870895386
Epoch 300, Loss: 0.049648720771074295
Epoch 400, Loss: 0.04675476998090744
Epoch 500, Loss: 0.04402981325984001
Epoch 600, Loss: 0.04376675933599472
Epoch 700, Loss: 0.04350510984659195
Epoch 800, Loss: 0.043245431035757065
Epoch 900, Loss: 0.04298752173781395


  momentum = jax.tree_map(jnp.zeros_like, params)
  momentum_next = jax.tree_map(
  param_updates = jax.tree_map(
