Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scan with gradient checkpointing #2139

Open
shoyer opened this issue Feb 2, 2020 · 6 comments
Open

scan with gradient checkpointing #2139

shoyer opened this issue Feb 2, 2020 · 6 comments

Comments

@shoyer
Copy link
Collaborator

shoyer commented Feb 2, 2020

It would be great to have a version of lax.scan used a recursive gradient checkpointing (e.g., "binomial checkpointing") that allows for differentiating through long time series with logarithmic time/space costs.

In principle this could be built on top of the experimental remat decorator: #1749

@shoyer
Copy link
Collaborator Author

shoyer commented Aug 6, 2020

Support for a single level of gradient checkpointing might also be useful, which requires only one extra forward pass. Apparently the optimal way to do it reduces memory usage to the square root of the number of steps, e.g., per this implementation in TensorFlow:
https://github.com/cybertronai/gradient-checkpointing

@hawkinsp
Copy link
Collaborator

hawkinsp commented Aug 6, 2020

There is also a class of algorithms that optimize to a fixed memory budget: https://dl.acm.org/doi/10.1145/347837.347846

(I'm not sure they are worth it over the simpler strategies though.)

@shoyer
Copy link
Collaborator Author

shoyer commented Aug 6, 2020

Yep, I realize now that that's what "binomial checkpointing" in particular means. I was originally thinking of something simpler, just using recursion.

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Feb 14, 2022

So Diffrax actually implements a bounded_while_loop that does exactly this -- early exit by nesting scan-conds, and managing memory using recursive checkpointing. In Diffrax's case it's used to handle the stepping of a differential equation solver.

The implementation is here: https://github.com/patrick-kidger/diffrax/blob/2b4e4d863c15abc7143919bac7825090bbfe50be/diffrax/misc/bounded_while_loop.py

It's worth noting that there are a lot of caveats that need to be worked around in order to make something like this feasible.


In practice most of these details are hidden from an end-user. (You just end up with a funny-looking extra argument to body_fun, and in many cases have to suffer subpar performance.) But I thought I'd record them here for anyone who ends up treading down the same path I did. Implementing a bounded_while_loop that exhibits reasonable performance was easily the single hardest part of implementing Diffrax, by a very large margin.

@shoyer
Copy link
Collaborator Author

shoyer commented Jul 19, 2022

A few other reference points for anyone who find this issue:

  1. Flax has flax.linen.remat_scan for scanning over Flax modules.
  2. I wrote a simpler version of scanning with nested gradient checkpointing, based on some the same design principles as Diffrax's bounded_while_loop:
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union

import jax
import jax.numpy as jnp


Carry = TypeVar('Carry')
Input = TypeVar('Input')
Output = TypeVar('Output')
Func = TypeVar('Func', bound=Callable)


def nested_checkpoint_scan(
    f: Callable[[Carry, Input], Tuple[Carry, Output]],
    init: Carry,
    xs: Input,
    length: Optional[int] = None,
    *,
    nested_lengths: Sequence[int],
    scan_fn: typing.ScanFn = jax.lax.scan,
    checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,
) -> Tuple[Carry, Output]:
  """A version of lax.scan that supports recursive gradient checkpointing.

  The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for
  the required `nested_lengths` argument.

  The key feature of `nested_checkpoint_scan` is that gradient calculations
  require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested
  scans, which it achieves by re-evaluating the forward pass
  `len(nested_lengths) - 1` times.

  `nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a
  single element.

  Args:
    f: function to scan over.
    init: initial value.
    xs: scanned over values.
    length: leading length of all dimensions
    nested_lengths: required list of lengths to scan over for each level of
      checkpointing. The product of nested_lengths must match length (if
      provided) and the size of the leading axis for all arrays in ``xs``.
    scan_fn: function matching the API of lax.scan
    checkpoint_fn: function matching the API of jax.checkpoint.

  Returns:
    Carry and output values.
  """
  if length is not None and length != math.prod(nested_lengths):
    raise ValueError(f'inconsistent {length=} and {nested_lengths=}')

  def nested_reshape(x):
    x = jnp.asarray(x)
    new_shape = tuple(nested_lengths) + x.shape[1:]
    return x.reshape(new_shape)

  sub_xs = jax.tree_map(nested_reshape, xs)
  return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn,
                            checkpoint_fn)


def _inner_nested_scan(f, init, xs, lengths, scan_fn, checkpoint_fn):
  """Recursively applied scan function."""
  if len(lengths) == 1:
    return scan_fn(f, init, xs, lengths[0])

  @checkpoint_fn
  def sub_scans(carry, xs):
    return _inner_nested_scan(f, carry, xs, lengths[1:], scan_fn, checkpoint_fn)

  carry, out = scan_fn(sub_scans, init, xs, lengths[0])
  stacked_out = jax.tree_map(jnp.concatenate, out)
  return carry, stacked_out

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Feb 22, 2023

Reporting back to this old thread: Equinox now supports a while-loop-with-gradient-checkpointing. This is available at equinox.internal.while_loop(..., kind="checkpointed").

Source code here: https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/while_loop/checkpointed.py

This means that (a) we now have a proper checkpointing scheme to help manage memory, and (b) we also get reverse-mode autodifferentiable while loops in JAX!

To my knowledge this is the first implementation of this in JAX. We've had similar things floating around before (e.g stuff like lax.scan(jax.checkpoint(f), ...), or multi-level versions of that), but they've always suffered from either asymptotically slow runtimes (since the checkpointing scheme wasn't really the right thing) or from slow compile times (e.g. due to unrolling loops).

For the pedantically curious.

This is technically slightly different to scan-with-gradient-checkpointing. The difference is that in a scan, the number of steps is known in advance. In a while loop, the number of steps is not known in advance. This implies using slightly different checkpointing algorithms: "online treeverse" vs "classical treeverse", and the former may be slightly less efficient due to having less information to work with.

Given a fixed num_checkpoints, and then running to see how many num_steps you get: if num_steps <= (num_checkpoints + 1) * (num_checkpoints + 2) / 2 then it turns out that both approaches match each other exactly. If num_steps is larger then this bound then online treeverse will make some extra computations (as compared to classical treeverse with an oracle on the number of steps) -- but it will still at least have the same asymptotic complexity!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants