-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[remat] Change remat lowering to XLA::Conditional #2391
Conversation
`jax.remat` creates rematerializing passes that don't have data dependencies on the actual loss-computing forward pass. This means that the XLA scheduler was free to schedule the remat forward pass before the loss-computing pass, defeating the goal of saving accelerator memory with `jax.remat`. In practice, it sometimes did for my workloads. This change expresses the lowering of remat_call(f) as: Conditional(true, inputs, f, inputs, dummy_f). In the common case of `jax.grad(jax.remat(f))`, the content of the lowered remat_call are both the forwards & backwards; that is, the incoming cotangents are part of the args. Additionally, Conditional (AFAIK) is un-inlineable in the sense that it doesn't execute until all its inputs (e.g. cotangents!) are available. Downsides: - AFAICT, we can no longer interleave computation in/outside the rematerialized block. - Potentially, lower performance. I do not observe this in my tests.
This looks good, but IIUC @trevorcai suggested we wait to merge until he runs more tests. Let us know! |
Had to include one-line change to work around an upstream XLA bug which got the parameter replication on the Conditionals wrong somehow. I think this is ready to merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
* [remat] Change remat lowering to XLA::Conditional `jax.remat` creates rematerializing passes that don't have data dependencies on the actual loss-computing forward pass. This means that the XLA scheduler was free to schedule the remat forward pass before the loss-computing pass, defeating the goal of saving accelerator memory with `jax.remat`. In practice, it sometimes did for my workloads. This change expresses the lowering of remat_call(f) as: Conditional(true, inputs, f, inputs, dummy_f). In the common case of `jax.grad(jax.remat(f))`, the content of the lowered remat_call are both the forwards & backwards; that is, the incoming cotangents are part of the args. Additionally, Conditional (AFAIK) is un-inlineable in the sense that it doesn't execute until all its inputs (e.g. cotangents!) are available. Downsides: - AFAICT, we can no longer interleave computation in/outside the rematerialized block. - Potentially, lower performance. I do not observe this in my tests. * provide no replication info for subcomputation params
jax.remat
creates rematerializing passes that don't have data dependencies onthe actual loss-computing forward pass. This means that the XLA scheduler was
free to schedule the remat forward pass before the loss-computing pass,
defeating the goal of saving accelerator memory with
jax.remat
.In practice, it sometimes did for my workloads.
This change expresses the lowering of remat_call(f) as:
Conditional(true, inputs, f, inputs, dummy_f).
In the common case of
jax.grad(jax.remat(f))
, the content of thelowered remat_call are both the forwards & backwards; that is, the
incoming cotangents are part of the args.
Additionally, Conditional (AFAIK) is un-inlineable in the sense that it
doesn't execute until all its inputs (e.g. cotangents!) are available.
Downsides:
rematerialized block.