-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for buffer donation in jit
and pmap
.
#2936
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You definitely need to get mattjj@ to look at this. When I have added mapped_invars
I missed several places where it should be updated (#2828)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice! I didn't do a super in-depth review, please let me know if you want one from me (and sorry for taking a few days to look at this, I'll be more prompt with future review).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments.
This looks very exciting! Is there any plan to add support for GPU as well? |
@avital my understanding is that this is platform-agnostic from the JAX end, but we'd need to request support for this feature from the XLA:GPU folks. EDIT: b/149489114 for googlers. |
I asked a few people to meet with me to review the design of this PR. Overall the proposed API was approved by the group and we will now iterate on a few changes to the implementation that will not be user visible (e.g. interaction with JVPTrace). One major change as a result of the review is that we are leaning in to "gifting" nature of "donation" (previously we would error if we could not use a donation, now we will simply discard it). The original implementation focused exclusively on the case where a donated buffer is aliased with an output, however there are other future cases where donation may make sense (e.g. reusing an input buffer for an intermediate computation) and on some backends we may not be able to make use of donated buffers (e.g. on CPU/GPU input/output aliasing is not supported). We decided that donation will be best effort and on backends that don't offer anything to do with donated buffers (e.g. CPU/GPU) we will log a warning and invalidate the the Python objects (to avoid accidental reuse). We will additionally no longer require perfect aliasing of inputs and outputs (meaning that on backends which do support donation you can donate more than you need). There were concerns raised about providing the user control over which specific buffers in the input aliased the output. For now we will document that input/output aliasing is managed by traversing the flat input and flat output in order and matching the first output with an appropriate shape/dtype to the input. This means you can have fine grained control over aliasing through re-ordering outputs in Python. Thanks to everyone who participated for an engaging discussion! I hope to have the patch ready to submit by the end of the month. |
Heads up, might need a rebase on #3210 when that goes in. Since that PR aims to be a cleanup, hopefully if there's any conflict it's only to make this PR easier. |
e6b87cf
to
f9176f8
Compare
Thanks for the comments everyone, this PR is ready for your review again 😄 I've rebased on master and made some major changes based on the design review:
PTAL! |
FYI two known issues that I think we can address in a follow up:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo the few suggestions offered, mostly docs/names.
92cbfa5
to
ee57d74
Compare
For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
ee57d74
to
464199e
Compare
donate_argnums = sorted(set(donate_argnums)) | ||
i = j = o = 0 | ||
out = [] | ||
while j < len(donate_argnums): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is gonna be super slow, but it's only invoked at compile time so it doesn't really matter.
if 'donated_invars' in params: | ||
new_donated_invars = (*[d for d, x in zip(params['donated_invars'], args) | ||
if not is_undefined_primal(x)], | ||
*[False for x in ct if x is not zero]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a vague sense that we might just be able to make this True
(assuming that backward_pass
is only called inside the AD infra) because it only has an effect in the op-by-op grad-of-jit case, and in that case using donation for cotangent buffer accumulation is always a good idea. But let's follow up and not make that change now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey nice catch, I think you're right!
@@ -187,9 +187,14 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): | |||
lifted_jaxpr = convert_constvars_jaxpr(jaxpr) | |||
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals] | |||
new_params = dict(params, call_jaxpr=lifted_jaxpr) | |||
invars = tuple(it.chain(const_tracers, env_tracers, tracers)) | |||
if 'donated_invars' in params: | |||
new_donated_invars = ((False,) * len(const_tracers) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, in the case where partial_eval
is used for reverse-mode AD (but only grad
, not linearize
, because that's all about reusing residuals!), it's possible that we could use donation on const_tracers
because we know they'll never be reused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed as you say this wouldn't work for linearize
(of jit
), but it also wouldn't work for vjp
, because in both cases we might call the linearized function many times and hence can't throw away the residuals after the first execution. It'd only work for grad
, though of course grad
is an important special case and might be worth thinking about lightweight ways to provide that optimization.
Filed #3286 to keep track of AD-related buffer donation opportunities. |
For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
For a computation of the form:
JAX must currently always have two copies of
x
in device memory since thereis no reliable way in Python to determine whether there will be future uses of
x
. This causes two classes of problem:Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
JAX will determine that the donated input
x
can alias with the output of thefunction and it will instruct XLA it must write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
The semantics of
donate_argnums
follows that ofstatic_argnums
, namely thatit identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.