### 1. Imports

In [1]:
import jax
import jax.numpy as jnp                # JAX NumPy
from jestimator import amos            # The Amos optimizer implementation
from jestimator import amos_helper     # Helper module for Amos

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import math
import tensorflow_datasets as tfds     # TFDS for MNIST
from sklearn.metrics import accuracy_score

### 2. Load data

In [2]:
def get_datasets():
  """Load MNIST train and test datasets into memory."""

  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

### 3. Build model

In [3]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

  def classify_xe_loss(self, x, labels):
    # Labels read from the tfds MNIST are integers from 0 to 9.  
    # Logits are arrays of size 10.
    logits = self(x)
    logits = jax.nn.log_softmax(logits)
    labels_ = jnp.expand_dims(labels, -1)
    llh_ = jnp.take_along_axis(logits, labels_, axis=-1)
    loss = -jnp.sum(llh_)
    return loss

### 4. Create train state

A `TrainState` object keeps the model parameters and optimizer states, and can be checkpointed into files.

We create the model and optimizer in this function.

**For the optimizer, we use Amos here.** The following hyper-parameters are set:

 * *learning_rate*:       The global learning rate.
 * *eta_fn*:              The model-specific 'eta'.
 * *shape_fn*:            Memory reduction setting.
 * *beta*:                Rate for running average of gradient squares.
 * *clip_value*:          Gradient clipping for stable training.

The global learning rate is usually set to the 1/sqrt(N), where N is the number of batches in the training data. For MNIST, we have 60k training examples and batch size is 32. So learning_rate=1/sqrt(60000/32).

The model-specific 'eta_fn' requires a function that, given a variable name and shape, returns a float indicating the expected scale of that variable. Hopefully in the near future we will have libraries that can automatically calculate this 'eta_fn' from the modeling code; but for now we have to specify it manually.

One can use the amos_helper.params_fn_from_assign_map() helper function to create 'eta_fn' from an assign_map. An assign_map is a dict which maps regex rules to a value or simple Python expressions. It will find the first regex rule which matches the name of a variable, and evaluate the Python expression if necessary to return the value. See our example below.

The 'shape_fn' similarly requires a function that, given a variable name and shape, returns a reduced shape for the corresponding slot variables. We can use the amos_helper.params_fn_from_assign_map() helper function to create 'shape_fn' from an assign_map as well.

'beta' is the exponential decay rate for running average of gradient squares. We set it to 0.98 here.

'clip_value' is the gradient clipping value, which should match the magnitude of the loss function. If the loss function is a sum of cross-entropy, then we should set 'clip_value' to the sqrt of the number of labels.

Please refer to our [paper](https://arxiv.org/abs/2210.11693) for more details of the hyper-parameters.

In [4]:
def get_train_state(rng):
  model = CNN()
  dummy_x = jnp.ones([1, 28, 28, 1])
  params = model.init(rng, dummy_x)

  eta_fn = amos_helper.params_fn_from_assign_map(
      {
          '.*/bias': 0.5,
          '.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',
          '.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',
          '.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',
          '.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',
      },
      eval_str_value=True,
  )
  shape_fn = amos_helper.params_fn_from_assign_map(
      {
          '.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',
          '.*Dense_0/kernel': '(1, SHAPE[1])',
          '.*': (),
      },
      eval_str_value=True,
  )
  optimizer = amos.amos(
      learning_rate=1/math.sqrt(60000/32),
      eta_fn=eta_fn,
      shape_fn=shape_fn,
      beta=0.98,
      clip_value=math.sqrt(32),
  )
  return train_state.TrainState.create(
      apply_fn=model.apply, params=params, tx=optimizer)

### 5. Training step

Use JAX’s @jit decorator to just-in-time compile the function for better performance.

In [5]:
@jax.jit
def train_step(batch, state):
  grad_fn = jax.grad(state.apply_fn)
  grads = grad_fn(
      state.params,
      batch['image'],
      batch['label'],
      method=CNN.classify_xe_loss)
  return state.apply_gradients(grads=grads)

### 6. Infer step

Use JAX’s @jit decorator to just-in-time compile the function for better performance.

In [6]:
@jax.jit
def infer_step(batch, state):
  logits = state.apply_fn(state.params, batch['image'])
  return jnp.argmax(logits, -1)

### 7. Main

Run the training loop and evaluate on test set.

In [7]:
train_ds, test_ds = get_datasets()

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = get_train_state(init_rng)
del init_rng  # Must not be used anymore.

num_epochs = 9
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  perms = jax.random.permutation(input_rng, 60000)
  del input_rng
  perms = perms.reshape((60000 // 32, 32))
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state = train_step(batch, state)

  pred = jax.device_get(infer_step(test_ds, state))
  accuracy = accuracy_score(test_ds['label'], pred)
  print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))



epoch: 1, test accuracy: 97.28
epoch: 2, test accuracy: 98.46
epoch: 3, test accuracy: 98.63
epoch: 4, test accuracy: 97.91
epoch: 5, test accuracy: 98.59
epoch: 6, test accuracy: 99.05
epoch: 7, test accuracy: 99.15
epoch: 8, test accuracy: 99.21
epoch: 9, test accuracy: 99.26
