In [1]:
from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow_datasets as tfds
from time import time


class MLP(nn.Module):
  """A simple MLP model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=20)(x)
    x = nn.relu(x)
    x = nn.Dense(features=20)(x)
    x = nn.relu(x)
    x = nn.Dense(features=2)(x)
    return x


@jax.jit
def apply_model(state, X, labels):
  """Computes gradients, loss and accuracy for a single batch."""
  def loss_fn(params):
    logits = MLP().apply({'params': params}, X)
    one_hot = jax.nn.one_hot(labels, 2)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy


def train_epoch(state, train_ds, batch_size, rng):
  train_ds_size = len(train_ds['X'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['X']))
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []

  for perm in perms:
    batch_X = train_ds['X'][perm, ...]
    batch_labels = train_ds['labels'][perm, ...]
    grads, loss, accuracy = apply_model(state, batch_X, batch_labels)
    # state is a flax TrainState object with a tx property defining step
    # tx (effectively the optimizer) is set in create_train_state function below
    state = state.apply_gradients(grads=grads)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy

@jax.jit
def apply_gradients(state, grads):
    return state.apply_gradients(grads=grads)

def create_train_state(rng, config):
  """Creates initial `TrainState`."""
  mlp = MLP()
  params = mlp.init(rng, jnp.ones([1, 2]))['params']
  tx = optax.sgd(config.learning_rate, config.momentum)
  return train_state.TrainState.create(
      apply_fn=mlp.apply, params=params, tx=tx)


def train_and_evaluate(config: ml_collections.ConfigDict,
                       ) -> train_state.TrainState:
  """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.

  Returns:
    The train state (which includes the `.params`).
  """
  train_ds, test_ds = get_datasets()
  rng = jax.random.PRNGKey(0)

  rng, init_rng = jax.random.split(rng)
  state = create_train_state(init_rng, config)

  for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy = train_epoch(state, train_ds,
                                                    config.batch_size,
                                                    input_rng)
    _, test_loss, test_accuracy = apply_model(state, test_ds['X'],
                                              test_ds['labels'])

    print(
        'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
        % (epoch, train_loss, train_accuracy * 100, test_loss,
           test_accuracy * 100))

  return state


In [2]:
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  X = np.random.random((70_000, 2))
  labels = (X[:, 0] > X[:, 1]).astype(int)
  train_ds = {
      'X': X[:60_000],
      'labels': labels[:60_000],
  }
  test_ds = {
      'X': X[60_000:],
      'labels': labels[60_000:]
  }
  return train_ds, test_ds

train_ds, test_ds = get_datasets()

config = ml_collections.ConfigDict({
    'learning_rate': 0.1,
    'momentum': 0.2,
    'batch_size': 128,
    'num_epochs': 5,
})

start = time()
state = train_and_evaluate(config)
print(f'Total training time: {int(time() - start)}')
# Find all mistakes in testset.
logits = MLP().apply({'params': state.params}, test_ds['X'])
error_idxs, = jnp.where(test_ds['labels'] != logits.argmax(axis=1))
print(f'Error rate: {len(error_idxs) / len(logits)}')



epoch:  1, train_loss: 0.1268, train_accuracy: 98.47, test_loss: 0.0459, test_accuracy: 98.61
epoch:  2, train_loss: 0.0357, train_accuracy: 99.27, test_loss: 0.0323, test_accuracy: 98.69
epoch:  3, train_loss: 0.0277, train_accuracy: 99.25, test_loss: 0.0298, test_accuracy: 98.61
epoch:  4, train_loss: 0.0227, train_accuracy: 99.39, test_loss: 0.0330, test_accuracy: 98.27
epoch:  5, train_loss: 0.0205, train_accuracy: 99.34, test_loss: 0.0175, test_accuracy: 99.47
Total training time: 6
Error rate: 0.0053
