From 613af1b3e2be5ff8ff7b8f45a6916b91902c449a Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Tue, 10 May 2022 14:04:14 +0200 Subject: [PATCH] Adds "full eval" HOWTO. --- docs/flax.jax_utils.rst | 1 + docs/howtos/full_eval.rst | 161 +++++ docs/index.rst | 1 + docs/notebooks/full_eval.ipynb | 1023 ++++++++++++++++++++++++++++++++ flax/jax_utils.py | 68 +++ tests/jax_utils_test.py | 94 +++ 6 files changed, 1348 insertions(+) create mode 100644 docs/howtos/full_eval.rst create mode 100644 docs/notebooks/full_eval.ipynb create mode 100644 tests/jax_utils_test.py diff --git a/docs/flax.jax_utils.rst b/docs/flax.jax_utils.rst index 115755de8a..25955ebb3b 100644 --- a/docs/flax.jax_utils.rst +++ b/docs/flax.jax_utils.rst @@ -20,3 +20,4 @@ Multi device utilities .. autofunction:: pmean +.. autofunction:: pad_shard_unpad diff --git a/docs/howtos/full_eval.rst b/docs/howtos/full_eval.rst new file mode 100644 index 0000000000..c67a6bfe3a --- /dev/null +++ b/docs/howtos/full_eval.rst @@ -0,0 +1,161 @@ +.. image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/full_eval.ipynb + +Processing the entire Dataset +============================= + +For efficiency reasons, we form batches that contain multiple examples and +process them in parallel. Especially when evaluating a model, it is important +that we process all examples and avoid losing the remainder of examples that +does not form a complete batch at the end. + + +The problem +----------- + +When evaluating on a single device, one can either drop the last incomplete +batch, or one can form a last batch with a shape different from the preceding +batches. Doing the latter has the disadvantage that this will trigger a +recompilation of the ``eval_step()`` because XLA is not shape polymorphic. + +.. code-block:: python + + collections.Counter( + tuple(batch['image'].shape) + for batch in tfds.load('mnist', split='test').batch(per_device_batch_size) + ) + # output: + # Counter({(272, 28, 28, 1): 1, (512, 28, 28, 1): 19}) + +The problem is accentuated when using multiple devices for data parallelism. If +the batch size is not divisible by the number devices, then that last step must +be executed on a single device (or a subset of devices). Usually one would drop +the last batch, but this will lead to incorrect results. + + +.. code-block:: python + + sum( + np.prod(batch['label'].shape) + for batch in tfds.load('mnist', split='test') + .batch(per_device_batch_size, drop_remainder=True) + .batch(jax.local_device_count()) + ) + # output: + # 9728 + +Using multiple hosts further complicates the situation because JAX uses the SPMD +paradigm and every host must execute the same program. We would usually form +non-overlapping splits for different hosts with |tfds.split_for_jax_process()|_, +but this can lead to different numbers for different hosts, resulting in +different JAX programs when all examples are to be processed. + +.. code-block:: python + + process_count = 6 + [ + len(tfds.load(dataset_name, split=tfds.split_for_jax_process( + 'test', process_index=process_index, process_count=process_count))) + for process_index in range(process_count) + ] + # output: + # [1667, 1667, 1667, 1667, 1666, 1666] + + + +.. |tfds.split_for_jax_process()| replace:: ``tfds.split_for_jax_process()`` +.. _tfds.split_for_jax_process(): https://www.tensorflow.org/datasets/api_docs/python/tfds/split_for_jax_process + + +The solution: padding +--------------------- + +Even though it's possible to solve this problem by cleverly adjusting the number +of batches executed by different devices on different hosts, such a solution +quickly becomes complicated and makes the main eval loop hard to read with a lot +of cumbersome logic. + +The more straight forward solution to this problem is to use padding at the end +of the dataset to make sure that the last batch has the same size as the +preceding batches. + + +Manual implementation +~~~~~~~~~~~~~~~~~~~~~ + +The last batch is manually padded to contain the same number of examples as in +the preceding batches. The predictions for the padded examples are discarded +from the computation. + +.. code-block:: python + + shard = lambda x: einops.rearrange( + x, '(d b) ... -> d b ...', d=jax.local_device_count()) + unshard = lambda x: einops.rearrange(x, 'd b ... -> (d b) ...') + + correct = total = 0 + for batch in ds.as_numpy_iterator(): + images = batch['image'] + n = len(images) + padding = np.zeros([per_host_batch_size - n, *images.shape[1:]], images.dtype) + padded_images = np.concatenate([images, padding]) + preds = unshard(get_preds(vs_p, shard(padded_images)))[:n] + total += n + correct += (batch['label'] == preds.argmax(axis=-1)).sum() + + +Using ``pad_shard_unpad()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The above pattern, namely the pad→shard→predict→unshard→unpad sequence, can be +extracted into a utility wrapper `pad_shard_unpad()`, which greatly simplifies +above evaluation loop. + +.. code-block:: python + + correct = total = 0 + for batch in ds.as_numpy_iterator(): + preds = flax.jax_utils.pad_shard_unpad(get_preds)( + vs, batch['image'], min_device_batch=per_device_batch_size) + total += len(batch['image']) + correct += (batch['label'] == preds.argmax(axis=-1)).sum() + + +Adding "infinite padding" +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Above solution works in most cases, but it has some limitations: + +1. In the rare case where even splitting of the dataset on multiple hosts leads + to a different number of batches. Imagine having a dataset of `n=4097` + examples, and evaluating this on `h=8`, each having `d=8` local devices, + and forming on-device batch sizes of `b=128`. With even dataset splitting, + the first host would get `4096/8+1==513` examples, and all other hosts would + get `4096/8==512` examples. Forming per-host batches of `d*b==512` this would + lead to two batches on the first host, and a single batch on all other hosts, + violating SPMD principles and hanging the multi-host setup in the last + ``psum()`` directive (which would only be executed by the first host, but not + the others). + +2. When dropping examples dynamically by using `ds.filter()`. + +In this more complicated case we could add "infinite padding" to the dataset, +on each of the hosts independently, and continuing processing examples until +*all* hosts run out of unpadded examples. + +.. code-block:: python + + correct = total = 0 + for batch in ds.as_numpy_iterator(): + n = count_p(batch['mask'])[0].item() # adds sync barrier + if not n: break + + preds = get_preds(vs, batch['image']).argmax(axis=-1) + total += n + correct += count_correct_p(batch['label'], preds, batch['mask'])[0] + +As for the other examples in this HOWTO, the complete executable code can be +found in the Colab: + +.. image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/full_eval.ipynb diff --git a/docs/index.rst b/docs/index.rst index 43e22f4650..17f13876e5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -44,6 +44,7 @@ For a quick introduction and short example snippets, see our `README howtos/convert_pytorch_to_flax howtos/optax_update_guide howtos/linen_upgrade_guide + howtos/full_eval .. toctree:: diff --git a/docs/notebooks/full_eval.ipynb b/docs/notebooks/full_eval.ipynb new file mode 100644 index 0000000000..6de171ea7f --- /dev/null +++ b/docs/notebooks/full_eval.ipynb @@ -0,0 +1,1023 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "SMNC51ldX-Nq" + }, + "source": [ + "This notebook only contains executable code cells for the examples mentioned in\n", + "https://flax.readthedocs.io/en/latest/howtos/full_eval.html\n", + "\n", + "Please refer to above link for a an explanation of the problem and the proposed solutions." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Um6ZK_o1W-Vu" + }, + "source": [ + "### setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "62DTHYCYHWp1", + "outputId": "b38d096f-58db-4d61-effa-eafa4c732826" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[K |████████████████████████████████| 72 kB 387 kB/s \n", + "\u001b[K |████████████████████████████████| 4.2 MB 4.9 MB/s \n", + "\u001b[K |████████████████████████████████| 140 kB 4.5 MB/s \n", + "\u001b[?25h Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "!pip install -q chex einops\n", + "# tfds.split_for_jax_process() was added in 4.5.1\n", + "!pip install -q tensorflow_datasets -U\n", + "# flax.jax_utils.pad_shard_unpad() is only available at HEAD\n", + "!pip install -q git+https://github.com/google/flax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "NdzAaRwVExA9" + }, + "outputs": [], + "source": [ + "import collections\n", + "\n", + "import chex\n", + "import einops\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import flax\n", + "import flax.linen as nn\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds\n", + "\n", + "chex.set_n_cpu_devices(8)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "E30SS9gCIvrV" + }, + "outputs": [], + "source": [ + "per_device_batch_size = 512\n", + "dataset_name = 'mnist'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bZ-HWxKZHf6I", + "outputId": "639262cb-b617-4561-c31f-60b33156a15f" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "data": { + "text/plain": [ + "DeviceArray([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class FakeModel(nn.Module):\n", + " num_classes: int\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " return jax.nn.one_hot(jnp.zeros([len(x)], jnp.int32), self.num_classes)\n", + "\n", + "model = FakeModel(num_classes=10)\n", + "vs = {}\n", + "vs_p = flax.jax_utils.replicate(vs)\n", + "inputs = jnp.zeros([2, 28, 28, 1])\n", + "model.apply(vs, inputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sx65cZiiW_cq" + }, + "source": [ + "### The problem" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 121, + "referenced_widgets": [ + "5ff2865d0c0240909b5123a16cab4847", + "6f7316963aba4fd4b01fbe5f6c0e62f7", + "8712c2f9196a48b19f5a362eed9b96cd", + "e7102cf7f0e342c1a26a2042927d621c", + "276324b8d0f24b5382ec8e646818e41b", + "0ca8b31500144c53bf16c4af69667cbd", + "510b002fcda944ee9d148c8ab089cb8a", + "07c84a778f3b4be586fb0cb0946d03a0", + "3c96dc03c61348b3b014ba98427a8431", + "7783236eae3c48749c2bce6f8facef42", + "a0418596a12048d7b1dc37fd34000a01" + ] + }, + "id": "yfGNjMBFWEUk", + "outputId": "09f0c28b-d28e-4a7a-8afe-8797da44ad6d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5ff2865d0c0240909b5123a16cab4847", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Dl Completed...: 0%| | 0/4 [00:00 leads to recompilation & only works on single device\n", + "\n", + "@jax.jit\n", + "def get_preds(vs, inputs):\n", + " print('retrigger compilation', inputs.shape)\n", + " return model.apply(vs, inputs)\n", + "\n", + "ds = tfds.load(dataset_name, split='test')\n", + "ds = ds.batch(per_device_batch_size, drop_remainder=False)\n", + "\n", + "correct = total = 0\n", + "for batch in ds.as_numpy_iterator():\n", + " preds = get_preds(vs, batch['image'])\n", + " total += len(batch['label'])\n", + " correct += (batch['label'] == preds.argmax(axis=1)).sum()\n", + "\n", + "correc = correct.item()\n", + "correct, total, correct / total" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dlJuEBcLKY94", + "outputId": "e94cf79c-a033-4bc3-a086-75ecd8bd21f0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "retrigger compilation (512, 28, 28, 1)\n" + ] + }, + { + "data": { + "text/plain": [ + "(DeviceArray(814, dtype=int32), 8192, DeviceArray(0.09936523, dtype=float32))" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# when the remainder is dropped, we can use multiple devices and avoid\n", + "# recompilations\n", + "# => but results are incorrect\n", + "\n", + "@jax.pmap\n", + "def get_preds(vs, inputs):\n", + " print('retrigger compilation', inputs.shape)\n", + " return model.apply(vs, inputs)\n", + "\n", + "ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))\n", + "# This `drop_remainder=True` is required so we can do a second batch level.\n", + "ds = ds.batch(per_device_batch_size, drop_remainder=True)\n", + "# This `drop_remainder=True` is required so we can avoid a recompilation.\n", + "ds = ds.batch(jax.local_device_count(), drop_remainder=True)\n", + "\n", + "correct = total = 0\n", + "for batch in ds.as_numpy_iterator():\n", + " preds = get_preds(vs_p, batch['image'])\n", + " total += len(batch['label'].flatten())\n", + " correct += (batch['label'] == preds.argmax(axis=-1)).sum()\n", + "\n", + "correc = correct.item()\n", + "correct, total, correct / total" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vfu54P0pJwEH" + }, + "source": [ + "### The solution: padding" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LIkNUHsfXKCp" + }, + "source": [ + "#### Manual implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "I1hg8paaasXj", + "outputId": "2e6c611d-357e-4e51-99d8-47c24d785b11" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "retrigger compilation (512, 28, 28, 1)\n" + ] + }, + { + "data": { + "text/plain": [ + "(980, 10000, 0.098)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# manually padding\n", + "# => precise & allows for data parallelism\n", + "\n", + "@jax.pmap\n", + "def get_preds(vs, inputs):\n", + " print('retrigger compilation', inputs.shape)\n", + " return model.apply(vs, inputs)\n", + "\n", + "ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))\n", + "per_host_batch_size = per_device_batch_size * jax.local_device_count()\n", + "ds = ds.batch(per_host_batch_size, drop_remainder=False)\n", + "\n", + "shard = lambda x: einops.rearrange(\n", + " x, '(d b) ... -> d b ...', d=jax.local_device_count())\n", + "unshard = lambda x: einops.rearrange(x, 'd b ... -> (d b) ...')\n", + "\n", + "correct = total = 0\n", + "for batch in ds.as_numpy_iterator():\n", + " images = batch['image']\n", + " n = len(images)\n", + " padding = np.zeros([per_host_batch_size - n, *images.shape[1:]], images.dtype)\n", + " padded_images = np.concatenate([images, padding])\n", + " preds = unshard(get_preds(vs_p, shard(padded_images)))[:n]\n", + " total += n\n", + " correct += (batch['label'] == preds.argmax(axis=-1)).sum()\n", + "\n", + "correct = correct.item()\n", + "correct, total, correct / total" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Wh6CymyjXQ-a" + }, + "source": [ + "#### Using `pad_shard_unpad()`" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "WomlaW3MO3VX" + }, + "outputs": [], + "source": [ + "# TODO(andstein) REMOVE THIS CELL BEFORE SUBMITTING.\n", + "\n", + "def pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=()):\n", + " def pad_shard_unpad_wrapper(*args, min_device_batch=None, **kw):\n", + " d = jax.local_device_count() # d = devices, b = batch\n", + " batch_sizes = (\n", + " {a.shape[0] for i, a in enumerate(args) if i not in static_argnums} |\n", + " {v.shape[0] for k, v in kw.items() if k not in static_argnames})\n", + " assert len(batch_sizes) == 1, f\"Inconsistent batch-sizes: {batch_sizes}\"\n", + " b = batch_sizes.pop()\n", + "\n", + " def maybe_pad(x, actually_pad=True):\n", + " if not actually_pad: return x # For call-site convenience below.\n", + " _, *shape = x.shape\n", + " db, rest = divmod(b, d)\n", + " if rest:\n", + " x = np.concatenate([x, np.zeros((d - rest, *shape), x.dtype)], axis=0)\n", + " db += 1\n", + " if min_device_batch and db < min_device_batch:\n", + " x = np.concatenate(\n", + " [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)])\n", + " db = min_device_batch\n", + " return x.reshape(d, db, *shape)\n", + "\n", + " args = [maybe_pad(a, i not in static_argnums) for i, a in enumerate(args)]\n", + " kw = {k: maybe_pad(v, k not in static_argnames) for k, v in kw.items()}\n", + " out = wrapped(*args, **kw)\n", + "\n", + " def unpad(x):\n", + " # Transfer back before cutting, to reduce on-device shape diversity.\n", + " return jax.device_get(x).reshape([np.prod(x.shape[:2]), *x.shape[2:]])[:b]\n", + " return jax.tree_map(unpad, out)\n", + "\n", + " return pad_shard_unpad_wrapper\n", + "\n", + "\n", + "flax.jax_utils.pad_shard_unpad = pad_shard_unpad" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pQX__5DfEX9g", + "outputId": "71017214-c4ce-4da0-8db5-9300dba79c3a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "retrigger compilation (512, 28, 28, 1)\n" + ] + }, + { + "data": { + "text/plain": [ + "(980, 10000, 0.098)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# same as before, but using @pad_shard_unshard decorator\n", + "\n", + "# manually padding\n", + "# => precise & allows for data parallelism\n", + "\n", + "@jax.pmap\n", + "def get_preds(vs, inputs):\n", + " print('retrigger compilation', inputs.shape)\n", + " return model.apply(vs, inputs)\n", + "\n", + "ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))\n", + "per_host_batch_size = per_device_batch_size * jax.local_device_count()\n", + "ds = ds.batch(per_host_batch_size, drop_remainder=False)\n", + "\n", + "correct = total = 0\n", + "for batch in ds.as_numpy_iterator():\n", + " preds = flax.jax_utils.pad_shard_unpad(get_preds)(\n", + " vs, batch['image'], min_device_batch=per_device_batch_size)\n", + " total += len(batch['image'])\n", + " correct += (batch['label'] == preds.argmax(axis=-1)).sum()\n", + "\n", + "correct = correct.item()\n", + "correct, total, correct / total" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ptn8NQeAXbeL" + }, + "source": [ + "#### Multi-host complications" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "MjtmUUjWPV1X", + "outputId": "70ee173a-dcdf-4136-a3e0-6685c09f8198" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "retrigger compilation (512, 28, 28, 1)\n" + ] + }, + { + "data": { + "text/plain": [ + "(980, 10000, 0.098)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# infinite zero padding\n", + "\n", + "def with_infinite_padding(dataset):\n", + " \"\"\"Adds \"infinite padding\" to the dataset.\"\"\"\n", + " filler_element = tf.nest.map_structure(\n", + " lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec)\n", + " filler_element['mask'] = [False]\n", + " filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)\n", + " dataset = dataset.map(\n", + " lambda features: dict(mask=True, **features),\n", + " num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", + " return dataset.concatenate(filler_dataset.repeat(None))\n", + "\n", + "@jax.pmap\n", + "def get_preds(vs, inputs):\n", + " print('retrigger compilation', inputs.shape)\n", + " return model.apply(vs, inputs)\n", + "\n", + "count_p = jax.pmap(\n", + " lambda mask: jax.lax.psum(mask.sum(), axis_name='batch'),\n", + " axis_name='batch',\n", + ")\n", + "count_correct_p = jax.pmap(\n", + " lambda labels, preds, mask:\n", + " jax.lax.psum((mask & (labels == preds)).sum(), axis_name='batch'),\n", + " axis_name='batch',\n", + ")\n", + "\n", + "ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))\n", + "ds = with_infinite_padding(ds).batch(per_device_batch_size).batch(jax.local_device_count())\n", + "\n", + "correct = total = 0\n", + "for batch in ds.as_numpy_iterator():\n", + " n = count_p(batch['mask'])[0].item() # adds sync barrier\n", + " if not n: break\n", + "\n", + " preds = get_preds(vs, batch['image']).argmax(axis=-1)\n", + " total += n\n", + " correct += count_correct_p(batch['label'], preds, batch['mask'])[0]\n", + "\n", + "correct = correct.item()\n", + "correct, total, correct / total" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "flax full_eval HOWTO", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "07c84a778f3b4be586fb0cb0946d03a0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0ca8b31500144c53bf16c4af69667cbd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "276324b8d0f24b5382ec8e646818e41b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3c96dc03c61348b3b014ba98427a8431": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "510b002fcda944ee9d148c8ab089cb8a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5ff2865d0c0240909b5123a16cab4847": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_6f7316963aba4fd4b01fbe5f6c0e62f7", + "IPY_MODEL_8712c2f9196a48b19f5a362eed9b96cd", + "IPY_MODEL_e7102cf7f0e342c1a26a2042927d621c" + ], + "layout": "IPY_MODEL_276324b8d0f24b5382ec8e646818e41b" + } + }, + "6f7316963aba4fd4b01fbe5f6c0e62f7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0ca8b31500144c53bf16c4af69667cbd", + "placeholder": "​", + "style": "IPY_MODEL_510b002fcda944ee9d148c8ab089cb8a", + "value": "Dl Completed...: 100%" + } + }, + "7783236eae3c48749c2bce6f8facef42": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8712c2f9196a48b19f5a362eed9b96cd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_07c84a778f3b4be586fb0cb0946d03a0", + "max": 4, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_3c96dc03c61348b3b014ba98427a8431", + "value": 4 + } + }, + "a0418596a12048d7b1dc37fd34000a01": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e7102cf7f0e342c1a26a2042927d621c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7783236eae3c48749c2bce6f8facef42", + "placeholder": "​", + "style": "IPY_MODEL_a0418596a12048d7b1dc37fd34000a01", + "value": " 4/4 [00:00<00:00, 10.51 file/s]" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/flax/jax_utils.py b/flax/jax_utils.py index cc6fefba0c..f3f089cb7b 100644 --- a/flax/jax_utils.py +++ b/flax/jax_utils.py @@ -243,3 +243,71 @@ def body_wrapper(c, xs): c, ys = _scan_nd(body_wrapper, init, xs, n=len(axis), unroll=unroll) ys = jax.tree_map(transpose_out, ys) return c, ys + + +# Copied from https://github.com/google-research/big_vision +def pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=()): + """Wraps a function with code that pads, shards, then un-shards, un-pads. + + Args: + wrapped: the function to be wrapped. Signature is `params, *args, *kwargs`. + static_argnums: indices of arguments to `wrapped` that should _not_ be + padded and sharded, but instead be forwarded as-is. The default is (0,) + because by far the most common use-case is to pass `params` first. + static_argnames: names of kwargs to `wrapped` that should _not_ be padded + and sharded, but instead be forwarded as-is. + + Returns: + A new function that pads and shards its arguments before passing them to + the wrapped function, and un-shards and un-pads the returned pytree. + + This is useful for calling a pmap'ed function with inputs that aren't + divisible by the number of devices. A typical use is: + @pad_shard_unpad + @jax.pmap + def forward(params, x): ... + + Notes: + The padding is done in host-memory before being passed to the function, and + the values returned by the function are transferred back to host memory. + + The returned function is augmented with a new keyword-only argument + `min_device_batch` that, if specified, forces padding inputs to at least + this size per device. This can be useful to avoid recompiles for the last + batch and reduce memory fragmentation. + + For more information refer to + https://flax.readthedocs.io/en/latest/howtos/full_eval.html + """ + + def pad_shard_unpad_wrapper(*args, min_device_batch=None, **kw): + d = jax.local_device_count() # d = devices, b = batch + batch_sizes = ( + {a.shape[0] for i, a in enumerate(args) if i not in static_argnums} | + {v.shape[0] for k, v in kw.items() if k not in static_argnames}) + assert len(batch_sizes) == 1, f"Inconsistent batch-sizes: {batch_sizes}" + b = batch_sizes.pop() + + def maybe_pad(x, actually_pad=True): + if not actually_pad: return x # For call-site convenience below. + _, *shape = x.shape + db, rest = divmod(b, d) + if rest: + x = np.concatenate([x, np.zeros((d - rest, *shape), x.dtype)], axis=0) + db += 1 + if min_device_batch and db < min_device_batch: + x = np.concatenate( + [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)]) + db = min_device_batch + return x.reshape(d, db, *shape) + + args = [maybe_pad(a, i not in static_argnums) for i, a in enumerate(args)] + kw = {k: maybe_pad(v, k not in static_argnames) for k, v in kw.items()} + out = wrapped(*args, **kw) + + def unpad(x): + # Transfer back before cutting, to reduce on-device shape diversity. + return jax.device_get(x).reshape([np.prod(x.shape[:2]), *x.shape[2:]])[:b] + return jax.tree_map(unpad, out) + + return pad_shard_unpad_wrapper diff --git a/tests/jax_utils_test.py b/tests/jax_utils_test.py new file mode 100644 index 0000000000..b173f967a7 --- /dev/null +++ b/tests/jax_utils_test.py @@ -0,0 +1,94 @@ +# Copyright 2022 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for flax.jax_utils.""" + +from functools import partial +from absl.testing import parameterized +import chex +from flax import jax_utils +import jax +import numpy as np +import tensorflow as tf + +NDEV = 4 + + +def setUpModule(): + chex.set_n_cpu_devices(NDEV) + + +class PadShardUnpadTest(chex.TestCase, tf.test.TestCase): + BATCH_SIZES = [NDEV, NDEV + 1, NDEV - 1, 5 * NDEV, 5 * NDEV + 1, 5 * NDEV - 1] + DTYPES = [np.float32, np.uint8, jax.numpy.bfloat16, np.int32] + + def tearDown(self): + chex.clear_trace_counter() + super().tearDown() + + @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) + def test_basics(self, dtype, bs): + # Just tests that basic calling works without exploring caveats. + @partial(jax_utils.pad_shard_unpad, static_argnums=()) + def add(a, b): + return a + b + + x = np.arange(bs, dtype=dtype) + y = add(x, 10*x) + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) + + @parameterized.parameters(DTYPES) + def test_min_device_batch_avoids_recompile(self, dtype): + @partial(jax_utils.pad_shard_unpad, static_argnums=()) + @jax.jit + @chex.assert_max_traces(n=1) + def add(a, b): + return a + b + + chex.clear_trace_counter() + + for bs in self.BATCH_SIZES: + x = np.arange(bs, dtype=dtype) + y = add(x, 10*x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x)) + + @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) + def test_static_argnum(self, dtype, bs): + @partial(jax_utils.pad_shard_unpad, static_argnums=(1,)) + def add(a, b): + return a + b + + x = np.arange(bs, dtype=dtype) + y = add(x, 10) + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(x + 10)) + + @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) + def test_static_argnames(self, dtype, bs): + # In this test, leave static_argnums at the default value too, in order to + # test the default/most canonical path where `params` are the first arg. + @partial(jax_utils.pad_shard_unpad, static_argnames=('b',)) + def add(params, a, *, b): + return params * a + b + + x = np.arange(bs, dtype=dtype) + y = add(5, x, b=10) + chex.assert_type(y.dtype, x.dtype) + np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10)) + + +if __name__ == '__main__': + tf.test.main() \ No newline at end of file