In [1]:
!pip install git+https://github.com/google-research/flax.git@prerelease

Collecting git+https://github.com/google-research/flax.git@prerelease
  Cloning https://github.com/google-research/flax.git (to revision prerelease) to /tmp/pip-req-build-tk28ygok
  Running command git clone -q https://github.com/google-research/flax.git /tmp/pip-req-build-tk28ygok
  Running command git checkout -b prerelease --track origin/prerelease
  Switched to a new branch 'prerelease'
  Branch 'prerelease' set up to track remote branch 'prerelease' from 'origin'.
Building wheels for collected packages: flax
  Building wheel for flax (setup.py) ... [?25l[?25hdone
  Created wheel for flax: filename=flax-0.0.1a0-cp36-none-any.whl size=58663 sha256=c51c37e7d4fb6751ae3e690bda124e1be838834c6da127842bdcf94222729653
  Stored in directory: /tmp/pip-ephem-wheel-cache-8n7vrtus/wheels/9b/83/b6/971d75100ac49feb064f934c1026f7c50e7abb681879a959e2
Successfully built flax
Installing collected packages: flax
Successfully installed flax-0.0.1a0


Annotated full end-to-end MNIST example


In [0]:
import jax
import flax

JAX has a re-implemented NumPy that runs on GPU and TPU


In [0]:
import jax.numpy as jnp


Flax can use any data loading pipeline. We use TF datasets.

In [4]:
import tensorflow_datasets as tfds

A Flax "module" lets you write a normal function, which defines learnable parameters in-line. In this case, we define a simple convolutional neural network.

Each call to flax.nn.Conv defines a learnable kernel.

In [0]:
class CNN(flax.nn.Module):
  """A simple CNN model."""

  def apply(self, x):
    x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = flax.nn.Dense(x, features=256)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=10)
    x = flax.nn.log_softmax(x)
    return x


Util function to create model

In [0]:
def create_model(key):
  _, initial_params = CNN.init(key, jnp.zeros((1, 28, 28, 1), jnp.float32))
  model = nn.Model(CNN, initial_params)
  return model

Util function to create optimizer

In [0]:
def create_optimizer(model, learning_rate, beta):
  optimizer_def = flax.optim.Momentum(learning_rate=learning_rate, beta=beta)
  optimizer = optimizer_def.create(model)
  return optimizer

jax.vmap allows us to define the cross_entropy_loss function as if it acts on a single sample. jax.vmap automatically vectorizes code efficiently to run on entire batches.



In [0]:
@jax.vmap
def cross_entropy_loss(logits, label):
  return -logits[label]

Compute loss and accuracy. We use jnp (jax.numpy) which can run on device (GPU or TPU).



In [0]:
def compute_metrics(logits, labels):
  loss = jnp.mean(cross_entropy_loss(logits, labels))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return {'loss': loss, 'accuracy': accuracy}

jax.jit traces the train_step function and compiles into fused device operations that run on GPU or TPU.



In [0]:
@jax.jit
def train_step(optimizer, batch):
  def loss_fn(model):
    logits = model(batch['image'])
    loss = jnp.mean(cross_entropy_loss(
        logits, batch['label']))
    return loss, logits
  optimizer, _, _ = optimizer.optimize(loss_fn)
  return optimizer

Making model predictions is as simple as calling model(input):

In [0]:
@jax.jit
def eval(model, eval_ds):
  logits = model(eval_ds['image'] / 255.0)
  return compute_metrics(logits, eval_ds['label'])

**Main train loop**





In [0]:
def train():
  # Load and shuffle MNIST.

  train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
  train_ds = train_ds.cache().shuffle(1000).batch(128)
  test_ds = tfds.as_numpy(tfds.load(
      'mnist', split=tfds.Split.TEST, batch_size=-1))
  # Create a new model, running all necessary initializers.
  # The parameters are stored as nested dicts on model.params.
  model = create_model(
      jax.random.PRNGKey(0))
  # Define an optimizer. At any particular optimzation step, optimizer.target contains the model. 
  optimizer = create_optimizer(model,
                               learning_rate=0.1, beta=0.9)
  # Run an optimization step for each batch of training
  for epoch in range(10):
    for batch in tfds.as_numpy(train_ds):
      batch['image'] = batch['image'] / 255.0
      optimizer = train_step(optimizer, batch)
    
    # Once an epoch, evaluate on the test set.
    metrics = eval(optimizer.target, test_ds)

    # metrics are only retrieved from device when needed on host (like in this print statement)
    print('eval epoch: %d, loss: %.4f, accuracy: %.2f'
      % (epoch+1,
      metrics['loss'], metrics['accuracy'] * 100))

In [13]:
train()

local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



[1mDownloading and preparing dataset mnist/3.0.0 (download: 11.06 MiB, generated: Unknown size, total: 11.06 MiB) to /root/tensorflow_datasets/mnist/3.0.0...[0m


HBox(children=(IntProgress(value=0, description='Dl Completed...', max=4, style=ProgressStyle(description_widt…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.0. Subsequent calls will reuse this data.[0m
eval epoch: 1, loss: 0.0640, accuracy: 98.03
eval epoch: 2, loss: 0.0613, accuracy: 98.06
eval epoch: 3, loss: 0.0375, accuracy: 98.76
eval epoch: 4, loss: 0.0382, accuracy: 98.71
eval epoch: 5, loss: 0.0360, accuracy: 98.92
eval epoch: 6, loss: 0.0365, accuracy: 98.92
eval epoch: 7, loss: 0.0313, accuracy: 99.14
eval epoch: 8, loss: 0.0355, accuracy: 99.05
eval epoch: 9, loss: 0.0448, accuracy: 98.77
eval epoch: 10, loss: 0.0298, accuracy: 99.14
