-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scan with gradient checkpointing #2139
Comments
Support for a single level of gradient checkpointing might also be useful, which requires only one extra forward pass. Apparently the optimal way to do it reduces memory usage to the square root of the number of steps, e.g., per this implementation in TensorFlow: |
There is also a class of algorithms that optimize to a fixed memory budget: https://dl.acm.org/doi/10.1145/347837.347846 (I'm not sure they are worth it over the simpler strategies though.) |
Yep, I realize now that that's what "binomial checkpointing" in particular means. I was originally thinking of something simpler, just using recursion. |
So Diffrax actually implements a The implementation is here: https://github.com/patrick-kidger/diffrax/blob/2b4e4d863c15abc7143919bac7825090bbfe50be/diffrax/misc/bounded_while_loop.py It's worth noting that there are a lot of caveats that need to be worked around in order to make something like this feasible.
In practice most of these details are hidden from an end-user. (You just end up with a funny-looking extra argument to |
A few other reference points for anyone who find this issue:
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union
import jax
import jax.numpy as jnp
Carry = TypeVar('Carry')
Input = TypeVar('Input')
Output = TypeVar('Output')
Func = TypeVar('Func', bound=Callable)
def nested_checkpoint_scan(
f: Callable[[Carry, Input], Tuple[Carry, Output]],
init: Carry,
xs: Input,
length: Optional[int] = None,
*,
nested_lengths: Sequence[int],
scan_fn: typing.ScanFn = jax.lax.scan,
checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,
) -> Tuple[Carry, Output]:
"""A version of lax.scan that supports recursive gradient checkpointing.
The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for
the required `nested_lengths` argument.
The key feature of `nested_checkpoint_scan` is that gradient calculations
require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested
scans, which it achieves by re-evaluating the forward pass
`len(nested_lengths) - 1` times.
`nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a
single element.
Args:
f: function to scan over.
init: initial value.
xs: scanned over values.
length: leading length of all dimensions
nested_lengths: required list of lengths to scan over for each level of
checkpointing. The product of nested_lengths must match length (if
provided) and the size of the leading axis for all arrays in ``xs``.
scan_fn: function matching the API of lax.scan
checkpoint_fn: function matching the API of jax.checkpoint.
Returns:
Carry and output values.
"""
if length is not None and length != math.prod(nested_lengths):
raise ValueError(f'inconsistent {length=} and {nested_lengths=}')
def nested_reshape(x):
x = jnp.asarray(x)
new_shape = tuple(nested_lengths) + x.shape[1:]
return x.reshape(new_shape)
sub_xs = jax.tree_map(nested_reshape, xs)
return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn,
checkpoint_fn)
def _inner_nested_scan(f, init, xs, lengths, scan_fn, checkpoint_fn):
"""Recursively applied scan function."""
if len(lengths) == 1:
return scan_fn(f, init, xs, lengths[0])
@checkpoint_fn
def sub_scans(carry, xs):
return _inner_nested_scan(f, carry, xs, lengths[1:], scan_fn, checkpoint_fn)
carry, out = scan_fn(sub_scans, init, xs, lengths[0])
stacked_out = jax.tree_map(jnp.concatenate, out)
return carry, stacked_out |
Reporting back to this old thread: Equinox now supports a while-loop-with-gradient-checkpointing. This is available at Source code here: https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/while_loop/checkpointed.py This means that (a) we now have a proper checkpointing scheme to help manage memory, and (b) we also get reverse-mode autodifferentiable while loops in JAX! To my knowledge this is the first implementation of this in JAX. We've had similar things floating around before (e.g stuff like For the pedantically curious. This is technically slightly different to scan-with-gradient-checkpointing. The difference is that in a scan, the number of steps is known in advance. In a while loop, the number of steps is not known in advance. This implies using slightly different checkpointing algorithms: "online treeverse" vs "classical treeverse", and the former may be slightly less efficient due to having less information to work with. Given a fixed |
It would be great to have a version of lax.scan used a recursive gradient checkpointing (e.g., "binomial checkpointing") that allows for differentiating through long time series with logarithmic time/space costs.
In principle this could be built on top of the experimental
remat
decorator: #1749The text was updated successfully, but these errors were encountered: