My first model with Flax.

I created dummy data which is two numeric features. The target is binary classification, where the label is true iff the first number comes before the second when the two numbers are spelled out in alphabetical order.



In [1]:
from flax import linen as nn
from flax.training import train_state
from num2words import num2words
from time import time

import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax


class MLP(nn.Module):
    """A simple MLP model."""

    @nn.compact
    def __call__(self, x):
        # Since 4 and 400 are close in alphabetical order (even if not in numerical value), I hypothesize that
        # the first digit is especially important. So I create a first_digit feature below
        num_digits = jnp.log10(jnp.abs(x))
        first_digit = jnp.floor((jnp.abs(x) // (10 ** (num_digits)))).clip(0)
        x = jnp.hstack([x, num_digits, first_digit])
        x = nn.Dense(features=20)(x)
        x = nn.relu(x)
        x = nn.Dense(features=20)(x)
        x = nn.relu(x)
        x = nn.Dense(features=2)(x)
        return x


@jax.jit
def apply_model(state, X, labels):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(params):
        logits = MLP().apply({"params": params}, X)
        one_hot = jax.nn.one_hot(labels, 2)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy


def train_epoch(state, train_ds, batch_size, rng):
    train_ds_size = len(train_ds["X"])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, len(train_ds["X"]))
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))

    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
        batch_X = train_ds["X"][perm, ...]
        batch_labels = train_ds["labels"][perm, ...]
        grads, loss, accuracy = apply_model(state, batch_X, batch_labels)
        # state is a flax TrainState object with a tx property defining step
        # tx (effectively the optimizer) is set in create_train_state function below
        state = state.apply_gradients(grads=grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    return state, train_loss, train_accuracy


@jax.jit
def apply_gradients(state, grads):
    return state.apply_gradients(grads=grads)


def create_train_state(rng, config):
    """Creates initial `TrainState`."""
    mlp = MLP()
    params = mlp.init(rng, jnp.ones([1, 2]))["params"]
    tx = optax.adabelief(config.learning_rate)
    return train_state.TrainState.create(apply_fn=mlp.apply, params=params, tx=tx)


def train_and_evaluate(config: ml_collections.ConfigDict,) -> train_state.TrainState:
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.

  Returns:
    The train state (which includes the `.params`).
  """
    train_ds, test_ds = make_data()
    rng = jax.random.PRNGKey(0)

    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, config)

    for epoch in range(1, config.num_epochs + 1):
        rng, input_rng = jax.random.split(rng)
        state, train_loss, train_accuracy = train_epoch(
            state, train_ds, config.batch_size, input_rng
        )
        _, test_loss, test_accuracy = apply_model(
            state, test_ds["X"], test_ds["labels"]
        )

        print(
            "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f"
            % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
        )

    return state


def make_data(train_size=50_000, val_size=1_000):
    total_size = train_size + val_size
    X = 10 * np.random.randn(total_size, 2)
    X_spelled_out = np.vectorize(num2words)(X)
    labels = (X_spelled_out[:, 0] > X_spelled_out[:, 1]).astype(int)
    labels = (
        np.array([num2words(i) for i in X[:, 0]])
        > np.array([num2words(i) for i in X[:, 1]])
    ).astype(int)
    train_ds = {
        "X": X[:train_size],
        "labels": labels[:train_size],
    }
    test_ds = {"X": X[train_size:], "labels": labels[train_size:]}
    return train_ds, test_ds


train_ds, test_ds = make_data()

config = ml_collections.ConfigDict(
    {"learning_rate": 0.05, "batch_size": 256, "num_epochs": 10,}
)

start = time()
state = train_and_evaluate(config)
print(f"Total training time: {int(time() - start)}")
# Print final test accuracy
logits = MLP().apply({"params": state.params}, test_ds["X"])
(error_idxs,) = jnp.where(test_ds["labels"] != logits.argmax(axis=1))
print(f"Error rate: {len(error_idxs) / len(logits)}")



epoch:  1, train_loss: 0.6741, train_accuracy: 59.08, test_loss: 0.5611, test_accuracy: 68.10
epoch:  2, train_loss: 0.5626, train_accuracy: 66.97, test_loss: 0.5271, test_accuracy: 71.40
epoch:  3, train_loss: 0.5514, train_accuracy: 67.34, test_loss: 0.5113, test_accuracy: 72.10
epoch:  4, train_loss: 0.5410, train_accuracy: 68.65, test_loss: 0.5119, test_accuracy: 71.60
epoch:  5, train_loss: 0.5343, train_accuracy: 68.70, test_loss: 0.5099, test_accuracy: 71.40
epoch:  6, train_loss: 0.5310, train_accuracy: 68.89, test_loss: 0.4993, test_accuracy: 71.00
epoch:  7, train_loss: 0.5268, train_accuracy: 69.08, test_loss: 0.5027, test_accuracy: 69.40
epoch:  8, train_loss: 0.5255, train_accuracy: 68.76, test_loss: 0.5022, test_accuracy: 71.10
epoch:  9, train_loss: 0.5200, train_accuracy: 69.26, test_loss: 0.4852, test_accuracy: 72.60
epoch: 10, train_loss: 0.5291, train_accuracy: 68.38, test_loss: 0.4920, test_accuracy: 70.80
Total training time: 84
Error rate: 0.307
