``jax.checkpint()`` (aliased ``jax.remat()``) helps control which intermediates are saved on the forward pass and which recomputed intermediates on the backward pass.

With this, ``jax.grad(f)(x)`` forward pass stores jacobian coefficient and other intermediates values (residuals) to use during the backward pass.

In [1]:
import jax 
import jax.ad_checkpoint
import jax.numpy as jnp 

def g(W, x):
    y = jnp.dot(W, x)
    return jnp.sin(y)

def f(W1, W2, W3, x):
    x = g(W1, x)
    x = g(W2, x)
    x = g(W3, x)
    return x 

W1 = jnp.ones((5,4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspecting the RESIDUAL values to be saved on the forward pass
# if you were to evaluate jax.grad(f)(W1, W2, W3, x)

from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)

f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\2016089115.py:7 (g)
f32[5] output of cos from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\2016089115.py:7 (g)
f32[6] output of sin from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\2016089115.py:7 (g)
f32[6] output of cos from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\2016089115.py:7 (g)
f32[7] output of cos from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\2016089115.py:7 (g)


By applying ``jax.checkpint()`` to sub-functions, as a decorator or at specific application sites, you force JAX not to save any of that sub-functions residual.

In [2]:
def f2(W1, W2, W3, x):
    x = jax.checkpoint(g)(W1, x)
    x = jax.checkpoint(g)(W2, x)
    x = jax.checkpoint(g)(W3, x)
    return x 
jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)

f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\2016089115.py:7 (g)
f32[6] output of sin from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\2016089115.py:7 (g)


To control which values are saveable without having to edit the definition of the function to be differentiated, __rematerialization policy__ can be used.

It saves only the results of ``dot`` operation with __no batch dimensions__

In [3]:
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)

f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of dot_general from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\2016089115.py:6 (g)
f32[6] output of dot_general from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\2016089115.py:6 (g)
f32[7] output of dot_general from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\2016089115.py:6 (g)


``policies`` can also be used to refer to intermediate values that needs saving. It is done by naming them using ``jax.ad_checkpoint.checkpoint_name()``

In [4]:
from jax.ad_checkpoint import checkpoint_name 

def f4(W1, W2, W3, x):
    x = checkpoint_name(g(W1, x), name='a')
    x = checkpoint_name(g(W2, x), name='b')
    x = checkpoint_name(g(W3, x), name='c')
    return x 

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)

f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] named 'a' from C:\Users\konar\AppData\Local\Temp\ipykernel_19492\1817552606.py:4 (f4)


To get closer look at what's going on using a custom ``print_fwd_bwd`` utility

In [5]:
from jax.tree_util import tree_flatten, tree_unflatten 

from rich.console import Console 
import rich.jupyter
from rich.table import Table 
import rich.text 

def print_fwd_bwd(f, *args, **kwargs) -> None:
    args, in_tree = tree_flatten((args, kwargs))

    def f_(*args):
        args, kwargs = tree_unflatten(in_tree, args)
        return f(*args, **kwargs)
    
    fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr
    y, f_vjp = jax.vjp(f_, *args)
    res, in_tree = tree_flatten(f_vjp)

    def g_(*args):
        *res, y = args 
        f_vjp = tree_unflatten(in_tree, res)
        return f_vjp(y)
    
    bwd = jax.make_jaxpr(g_)(*res, y).jaxpr 

    table = Table(show_header= False, show_lines= True, padding=(1, 2, 0, 2), box=None)
    table.add_row("[bold green]forward computation:",
                  "[bold green]backward computation:")
    table.add_row(rich.text.Text.from_ansi(str(fwd)),
                  rich.text.Text.from_ansi(str(bwd)))
    console = Console(width=240, force_jupyter=True)
    console.print(table)

def _renderable_repr(self):
    return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr


In [6]:
#without using jax.checkpoint
print_fwd_bwd(f, W1, W2, W3, x)

In [7]:
# Using `jax.checkpoint` with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)

Applying ``jax.jit()`` to a functions that contains  ``jax.grad()`` call - XLA will automatically optimize the computation, including decision about when to compute or rematerialize values.

As a result, ``jax.checkpoint()`` __often is not needed for differentiated functions udner a ``jax.jit()``__. 