From 66ba734882daf8206e41c96495a9ea3c53082bb1 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 3 Jun 2020 09:46:00 -0700 Subject: [PATCH] Add note to docs describing how pytree arguments work. (#3284) Addresses #3095. I'm not sure if we wanna link to this from API docstrings. This also subsumes the original pytrees notebook. --- docs/index.rst | 2 +- docs/notebooks/JAX_pytrees.ipynb | 329 ------------------------------- docs/pytrees.rst | 251 +++++++++++++++++++++++ 3 files changed, 252 insertions(+), 330 deletions(-) delete mode 100644 docs/notebooks/JAX_pytrees.ipynb create mode 100644 docs/pytrees.rst diff --git a/docs/index.rst b/docs/index.rst index ab46511e59e8..b7e9725edabe 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -23,7 +23,6 @@ For an introduction to JAX, start at the notebooks/Common_Gotchas_in_JAX notebooks/Custom_derivative_rules_for_Python_code - notebooks/JAX_pytrees notebooks/XLA_in_Python notebooks/How_JAX_primitives_work notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -39,6 +38,7 @@ For an introduction to JAX, start at the concurrency gpu_memory_allocation profiling + pytrees rank_promotion_warning type_promotion diff --git a/docs/notebooks/JAX_pytrees.ipynb b/docs/notebooks/JAX_pytrees.ipynb deleted file mode 100644 index 2f6128bc4f99..000000000000 --- a/docs/notebooks/JAX_pytrees.ipynb +++ /dev/null @@ -1,329 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "JAX_pytrees.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "g_vouXWulcNh", - "colab_type": "text" - }, - "source": [ - "# JAX pytrees\n", - "\n", - "Date: October 2019" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lu00ShwgaPEW", - "colab_type": "text" - }, - "source": [ - "**This is primarily JAX internal documentation, end-users are not supposed to need to understand this to use JAX, except when registering new user-defined container types with JAX. Some of these details may change.**\n", - "\n", - "Python has a lot of container data types (list, tuple, dict, namedtuple, etc.), and users sometimes define their own. To keep the JAX internals simpler while supporting lots of container types, we canonicalize nested containers into flat lists of numeric or array types at the `api.py` boundary (and also in control flow primitives). That way `grad`, `jit`, `vmap` etc., can handle user functions that accept and return these containers, while all the other parts of the system can operate on functions that only take (multiple) array arguments and always return a flat list of arrays. \n", - "\n", - "We refer to a recursive structured value whose leaves are basic types as a `pytree`. When JAX flattens a pytree it will produce a list of leaves and a `treedef` object that encodes the structure of the original value. The `treedef` can then be used to construct a matching structured value after transforming the leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we handle them assuming referential transparency and that they can't contain reference cycles. \n", - "\n", - "Here is a simple example:\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "X8DlAmOMmufl", - "colab_type": "code", - "outputId": "f5069593-b36e-4f2d-b8f0-7642e7034bbd", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 85 - } - }, - "source": [ - "from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node\n", - "from jax import numpy as np\n", - "\n", - "# The structured value to be transformed\n", - "value_structured = [1., (2., 3.)]\n", - "\n", - "# The leaves in value_flat correspond to the `*` markers in value_tree\n", - "value_flat, value_tree = tree_flatten(value_structured)\n", - "print(\"value_flat={}\\nvalue_tree={}\".format(value_flat, value_tree))\n", - "\n", - "# Transform the flt value list using an element-wise numeric transformer\n", - "transformed_flat = list(map(lambda v: v * 2., value_flat))\n", - "print(\"transformed_flat={}\".format(transformed_flat))\n", - "\n", - "# Reconstruct the structured output, using the original \n", - "transformed_structured = tree_unflatten(value_tree, transformed_flat)\n", - "print(\"transformed_structured={}\".format(transformed_structured))" - ], - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "text": [ - "value_flat=[1.0, 2.0, 3.0]\n", - "value_tree=PyTreeDef(list, [*,PyTreeDef(tuple, [*,*])])\n", - "transformed_flat=[2.0, 4.0, 6.0]\n", - "transformed_structured=[2.0, (4.0, 6.0)]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sgUJpiXSsRSi", - "colab_type": "text" - }, - "source": [ - "Pytrees containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ViXja8YxsXZC", - "colab_type": "code", - "outputId": "ff8120b2-f1fc-4647-9e0d-c35ee87bdd2e", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 459 - } - }, - "source": [ - "from collections import namedtuple\n", - "Point = namedtuple('Point', ['x', 'y'])\n", - "\n", - "example_containers = [\n", - " (1., [2., 3.]),\n", - " (1., {'b': 2., 'a': 3.}),\n", - " 1.,\n", - " None,\n", - " np.zeros(2),\n", - " Point(1., 2.)\n", - "]\n", - "def show_example(structured):\n", - " flat, tree = tree_flatten(structured)\n", - " unflattened = tree_unflatten(tree, flat)\n", - " print(\"structured={}\\n flat={}\\n tree={}\\n unflattened={}\".format(\n", - " structured, flat, tree, unflattened))\n", - " \n", - "for structured in example_containers:\n", - " show_example(structured)\n", - " " - ], - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "text": [ - "structured=(1.0, [2.0, 3.0])\n", - " flat=[1.0, 2.0, 3.0]\n", - " tree=PyTreeDef(tuple, [*,PyTreeDef(list, [*,*])])\n", - " unflattened=(1.0, [2.0, 3.0])\n", - "structured=(1.0, {'b': 2.0, 'a': 3.0})\n", - " flat=[1.0, 3.0, 2.0]\n", - " tree=PyTreeDef(tuple, [*,PyTreeDef(dict[['a', 'b']], [*,*])])\n", - " unflattened=(1.0, {'a': 3.0, 'b': 2.0})\n", - "structured=1.0\n", - " flat=[1.0]\n", - " tree=*\n", - " unflattened=1.0\n", - "structured=None\n", - " flat=[]\n", - " tree=PyTreeDef(None, [])\n", - " unflattened=None\n" - ], - "name": "stdout" - }, - { - "output_type": "stream", - "text": [ - "structured=[0. 0.]\n", - " flat=[_FilledConstant([0., 0.], dtype=float32)]\n", - " tree=*\n", - " unflattened=[0. 0.]\n", - "structured=Point(x=1.0, y=2.0)\n", - " flat=[1.0, 2.0]\n", - " tree=PyTreeDef(namedtuple[], [*,*])\n", - " unflattened=Point(x=1.0, y=2.0)\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "f5iYkKRx2ILR", - "colab_type": "text" - }, - "source": [ - "## Pytrees are extensible" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Cb35Y8vBtVKp", - "colab_type": "text" - }, - "source": [ - "By default, any part of a structured value that is not recognized as an internal pytree node is treated as a leaf (and such containers could not be passed to JAX-traceable functions):\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "17JyT_arth7P", - "colab_type": "code", - "outputId": "142485b6-c7fb-4bfd-b9a3-c01e8be7e9e4", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 85 - } - }, - "source": [ - "class Special(object):\n", - " def __init__(self, x, y):\n", - " self.x = x\n", - " self.y = y\n", - " \n", - " def __repr__(self):\n", - " return \"Special(x={}, y={})\".format(self.x, self.y)\n", - " \n", - "\n", - "show_example(Special(1., 2.))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "structured=Special(x=1.0, y=2.0)\n", - " flat=[Special(x=1.0, y=2.0)]\n", - " tree=*\n", - " unflattened=Special(x=1.0, y=2.0)\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3y9NECzRthKi", - "colab_type": "text" - }, - "source": [ - "The set of Python types that are considered internal pytree nodes is extensible, through a global registry of types. Values of registered types\n", - "are traversed recursively:\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "3Emk3EN5uPMr", - "colab_type": "code", - "outputId": "4b5b3ff6-6b80-424c-a6c0-8da97b943a7b", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 85 - } - }, - "source": [ - "class RegisteredSpecial(Special):\n", - " def __repr__(self):\n", - " return \"RegisteredSpecial(x={}, y={})\".format(self.x, self.y)\n", - "\n", - "def special_flatten(v):\n", - " \"\"\"Specifies a flattening recipe.\n", - " \n", - " Params:\n", - " v: the value of registered type to flatten.\n", - " Returns: \n", - " a pair of an iterable with the children to be flattened recursively,\n", - " and some opaque auxiliary data to pass back to the unflattening recipe.\n", - " The auxiliary data is stored in the treedef for use during unflattening.\n", - " The auxiliary data could be used, e.g., for dictionary keys.\n", - " \"\"\"\n", - " children = (v.x, v.y)\n", - " aux_data = None\n", - " return (children, aux_data)\n", - "\n", - "def special_unflatten(aux_data, children):\n", - " \"\"\"Specifies an unflattening recipe.\n", - " \n", - " Params:\n", - " aux_data: the opaque data that was specified during flattening of the \n", - " current treedef.\n", - " children: the unflattened children\n", - " \n", - " Returns:\n", - " a re-constructed object of the registered type, using the specified \n", - " children and auxiliary data.\n", - " \"\"\"\n", - " return RegisteredSpecial(*children)\n", - "\n", - "# Global registration\n", - "register_pytree_node(\n", - " RegisteredSpecial,\n", - " special_flatten, # tell JAX what are the children nodes\n", - " special_unflatten # tell JAX how to pack back into a RegisteredSpecial\n", - ")\n", - "\n", - "show_example(RegisteredSpecial(1., 2.))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "structured=RegisteredSpecial(x=1.0, y=2.0)\n", - " flat=[1.0, 2.0]\n", - " tree=PyTreeDef([None], [*,*])\n", - " unflattened=RegisteredSpecial(x=1.0, y=2.0)\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TIBrH5KaxImR", - "colab_type": "text" - }, - "source": [ - "JAX needs sometimes to compare treedef for equality. Therefore care must be taken to ensure that the auxiliary data specified in the flattening recipe supports a meaningful equality comparison. \n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qoi69-I64_qe", - "colab_type": "text" - }, - "source": [ - "The whole set of functions for operating on pytrees are in the [tree_util module](https://jax.readthedocs.io/en/latest/jax.tree_util.html).\n" - ] - } - ] -} \ No newline at end of file diff --git a/docs/pytrees.rst b/docs/pytrees.rst new file mode 100644 index 000000000000..c290bc688e9f --- /dev/null +++ b/docs/pytrees.rst @@ -0,0 +1,251 @@ +Pytrees +======== + +What is a pytree? +^^^^^^^^^^^^^^^^^ + +In JAX, a pytree is **a container of leaf elements and/or more pytrees**. +Containers include lists, tuples, and dicts (JAX can be extended to consider +other container types as pytrees, see `Extending pytrees`_ below). A leaf +element is anything that's not a pytree, e.g. an array. In other words, a pytree +is just **a possibly-nested standard or user-registered Python container**. If +nested, note that the container types do not need to match. A single "leaf", +i.e. a non-container object, is also considered a pytree. + +Example pytrees:: + + [1, "a", object()] # 3 leaves + + (1, (2, 3), ()) # 3 leaves + + [1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves + +Pytrees and JAX functions +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Many JAX functions, including all function transformations, operate over pytrees +of arrays (other leaf types are sometimes allowed as well). Transformations are +only applied to the leaf arrays while preserving the original pytree structure; +for example, ``vmap`` and ``pmap`` only map over arrays, but automatically map +over arrays inside of standard Python sequences, and can return mapped Python +sequences. + +Applying optional parameters to pytrees +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Some JAX function transformations take optional parameters that specify how +certain input or output values should be treated (e.g. the ``in_axes`` and +``out_axes`` arguments to ``vmap``). These parameters are also pytrees, and the +leaf values are "matched up" with the corresponding input or output leaf arrays. +For example, if we pass the following input to vmap (note that the input +arguments to a function are considered a tuple):: + + (a1, {"k1": a2, "k2": a3}) + +We can use the following ``in_axes`` pytree to specify that only the "k2" +argument is mapped (axis=0) and the rest aren't mapped over (axis=None):: + + (None, {"k1": None, "k2": 0}) + +Note that the optional parameter pytree structure must match that of the main +input pytree. However, the optional parameters can optionally be specified as a +"prefix" pytree, meaning that a single leaf value can be applied to an entire +sub-pytree. For example, if we have the same ``vmap`` input as above, but wish +to only map over the dictionary argument, we can use:: + + (None, 0) # equivalent to (None, {"k1": 0, "k2": 0}) + +Or, if want every argument to be mapped, we can simply write a single leaf value +that is applied over the entire argument tuple pytree:: + + 0 + +This happens to be the default ``in_axes`` value! + +The same logic applies to other optional parameters that refer to specific input +or output values of a transformed function, e.g. ``vmap``'s ``out_axes`` and +``pmaps``'s ``in_axes``. + + +Developer information +^^^^^^^^^^^^^^^^^^^^^^ + +*This is primarily JAX internal documentation, end-users are not supposed to need +to understand this to use JAX, except when registering new user-defined +container types with JAX. Some of these details may change.* + +Internal pytree handling +------------------------ + +JAX canonicalizes pytrees into flat lists of numeric or array types at the +`api.py` boundary (and also in control flow primitives). This keeps downstream +JAX internals simpler: `vmap` etc. can handle user functions that accept and +return Python containers, while all the other parts of the system can operate on +functions that only take (multiple) array arguments and always return a flat +list of arrays. + +When JAX flattens a pytree it will produce a list of leaves and a `treedef` +object that encodes the structure of the original value. The `treedef` can then +be used to construct a matching structured value after transforming the +leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we +handle them assuming referential transparency and that they can't contain +reference cycles. + +Here is a simple example:: + + from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node + from jax import numpy as np + + # The structured value to be transformed + value_structured = [1., (2., 3.)] + + # The leaves in value_flat correspond to the `*` markers in value_tree + value_flat, value_tree = tree_flatten(value_structured) + print("value_flat={}\nvalue_tree={}".format(value_flat, value_tree)) + + # Transform the flt value list using an element-wise numeric transformer + transformed_flat = list(map(lambda v: v * 2., value_flat)) + print("transformed_flat={}".format(transformed_flat)) + + # Reconstruct the structured output, using the original + transformed_structured = tree_unflatten(value_tree, transformed_flat) + print("transformed_structured={}".format(transformed_structured)) + + # Output: + # value_flat=[1.0, 2.0, 3.0] + # value_tree=PyTreeDef(list, [*,PyTreeDef(tuple, [*,*])]) + # transformed_flat=[2.0, 4.0, 6.0] + # transformed_structured=[2.0, (4.0, 6.0)] + +By default, Pytrees containers can be lists, tuples, dicts, namedtuple, None, +OrderedDict. Other types of values, including numeric and ndarray values, are +treated as leaves:: + + from collections import namedtuple + Point = namedtuple('Point', ['x', 'y']) + + example_containers = [ + (1., [2., 3.]), + (1., {'b': 2., 'a': 3.}), + 1., + None, + np.zeros(2), + Point(1., 2.) + ] + def show_example(structured): + flat, tree = tree_flatten(structured) + unflattened = tree_unflatten(tree, flat) + print("structured={}\n flat={}\n tree={}\n unflattened={}".format( + structured, flat, tree, unflattened)) + + for structured in example_containers: + show_example(structured) + + # Output: + # structured=(1.0, [2.0, 3.0]) + # flat=[1.0, 2.0, 3.0] + # tree=PyTreeDef(tuple, [*,PyTreeDef(list, [*,*])]) + # unflattened=(1.0, [2.0, 3.0]) + # structured=(1.0, {'b': 2.0, 'a': 3.0}) + # flat=[1.0, 3.0, 2.0] + # tree=PyTreeDef(tuple, [*,PyTreeDef(dict[['a', 'b']], [*,*])]) + # unflattened=(1.0, {'a': 3.0, 'b': 2.0}) + # structured=1.0 + # flat=[1.0] + # tree=* + # unflattened=1.0 + # structured=None + # flat=[] + # tree=PyTreeDef(None, []) + # unflattened=None + # structured=[0. 0.] + # flat=[DeviceArray([0., 0.], dtype=float32)] + # tree=* + # unflattened=[0. 0.] + # structured=Point(x=1.0, y=2.0) + # flat=[1.0, 2.0] + # tree=PyTreeDef(namedtuple[], [*,*]) + # unflattened=Point(x=1.0, y=2.0) + +Extending pytrees +----------------- + +By default, any part of a structured value that is not recognized as an internal +pytree node is treated as a leaf (and such containers could not be passed to +JAX-traceable functions):: + + class Special(object): + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "Special(x={}, y={})".format(self.x, self.y) + + + show_example(Special(1., 2.)) + + # Output: + # structured=Special(x=1.0, y=2.0) + # flat=[Special(x=1.0, y=2.0)] + # tree=* + # unflattened=Special(x=1.0, y=2.0) + +The set of Python types that are considered internal pytree nodes is extensible, +through a global registry of types. Values of registered types are traversed +recursively:: + + class RegisteredSpecial(Special): + def __repr__(self): + return "RegisteredSpecial(x={}, y={})".format(self.x, self.y) + + def special_flatten(v): + """Specifies a flattening recipe. + + Params: + v: the value of registered type to flatten. + Returns: + a pair of an iterable with the children to be flattened recursively, + and some opaque auxiliary data to pass back to the unflattening recipe. + The auxiliary data is stored in the treedef for use during unflattening. + The auxiliary data could be used, e.g., for dictionary keys. + """ + children = (v.x, v.y) + aux_data = None + return (children, aux_data) + + def special_unflatten(aux_data, children): + """Specifies an unflattening recipe. + + Params: + aux_data: the opaque data that was specified during flattening of the + current treedef. + children: the unflattened children + + Returns: + a re-constructed object of the registered type, using the specified + children and auxiliary data. + """ + return RegisteredSpecial(*children) + + # Global registration + register_pytree_node( + RegisteredSpecial, + special_flatten, # tell JAX what are the children nodes + special_unflatten # tell JAX how to pack back into a RegisteredSpecial + ) + + show_example(RegisteredSpecial(1., 2.)) + + # Output: + # structured=RegisteredSpecial(x=1.0, y=2.0) + # flat=[1.0, 2.0] + # tree=PyTreeDef([None], [*,*]) + # unflattened=RegisteredSpecial(x=1.0, y=2.0) + +JAX needs sometimes to compare treedef for equality. Therefore care must be +taken to ensure that the auxiliary data specified in the flattening recipe +supports a meaningful equality comparison. + +The whole set of functions for operating on pytrees are in `tree_util module +`_.