Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions linen_examples/mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## MNIST classification

Trains a simple convolutional network on the MNIST dataset.

### Requirements
* TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary

### Example output

```
I0828 08:51:41.821526 139971964110656 mnist_lib.py:128] train epoch: 10, loss: 0.0097, accuracy: 99.69
I0828 08:51:42.248714 139971964110656 mnist_lib.py:178] eval epoch: 10, loss: 0.0299, accuracy: 99.14
```

### How to run

`python mnist_main.py --model_dir=/tmp/mnist`
73 changes: 73 additions & 0 deletions linen_examples/mnist/mnist_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Benchmark for the MNIST example."""
import time
from absl import flags
from absl.testing import absltest
from absl.testing.flagsaver import flagsaver

import jax
import numpy as np
from flax.testing import Benchmark

import mnist_main


# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()


FLAGS = flags.FLAGS


class MnistBenchmark(Benchmark):
"""Benchmarks for the MNIST Flax example."""

@flagsaver
def test_cpu(self):
"""Run full training for MNIST CPU training."""
model_dir = self.get_tmp_model_dir()
FLAGS.model_dir = model_dir
start_time = time.time()
mnist_main.main([])
benchmark_time = time.time() - start_time
summaries = self.read_summaries(model_dir)

# Summaries contain all the information necessary for the regression
# metrics.
wall_time, _, eval_accuracy = zip(*summaries['eval_accuracy'])
wall_time = np.array(wall_time)
sec_per_epoch = np.mean(wall_time[1:] - wall_time[:-1])
end_eval_accuracy = eval_accuracy[-1]

# Assertions are deferred until the test finishes, so the metrics are
# always reported and benchmark success is determined based on *all*
# assertions.
self.assertBetween(end_eval_accuracy, 0.98, 1.0)

# Use the reporting API to report single or multiple metrics/extras.
self.report_wall_time(benchmark_time)
self.report_metrics({
'sec_per_epoch': sec_per_epoch,
'accuracy': end_eval_accuracy,
})
self.report_extras({
'model_name': 'MNIST',
'description': 'CPU test for MNIST.'
})


if __name__ == '__main__':
absltest.main()
186 changes: 186 additions & 0 deletions linen_examples/mnist/mnist_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MNIST example.

Library file which executes the training and evaluation loop for MNIST.
The data is loaded using tensorflow_datasets.
"""

from absl import logging

import jax
from jax import random
import jax.numpy as jnp
from jax.config import config
config.enable_omnistaging()

import numpy as onp

import tensorflow_datasets as tfds

from flax import linen as nn
from flax import optim
from flax.metrics import tensorboard

class CNN(nn.Module):
"""A simple CNN model."""

@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x


def get_params(key):
init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = CNN().init(key, init_shape)["param"]
return initial_params


def create_optimizer(params, learning_rate, beta):
optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta)
optimizer = optimizer_def.create(params)
return optimizer


def onehot(labels, num_classes=10):
x = (labels[..., None] == jnp.arange(num_classes)[None])
return x.astype(jnp.float32)


def cross_entropy_loss(logits, labels):
return -jnp.mean(jnp.sum(onehot(labels) * logits, axis=-1))


def compute_metrics(logits, labels):
loss = cross_entropy_loss(logits, labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics


@jax.jit
def train_step(optimizer, batch):
"""Train for a single step."""
def loss_fn(params):
logits = CNN().apply({'param': params}, batch['image'])
loss = cross_entropy_loss(logits, batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grad = grad_fn(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
metrics = compute_metrics(logits, batch['label'])
return optimizer, metrics


@jax.jit
def eval_step(params, batch):
logits = CNN().apply({'param': params}, batch['image'])
return compute_metrics(logits, batch['label'])


def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size

perms = random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for perm in perms:
batch = {k: v[perm] for k, v in train_ds.items()}
optimizer, metrics = train_step(optimizer, batch)
batch_metrics.append(metrics)

# compute mean of metrics across each batch in epoch.
batch_metrics_np = jax.device_get(batch_metrics)
epoch_metrics_np = {
k: onp.mean([metrics[k] for metrics in batch_metrics_np])
for k in batch_metrics_np[0]}

logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100)

return optimizer, epoch_metrics_np


def eval_model(model, test_ds):
metrics = eval_step(model, test_ds)
metrics = jax.device_get(metrics)
summary = jax.tree_map(lambda x: x.item(), metrics)
return summary['loss'], summary['accuracy']


def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds


def train_and_evaluate(model_dir: str, num_epochs: int, batch_size: int,
learning_rate: float, momentum: float):
"""Execute model training and evaluation loop.

Args:
model_dir: Directory where the tensorboard summaries are written to.
num_epochs: Number of epochs to cycle through the dataset before stopping.
batch_size: Batch size of the input.
learning_rate: Learning rate for the momentum optimizer.
momentum: Momentum value for the momentum optimizer.
"""
train_ds, test_ds = get_datasets()
rng = random.PRNGKey(0)

summary_writer = tensorboard.SummaryWriter(model_dir)

rng, init_rng = random.split(rng)
params = get_params(init_rng)
optimizer = create_optimizer(params, learning_rate, momentum)

for epoch in range(1, num_epochs + 1):
rng, input_rng = random.split(rng)
optimizer, train_metrics = train_epoch(
optimizer, train_ds, batch_size, epoch, input_rng)
loss, accuracy = eval_model(optimizer.target, test_ds)

logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
epoch, loss, accuracy * 100)

summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch)
summary_writer.scalar('eval_loss', loss, epoch)
summary_writer.scalar('eval_accuracy', accuracy, epoch)

summary_writer.flush()
return optimizer
63 changes: 63 additions & 0 deletions linen_examples/mnist/mnist_lib_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for flax.examples.mnist.mnist_lib."""

import pathlib
import tempfile

from absl.testing import absltest

import jax
from jax import numpy as jnp

import tensorflow_datasets as tfds

import mnist_lib


class MnistLibTest(absltest.TestCase):
"""Test cases for mnist_lib."""

def test_cnn(self):
"""Tests CNN module used as the trainable model."""
rng = jax.random.PRNGKey(0)
inputs = jnp.ones((5, 224, 224, 3), jnp.float32)
output, variables = mnist_lib.CNN().init_with_output(rng, inputs)

self.assertEqual((5, 10), output.shape)

# TODO(mohitreddy): Consider creating a testing module which
# gives a parameters overview including number of parameters.
self.assertLen(variables['param'], 4)

def test_train_and_evaluate(self):
"""Tests training and evaluation code by running a single step with
mocked data for MNIST dataset.
"""
# Create a temporary directory where tensorboard metrics are written.
model_dir = tempfile.mkdtemp()

# Go two directories up to the root of the flax directory.
flax_root_dir = pathlib.Path(__file__).parents[2]
data_dir = str(flax_root_dir) + '/.tfds/metadata'

with tfds.testing.mock_data(num_examples=8, data_dir=data_dir):
mnist_lib.train_and_evaluate(
model_dir=model_dir, num_epochs=1, batch_size=8,
learning_rate=0.1, momentum=0.9)


if __name__ == '__main__':
absltest.main()
Loading