Skip to content

Commit

Permalink
Applied review suggestsions
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula committed Oct 8, 2021
1 parent f03baae commit 3938018
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2721,20 +2721,12 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
Note that in some cases the XLA compiler may do its optimization that affect
the meory usage, e.g., common-subexpression optimization, fusion, or
even rematerialization. For example, if the code uses only element-wise
operations, like in our example, XLA may fuse both the forward and backward
pass, or may rematerialize aggressively, resulting in low memory usage even
without ``jax.checkpoint``. In fact, in some such cases ``jax.checkpoint`` may
hinder the compiler and you may see larger memory usage. For complex examples,
the effect of ``jax.checkpoint`` is likely to be more significant than the
XLA optimizations.
The best way to use ``jax.checkpoint`` is to experiment with its placement
on sub-computations. For example, ``lambda x: f(jax.checkpoint(g)(x))`` is
likely to have lower memory usage under ``jax.grad`` than when not using
``jax.checkpoint``, or than when using it on ``f`` or the whole function.
While ``jax.checkpoint`` controls what values are stored from the forward-pass
to be used on the backward pass, the total amount of memory required to
evaluate a function or its VJP depends on many additional internal details of
that function. Those details include which numerical primitives are used,
how they're composed, where jit and control flow primitives like scan
are used, and other factors.
The :func:`jax.checkpoint` decorator can be applied recursively to express
sophisticated autodiff rematerialization strategies. For example:
Expand Down

0 comments on commit 3938018

Please sign in to comment.