Skip to content

Commit

Permalink
tweak jit docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Apr 20, 2024
1 parent 52f5f70 commit b8df23c
Showing 1 changed file with 50 additions and 72 deletions.
122 changes: 50 additions & 72 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,60 +157,37 @@ def jit(
"""Sets up ``fun`` for just-in-time compilation with XLA.
Args:
fun: Function to be jitted. ``fun`` should be a pure function, as
side-effects may only be executed once.
The arguments and return value of ``fun`` should be arrays,
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
Positional arguments indicated by ``static_argnums`` can be anything at
all, provided they are hashable and have an equality operation defined.
Static arguments are included as part of a compilation cache key, which is
why hash and equality operators must be defined.
JAX keeps a weak reference to ``fun`` for use as a compilation cache key,
so the object ``fun`` must be weakly-referenceable. Most :class:`Callable`
objects will already satisfy this requirement.
in_shardings: Pytree of structure matching that of arguments to ``fun``,
with all actual arguments replaced by resource assignment specifications.
It is also valid to specify a pytree prefix (e.g. one value in place of a
whole subtree), in which case the leaves get broadcast to all values in
that subtree.
The ``in_shardings`` argument is optional. JAX will infer the shardings
from the input :py:class:`jax.Array`'s and defaults to replicating the input
if the sharding cannot be inferred.
The valid resource assignment specifications are:
- :py:class:`XLACompatibleSharding`, which will decide how the value
will be partitioned. With this, using a mesh context manager is not
required.
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
it wants.
For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
The size of every dimension has to be a multiple of the total number of
resources assigned to it. This is similar to pjit's in_shardings.
out_shardings: Like ``in_shardings``, but specifies resource
assignment for function outputs. This is similar to pjit's
out_shardings.
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
will use GSPMD's sharding propagation to figure out what the sharding of the
output(s) should be.
static_argnums: An optional int or collection of ints that specify which
positional arguments to treat as static (compile-time constant).
Operations that only depend on static arguments will be constant-folded in
Python (during tracing), and so the corresponding argument values can be
any Python object.
fun: Function to be jitted. ``fun`` should be a pure function.
The arguments and return value of ``fun`` should be arrays, scalar, or
(nested) standard Python containers (tuple/list/dict) thereof. Positional
arguments indicated by ``static_argnums`` can be any hashable type. Static
arguments are included as part of a compilation cache key, which is why
hash and equality operators must be defined. JAX keeps a weak reference to
``fun`` for use as a compilation cache key, so the object ``fun`` must be
weakly-referenceable.
in_shardings: optional, a :py:class:`Sharding` or pytree with
:py:class:`Sharding` leaves and structure that is a tree prefix of the
positional arguments tuple to ``fun``. If provided, the positional
arguments passed to ``fun`` must have shardings that are compatible with
``in_shardings`` or an error is raised, and the compiled computation has
input shardings corresponding to ``in_shardings``. If not provided, the
compiled computation's input shardings are inferred from argument
sharings.
out_shardings: optional, a :py:class:`Sharding` or pytree with
:py:class:`Sharding` leaves and structure that is a tree prefix of the
output of ``fun``. If provided, it has the same effect as applying
corresponding :py:func:`jax.lax.with_sharding_constraint`s to the output
of ``fun``.
static_argnums: optional, an int or collection of ints that specify which
positional arguments to treat as static (trace- and compile-time
constant).
Static arguments should be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and immutable. Calling the jitted function
with different values for these constants will trigger recompilation.
Arguments that are not arrays or containers thereof must be marked as
static.
``__eq__`` are implemented, and immutable. Otherwise they can be arbitrary
Python objects. Calling the jitted function with different values for
these constants will trigger recompilation. Arguments that are not
array-like or containers thereof must be marked as static.
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
arguments are treated as static. If ``static_argnums`` is not provided but
Expand All @@ -221,17 +198,18 @@ def jit(
provided, ``inspect.signature`` is not used, and only actual
parameters listed in either ``static_argnums`` or ``static_argnames`` will
be treated as static.
static_argnames: An optional string or collection of strings specifying
static_argnames: optional, a string or collection of strings specifying
which named arguments to treat as static (compile-time constant). See the
comment on ``static_argnums`` for details. If not
provided but ``static_argnums`` is set, the default is based on calling
``inspect.signature(fun)`` to find corresponding named arguments.
donate_argnums: Specify which positional argument buffers are "donated" to
the computation. It is safe to donate argument buffers if you no longer
need them once the computation has finished. In some cases XLA can make
use of donated buffers to reduce the amount of memory needed to perform a
donate_argnums: optional, collection of integers to specify which positional
argument buffers can be overwritten by the computation and marked deleted
in the caller. It is safe to donate argument buffers if you no longer need
them once the computation has started. In some cases XLA can make use of
donated buffers to reduce the amount of memory needed to perform a
computation, for example recycling one of your input buffers to store a
result. You should not reuse buffers that you donate to a computation, JAX
result. You should not reuse buffers that you donate to a computation; JAX
will raise an error if you try to. By default, no argument buffers are
donated.
Expand All @@ -247,15 +225,16 @@ def jit(
For more details on buffer donation see the
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
donate_argnames: An optional string or collection of strings specifying
donate_argnames: optional, a string or collection of strings specifying
which named arguments are donated to the computation. See the
comment on ``donate_argnums`` for details. If not
provided but ``donate_argnums`` is set, the default is based on calling
``inspect.signature(fun)`` to find corresponding named arguments.
keep_unused: If `False` (the default), arguments that JAX determines to be
unused by `fun` *may* be dropped from resulting compiled XLA executables.
Such arguments will not be transferred to the device nor provided to the
underlying executable. If `True`, unused arguments will not be pruned.
keep_unused: optional boolean. If `False` (the default), arguments that JAX
determines to be unused by `fun` *may* be dropped from resulting compiled
XLA executables. Such arguments will not be transferred to the device nor
provided to the underlying executable. If `True`, unused arguments will
not be pruned.
device: This is an experimental feature and the API is likely to change.
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via :py:func:`jax.devices`.) The default is inherited
Expand All @@ -264,9 +243,8 @@ def jit(
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
inline: Specify whether this function should be inlined into enclosing
jaxprs (rather than being represented as an application of the xla_call
primitive with its own subjaxpr). Default False.
inline: Optional boolean. Specify whether this function should be inlined
into enclosing jaxprs. Default False.
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation.
Expand All @@ -287,8 +265,8 @@ def jit(
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232 0.76827 0.59566 ]
To pass arguments such as ``static_argnames`` when decorating a function, a common
pattern is to use :func:`functools.partial`:
To pass arguments such as ``static_argnames`` when decorating a function, a
common pattern is to use :func:`functools.partial`:
>>> from functools import partial
>>>
Expand Down Expand Up @@ -2470,10 +2448,10 @@ def device_put(
Args:
x: An array, scalar, or (nested) standard Python container thereof.
device: The (optional) :py:class:`Device`, `Sharding`, or a (nested)
`Sharding` in standard Python container (must be a tree prefix of ``x``),
representing the device(s) to which ``x`` should be transferred. If
given, then the result is committed to the device(s).
device: The (optional) :py:class:`Device`, :py:class:`Sharding`, or a
(nested) :py:class:`Sharding` in standard Python container (must be a tree
prefix of ``x``), representing the device(s) to which ``x`` should be
transferred. If given, then the result is committed to the device(s).
Returns:
A copy of ``x`` that resides on ``device``.
Expand Down

0 comments on commit b8df23c

Please sign in to comment.