From b5a7d3de6abdceadffe8a5221d290e8adc84dcb1 Mon Sep 17 00:00:00 2001 From: Marc van Zee Date: Fri, 28 Aug 2020 08:56:22 +0000 Subject: [PATCH 1/2] Adds MNIST to Linen examples --- linen_examples/mnist/README.md | 17 +++ linen_examples/mnist/mnist_benchmark.py | 73 ++++++++++ linen_examples/mnist/mnist_lib.py | 186 ++++++++++++++++++++++++ linen_examples/mnist/mnist_lib_test.py | 63 ++++++++ linen_examples/mnist/mnist_main.py | 58 ++++++++ 5 files changed, 397 insertions(+) create mode 100644 linen_examples/mnist/README.md create mode 100644 linen_examples/mnist/mnist_benchmark.py create mode 100644 linen_examples/mnist/mnist_lib.py create mode 100644 linen_examples/mnist/mnist_lib_test.py create mode 100644 linen_examples/mnist/mnist_main.py diff --git a/linen_examples/mnist/README.md b/linen_examples/mnist/README.md new file mode 100644 index 000000000..c12615f07 --- /dev/null +++ b/linen_examples/mnist/README.md @@ -0,0 +1,17 @@ +## MNIST classification + +Trains a simple convolutional network on the MNIST dataset. + +### Requirements +* TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary + +### Example output + +``` +I0828 08:51:41.821526 139971964110656 mnist_lib.py:128] train epoch: 10, loss: 0.0097, accuracy: 99.69 +I0828 08:51:42.248714 139971964110656 mnist_lib.py:178] eval epoch: 10, loss: 0.0299, accuracy: 99.14 +``` + +### How to run + +`python mnist_main.py --model_dir=/tmp/mnist` diff --git a/linen_examples/mnist/mnist_benchmark.py b/linen_examples/mnist/mnist_benchmark.py new file mode 100644 index 000000000..aa5472464 --- /dev/null +++ b/linen_examples/mnist/mnist_benchmark.py @@ -0,0 +1,73 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark for the MNIST example.""" +import time +from absl import flags +from absl.testing import absltest +from absl.testing.flagsaver import flagsaver + +import jax +import numpy as np +from flax.testing import Benchmark + +import mnist_main + + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + + +FLAGS = flags.FLAGS + + +class MnistBenchmark(Benchmark): + """Benchmarks for the MNIST Flax example.""" + + @flagsaver + def test_cpu(self): + """Run full training for MNIST CPU training.""" + model_dir = self.get_tmp_model_dir() + FLAGS.model_dir = model_dir + start_time = time.time() + mnist_main.main([]) + benchmark_time = time.time() - start_time + summaries = self.read_summaries(model_dir) + + # Summaries contain all the information necessary for the regression + # metrics. + wall_time, _, eval_accuracy = zip(*summaries['eval_accuracy']) + wall_time = np.array(wall_time) + sec_per_epoch = np.mean(wall_time[1:] - wall_time[:-1]) + end_eval_accuracy = eval_accuracy[-1] + + # Assertions are deferred until the test finishes, so the metrics are + # always reported and benchmark success is determined based on *all* + # assertions. + self.assertBetween(end_eval_accuracy, 0.98, 1.0) + + # Use the reporting API to report single or multiple metrics/extras. + self.report_wall_time(benchmark_time) + self.report_metrics({ + 'sec_per_epoch': sec_per_epoch, + 'accuracy': end_eval_accuracy, + }) + self.report_extras({ + 'model_name': 'MNIST', + 'description': 'CPU test for MNIST.' + }) + + +if __name__ == '__main__': + absltest.main() diff --git a/linen_examples/mnist/mnist_lib.py b/linen_examples/mnist/mnist_lib.py new file mode 100644 index 000000000..d74f368e5 --- /dev/null +++ b/linen_examples/mnist/mnist_lib.py @@ -0,0 +1,186 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNIST example. + +Library file which executes the training and evaluation loop for MNIST. +The data is loaded using tensorflow_datasets. +""" + +from absl import logging + +import jax +from jax import random +import jax.numpy as jnp +from jax.config import config +config.enable_omnistaging() + +import numpy as onp + +import tensorflow_datasets as tfds + +from flax import linen as nn +from flax import optim +from flax.metrics import tensorboard + +class CNN(nn.Module): + """A simple CNN model.""" + + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(features=256)(x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + x = nn.log_softmax(x) + return x + + +def get_params(key): + init_shape = jnp.ones((1, 28, 28, 1), jnp.float32) + initial_params = CNN().init(key, init_shape)["param"] + return initial_params + + +def create_optimizer(params, learning_rate, beta): + optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta) + optimizer = optimizer_def.create(params) + return optimizer + + +def onehot(labels, num_classes=10): + x = (labels[..., None] == jnp.arange(num_classes)[None]) + return x.astype(jnp.float32) + + +def cross_entropy_loss(logits, labels): + return -jnp.mean(jnp.sum(onehot(labels) * logits, axis=-1)) + + +def compute_metrics(logits, labels): + loss = cross_entropy_loss(logits, labels) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + metrics = { + 'loss': loss, + 'accuracy': accuracy, + } + return metrics + + +@jax.jit +def train_step(optimizer, batch): + """Train for a single step.""" + def loss_fn(params): + logits = CNN().apply({'param': params}, batch['image']) + loss = cross_entropy_loss(logits, batch['label']) + return loss, logits + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (_, logits), grad = grad_fn(optimizer.target) + optimizer = optimizer.apply_gradient(grad) + metrics = compute_metrics(logits, batch['label']) + return optimizer, metrics + + +@jax.jit +def eval_step(params, batch): + logits = CNN().apply({'param': params}, batch['image']) + return compute_metrics(logits, batch['label']) + + +def train_epoch(optimizer, train_ds, batch_size, epoch, rng): + """Train for a single epoch.""" + train_ds_size = len(train_ds['image']) + steps_per_epoch = train_ds_size // batch_size + + perms = random.permutation(rng, len(train_ds['image'])) + perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch + perms = perms.reshape((steps_per_epoch, batch_size)) + batch_metrics = [] + for perm in perms: + batch = {k: v[perm] for k, v in train_ds.items()} + optimizer, metrics = train_step(optimizer, batch) + batch_metrics.append(metrics) + + # compute mean of metrics across each batch in epoch. + batch_metrics_np = jax.device_get(batch_metrics) + epoch_metrics_np = { + k: onp.mean([metrics[k] for metrics in batch_metrics_np]) + for k in batch_metrics_np[0]} + + logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, + epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100) + + return optimizer, epoch_metrics_np + + +def eval_model(model, test_ds): + metrics = eval_step(model, test_ds) + metrics = jax.device_get(metrics) + summary = jax.tree_map(lambda x: x.item(), metrics) + return summary['loss'], summary['accuracy'] + + +def get_datasets(): + """Load MNIST train and test datasets into memory.""" + ds_builder = tfds.builder('mnist') + ds_builder.download_and_prepare() + train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) + test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) + train_ds['image'] = jnp.float32(train_ds['image']) / 255. + test_ds['image'] = jnp.float32(test_ds['image']) / 255. + return train_ds, test_ds + + +def train_and_evaluate(model_dir: str, num_epochs: int, batch_size: int, + learning_rate: float, momentum: float): + """Execute model training and evaluation loop. + + Args: + model_dir: Directory where the tensorboard summaries are written to. + num_epochs: Number of epochs to cycle through the dataset before stopping. + batch_size: Batch size of the input. + learning_rate: Learning rate for the momentum optimizer. + momentum: Momentum value for the momentum optimizer. +""" + train_ds, test_ds = get_datasets() + rng = random.PRNGKey(0) + + summary_writer = tensorboard.SummaryWriter(model_dir) + + rng, init_rng = random.split(rng) + params = get_params(init_rng) + optimizer = create_optimizer(params, learning_rate, momentum) + + for epoch in range(1, num_epochs + 1): + rng, input_rng = random.split(rng) + optimizer, train_metrics = train_epoch( + optimizer, train_ds, batch_size, epoch, input_rng) + loss, accuracy = eval_model(optimizer.target, test_ds) + + logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', + epoch, loss, accuracy * 100) + + summary_writer.scalar('train_loss', train_metrics['loss'], epoch) + summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch) + summary_writer.scalar('eval_loss', loss, epoch) + summary_writer.scalar('eval_accuracy', accuracy, epoch) + + summary_writer.flush() + return optimizer diff --git a/linen_examples/mnist/mnist_lib_test.py b/linen_examples/mnist/mnist_lib_test.py new file mode 100644 index 000000000..b86537d34 --- /dev/null +++ b/linen_examples/mnist/mnist_lib_test.py @@ -0,0 +1,63 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for flax.examples.mnist.mnist_lib.""" + +import pathlib +import tempfile + +from absl.testing import absltest + +import jax +from jax import numpy as jnp + +import tensorflow_datasets as tfds + +import mnist_lib + + +class MnistLibTest(absltest.TestCase): + """Test cases for mnist_lib.""" + + def test_cnn(self): + """Tests CNN module used as the trainable model.""" + rng = jax.random.PRNGKey(0) + output, init_params = mnist_lib.CNN.init_by_shape( + rng, [((5, 224, 224, 3), jnp.float32)]) + + self.assertEqual((5, 10), output.shape) + + # TODO(mohitreddy): Consider creating a testing module which + # gives a parameters overview including number of parameters. + self.assertLen(init_params, 4) + + def test_train_and_evaluate(self): + """Tests training and evaluation code by running a single step with + mocked data for MNIST dataset. + """ + # Create a temporary directory where tensorboard metrics are written. + model_dir = tempfile.mkdtemp() + + # Go two directories up to the root of the flax directory. + flax_root_dir = pathlib.Path(__file__).parents[2] + data_dir = str(flax_root_dir) + '/.tfds/metadata' + + with tfds.testing.mock_data(num_examples=8, data_dir=data_dir): + mnist_lib.train_and_evaluate( + model_dir=model_dir, num_epochs=1, batch_size=8, + learning_rate=0.1, momentum=0.9) + + +if __name__ == '__main__': + absltest.main() diff --git a/linen_examples/mnist/mnist_main.py b/linen_examples/mnist/mnist_main.py new file mode 100644 index 000000000..9edaa721a --- /dev/null +++ b/linen_examples/mnist/mnist_main.py @@ -0,0 +1,58 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNIST example. + +This script trains a simple Convolutional Neural Net on the MNIST dataset. + +""" + +from absl import app +from absl import flags + +import mnist_lib + +FLAGS = flags.FLAGS + +flags.DEFINE_float( + 'learning_rate', default=0.1, + help=('The learning rate for the momentum optimizer.')) + +flags.DEFINE_float( + 'momentum', default=0.9, + help=('The decay rate used for the momentum optimizer.')) + +flags.DEFINE_integer( + 'batch_size', default=128, + help=('Batch size for training.')) + +flags.DEFINE_integer( + 'num_epochs', default=10, + help=('Number of training epochs.')) + +flags.DEFINE_string( + 'model_dir', default=None, + help=('Directory to store model data.')) + + +def main(_): + mnist_lib.train_and_evaluate( + model_dir=FLAGS.model_dir, num_epochs=FLAGS.num_epochs, + batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate, + momentum=FLAGS.momentum) + + +if __name__ == '__main__': + flags.mark_flag_as_required('model_dir') + app.run(main) From 67eb8e14315de70c4e1564edbe712cf9774e19cf Mon Sep 17 00:00:00 2001 From: Marc van Zee Date: Fri, 28 Aug 2020 10:44:10 +0000 Subject: [PATCH 2/2] Fix tests --- linen_examples/mnist/mnist_lib_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/linen_examples/mnist/mnist_lib_test.py b/linen_examples/mnist/mnist_lib_test.py index b86537d34..f58b24a61 100644 --- a/linen_examples/mnist/mnist_lib_test.py +++ b/linen_examples/mnist/mnist_lib_test.py @@ -33,14 +33,14 @@ class MnistLibTest(absltest.TestCase): def test_cnn(self): """Tests CNN module used as the trainable model.""" rng = jax.random.PRNGKey(0) - output, init_params = mnist_lib.CNN.init_by_shape( - rng, [((5, 224, 224, 3), jnp.float32)]) + inputs = jnp.ones((5, 224, 224, 3), jnp.float32) + output, variables = mnist_lib.CNN().init_with_output(rng, inputs) self.assertEqual((5, 10), output.shape) # TODO(mohitreddy): Consider creating a testing module which # gives a parameters overview including number of parameters. - self.assertLen(init_params, 4) + self.assertLen(variables['param'], 4) def test_train_and_evaluate(self): """Tests training and evaluation code by running a single step with