# Schedule-free optimizer

This notebook illustrates how to incoprorate the [optax.contrib.schedule_free](https://optax.readthedocs.io/en/latest/api/contrib.html#optax.contrib.schedule_free) optimizer in usual pipelines.

The notebook is purely for implementation details purposes not for performance illustration.

In [None]:
import math

from flax import linen as nn
import jax
import jax.numpy as jnp
import optax
import optax.tree_utils as otu
import tensorflow as tf
import tensorflow_datasets as tfds

from matplotlib import pyplot as plt

tf.config.experimental.set_visible_devices([], "GPU")
print("JAX running on", jax.devices()[0].platform.upper())

In [None]:
# @markdown Total number of epochs to train for:
N_STEPS = 1000  # @param{type:"integer"}

# @markdown Number of samples in each batch:
BATCH_SIZE = 4  # @param{type:"integer"}

# @markdown Frequency to eval loss
EVAL_EVERY = 50  # @param{type:"integer"}

## Setup


In [None]:
# @title Data

def tf_to_numpy(xs):
  return jax.tree_util.tree_map(lambda x: x._numpy(), xs)


def get_data():
  (train_loader, test_loader), info = tfds.load(
    "cifar10", split=["train", "test"], as_supervised=True, with_info=True
  )

  def augment(image, label):
    """Performs data augmentation."""
    image = tf.image.resize_with_crop_or_pad(image, 40, 40)
    image = tf.image.random_crop(image, [32, 32, 3])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.2)
    image = tf.image.random_contrast(image, 0.8, 1.2)
    image = tf.image.random_saturation(image, 0.8, 1.2)
    return image, label


  train_loader = train_loader.repeat().map(augment)

  train_loader = train_loader.shuffle(
      buffer_size=10_000, reshuffle_each_iteration=True
  ).batch(BATCH_SIZE, drop_remainder=True)
  train_loader = map(tf_to_numpy, train_loader)

  test_loader = test_loader.batch(BATCH_SIZE, drop_remainder=True).repeat().prefetch(10)
  test_loader = map(tf_to_numpy, test_loader)

  train_steps_per_epoch = math.ceil(info.splits['train'].num_examples / BATCH_SIZE)
  val_steps_per_epoch = math.ceil(info.splits['test'].num_examples / BATCH_SIZE)
  info = {'train_steps_per_epoch': train_steps_per_epoch, 'val_steps_per_epoch': val_steps_per_epoch}
  return train_loader, test_loader, info

In [None]:
# @title Model
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

net = CNN()


In [None]:
# @title Train, eval steps

def get_eval_params(params, state):
  sfo_state = otu.tree_get(state, "ScheduleFreeState")
  if sfo_state is not None:
    eval_params = optax.contrib.schedule_free_eval_params(sfo_state, params)
  else:
    eval_params = params
  return eval_params


def train_obj(params, data):
  inputs, labels = data
  logits = net.apply(params, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=labels
  ).mean()
  accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
  return loss, accuracy


def train_step(params, state, data, opt):
  _, grads = jax.value_and_grad(train_obj, has_aux=True)(params, data)
  udpates, state = opt.update(grads, state, params)
  params = optax.apply_updates(params, udpates)
  return params, state


train_step = jax.jit(train_step, static_argnames=["opt"])


def eval_step(params, state, data):
  eval_params = get_eval_params(params, state)
  loss, accuracy = train_obj(eval_params, data)
  return loss, accuracy


eval_step = jax.jit(eval_step)


def eval(params, state, dataset, num_batch_per_eval):
  total_loss = 0.0
  total_acc = 0.0
  for _ in range(num_batch_per_eval):
    batch = next(dataset)
    loss, acc = eval_step(params, state, batch)
    total_loss += loss
    total_acc += acc
  return total_loss/num_batch_per_eval, total_acc/num_batch_per_eval



In [None]:
# @title Train loop

def init_params(input_example, opt):
  key = jax.random.PRNGKey(0)
  params = net.init(key, input_example)
  state = opt.init(params)
  return params, state

def train_loop(params, state, opt, train_loader, test_loader, info_data):
  loss_log = []
  acc_log = []

  for step, batch in zip(range(N_STEPS), train_loader):
    params, state = train_step(params, state, batch, opt)
    if (step % EVAL_EVERY) == 0:
      avg_loss, avg_acc = eval(params, state, test_loader, info_data['val_steps_per_epoch'])
      print(f'step: {step}, loss: {avg_loss}, acc: {avg_acc}')
      loss_log.append(avg_loss)
      acc_log.append(avg_acc)
  return loss_log, acc_log


## Experiments

In [None]:
# @title Adam with prefixed schedule

schedule = optax.warmup_cosine_decay_schedule(0., 1e-3, int(N_STEPS/10), decay_steps=N_STEPS)
opt = optax.adam(learning_rate=schedule)

train_loader, test_loader, info_data = get_data()
input_exmp = next(iter(train_loader))[0]
params, state = init_params(input_exmp, opt)

loss_log, acc_log = train_loop(params, state, opt, train_loader, test_loader, info_data)

In [None]:
plt.plot(acc_log)
plt.show()

In [None]:
# @title Schedule-free Adamw
opt = optax.contrib.schedule_free_adamw(learning_rate=1e-3, warmup_steps=int(N_STEPS/10))

train_loader, test_loader, info_data = get_data()
input_exmp = next(iter(train_loader))[0]
params, state = init_params(input_exmp, opt)

loss_log, acc_log = train_loop(params, state, opt, train_loader, test_loader, info_data)

In [None]:
plt.plot(acc_log)
plt.show()