Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for buffer donation in jit and pmap. #2936

Merged
merged 1 commit into from
May 31, 2020

Conversation

tomhennigan
Copy link
Collaborator

For a computation of the form:

>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
...   x = f(x)

JAX must currently always have two copies of x in device memory since there
is no reliable way in Python to determine whether there will be future uses of
x. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
    copy of their parameters and other state while they typically only require
    one copy. This typically frees 100M+ of device memory and is a critical
    optimization for larger models to match state of the art performance in
    other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
    fragmentation on some platforms (although having a reusing allocator and
    limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
...   x = f(x)

JAX will determine that the donated input x can alias with the output of the
function and it will instruct XLA it must write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of donate_argnums follows that of static_argnums, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

>>> @partial(jit, donate_argnums=0)
... def move(x):
...   # Do something complex enough for JAX to just optimize it away.
...   return tree_map(lambda x: x + x - x, x)

>>> def safe_eager_uniform(key, *a, **k):
...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
...   key = move(key)
...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.

@tomhennigan
Copy link
Collaborator Author

@hawkinsp @skye FYI

Copy link
Collaborator

@gnecula gnecula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You definitely need to get mattjj@ to look at this. When I have added mapped_invars I missed several places where it should be updated (#2828)

jax/api_util.py Outdated Show resolved Hide resolved
jax/interpreters/pxla.py Outdated Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
jax/interpreters/xla.py Outdated Show resolved Hide resolved
jax/interpreters/xla.py Show resolved Hide resolved
jax/interpreters/xla.py Outdated Show resolved Hide resolved
jax/interpreters/xla.py Outdated Show resolved Hide resolved
tests/api_util_test.py Show resolved Hide resolved
jax/interpreters/xla.py Show resolved Hide resolved
tests/api_test.py Show resolved Hide resolved
jax/api.py Show resolved Hide resolved
jax/interpreters/xla.py Outdated Show resolved Hide resolved
Copy link
Member

@skye skye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really nice! I didn't do a super in-depth review, please let me know if you want one from me (and sorry for taking a few days to look at this, I'll be more prompt with future review).

tests/api_test.py Show resolved Hide resolved
jax/interpreters/pxla.py Outdated Show resolved Hide resolved
jax/api.py Show resolved Hide resolved
Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments.

jax/api.py Show resolved Hide resolved
jax/api_util.py Outdated Show resolved Hide resolved
jax/interpreters/pxla.py Outdated Show resolved Hide resolved
jax/interpreters/xla.py Outdated Show resolved Hide resolved
jax/interpreters/xla.py Outdated Show resolved Hide resolved
jax/interpreters/xla.py Outdated Show resolved Hide resolved
jax/interpreters/xla.py Outdated Show resolved Hide resolved
jax/interpreters/xla.py Outdated Show resolved Hide resolved
jax/interpreters/xla.py Outdated Show resolved Hide resolved
@avital
Copy link

avital commented May 12, 2020

This looks very exciting! Is there any plan to add support for GPU as well?

@mattjj
Copy link
Collaborator

mattjj commented May 14, 2020

@avital my understanding is that this is platform-agnostic from the JAX end, but we'd need to request support for this feature from the XLA:GPU folks. EDIT: b/149489114 for googlers.

@tomhennigan
Copy link
Collaborator Author

I asked a few people to meet with me to review the design of this PR. Overall the proposed API was approved by the group and we will now iterate on a few changes to the implementation that will not be user visible (e.g. interaction with JVPTrace).

One major change as a result of the review is that we are leaning in to "gifting" nature of "donation" (previously we would error if we could not use a donation, now we will simply discard it). The original implementation focused exclusively on the case where a donated buffer is aliased with an output, however there are other future cases where donation may make sense (e.g. reusing an input buffer for an intermediate computation) and on some backends we may not be able to make use of donated buffers (e.g. on CPU/GPU input/output aliasing is not supported).

We decided that donation will be best effort and on backends that don't offer anything to do with donated buffers (e.g. CPU/GPU) we will log a warning and invalidate the the Python objects (to avoid accidental reuse). We will additionally no longer require perfect aliasing of inputs and outputs (meaning that on backends which do support donation you can donate more than you need).

There were concerns raised about providing the user control over which specific buffers in the input aliased the output. For now we will document that input/output aliasing is managed by traversing the flat input and flat output in order and matching the first output with an appropriate shape/dtype to the input. This means you can have fine grained control over aliasing through re-ordering outputs in Python.

Thanks to everyone who participated for an engaging discussion! I hope to have the patch ready to submit by the end of the month.

@mattjj
Copy link
Collaborator

mattjj commented May 27, 2020

Heads up, might need a rebase on #3210 when that goes in. Since that PR aims to be a cleanup, hopefully if there's any conflict it's only to make this PR easier.

@tomhennigan
Copy link
Collaborator Author

Thanks for the comments everyone, this PR is ready for your review again 😄 I've rebased on master and made some major changes based on the design review:

  1. We now interact correctly with JVPTrace, following the guidance of mapped_invars (thanks for the help @mattjj!!).
  2. Donation is now optional, on CPU/GPU we simply log a warning at "compile" time that the buffers aren't going to be used and do nothing else.
  3. Function arguments and type hints are now consistent with the rest of JAX.

PTAL!

jax/api.py Outdated Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
@tomhennigan
Copy link
Collaborator Author

FYI two known issues that I think we can address in a follow up:

  1. If you try and donate the same array twice in the same computation (e.g. jit(f, donate_argnums=(0, 1))(x, x)) you get a deadlock.
  2. Buffers that are donated but not used are currently not marked as "deleted" meaning that you can reuse them (this is not a correctness issue, but might lead to confusing "this buffer is deleted" errors if you move your code to TPU where input/output aliasing is supported). In a follow up we should delete these buffers at call time.

jax/core.py Show resolved Hide resolved
jax/interpreters/pxla.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo the few suggestions offered, mostly docs/names.

@tomhennigan tomhennigan force-pushed the changelist/304781360 branch 2 times, most recently from 92cbfa5 to ee57d74 Compare May 31, 2020 19:44
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
donate_argnums = sorted(set(donate_argnums))
i = j = o = 0
out = []
while j < len(donate_argnums):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is gonna be super slow, but it's only invoked at compile time so it doesn't really matter.

jax/core.py Show resolved Hide resolved
if 'donated_invars' in params:
new_donated_invars = (*[d for d, x in zip(params['donated_invars'], args)
if not is_undefined_primal(x)],
*[False for x in ct if x is not zero])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a vague sense that we might just be able to make this True (assuming that backward_pass is only called inside the AD infra) because it only has an effect in the op-by-op grad-of-jit case, and in that case using donation for cotangent buffer accumulation is always a good idea. But let's follow up and not make that change now!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey nice catch, I think you're right!

@@ -187,9 +187,14 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
new_params = dict(params, call_jaxpr=lifted_jaxpr)
invars = tuple(it.chain(const_tracers, env_tracers, tracers))
if 'donated_invars' in params:
new_donated_invars = ((False,) * len(const_tracers) +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, in the case where partial_eval is used for reverse-mode AD (but only grad, not linearize, because that's all about reusing residuals!), it's possible that we could use donation on const_tracers because we know they'll never be reused.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed as you say this wouldn't work for linearize (of jit), but it also wouldn't work for vjp, because in both cases we might call the linearized function many times and hence can't throw away the residuals after the first execution. It'd only work for grad, though of course grad is an important special case and might be worth thinking about lightweight ways to provide that optimization.

@jekbradbury jekbradbury merged commit 6124f70 into jax-ml:master May 31, 2020
@jekbradbury
Copy link
Contributor

Filed #3286 to keep track of AD-related buffer donation opportunities.

NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jun 11, 2020
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants