Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[remat] Change remat lowering to XLA::Conditional #2391

Merged
merged 2 commits into from
Mar 11, 2020
Merged

[remat] Change remat lowering to XLA::Conditional #2391

merged 2 commits into from
Mar 11, 2020

Conversation

trevorcai
Copy link
Contributor

jax.remat creates rematerializing passes that don't have data dependencies on
the actual loss-computing forward pass. This means that the XLA scheduler was
free to schedule the remat forward pass before the loss-computing pass,
defeating the goal of saving accelerator memory with jax.remat.

In practice, it sometimes did for my workloads.

This change expresses the lowering of remat_call(f) as:
Conditional(true, inputs, f, inputs, dummy_f).

In the common case of jax.grad(jax.remat(f)), the content of the
lowered remat_call are both the forwards & backwards; that is, the
incoming cotangents are part of the args.

Additionally, Conditional (AFAIK) is un-inlineable in the sense that it
doesn't execute until all its inputs (e.g. cotangents!) are available.

Downsides:

  • AFAICT, we can no longer interleave computation in/outside the
    rematerialized block.
  • Potentially, lower performance. I do not observe this in my tests.

`jax.remat` creates rematerializing passes that don't have data dependencies on
the actual loss-computing forward pass. This means that the XLA scheduler was
free to schedule the remat forward pass before the loss-computing pass,
defeating the goal of saving accelerator memory with `jax.remat`.

In practice, it sometimes did for my workloads.

This change expresses the lowering of remat_call(f) as:
Conditional(true, inputs, f, inputs, dummy_f).

In the common case of `jax.grad(jax.remat(f))`, the content of the
lowered remat_call are both the forwards & backwards; that is, the
incoming cotangents are part of the args.

Additionally, Conditional (AFAIK) is un-inlineable in the sense that it
doesn't execute until all its inputs (e.g. cotangents!) are available.

Downsides:

- AFAICT, we can no longer interleave computation in/outside the
  rematerialized block.
- Potentially, lower performance. I do not observe this in my tests.
@mattjj
Copy link
Member

mattjj commented Mar 10, 2020

This looks good, but IIUC @trevorcai suggested we wait to merge until he runs more tests. Let us know!

@trevorcai
Copy link
Contributor Author

Had to include one-line change to work around an upstream XLA bug which got the parameter replication on the Conditionals wrong somehow. I think this is ready to merge.

Copy link
Member

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@mattjj mattjj merged commit 620bf43 into google:master Mar 11, 2020
srvasude pushed a commit to srvasude/jax that referenced this pull request May 5, 2020
* [remat] Change remat lowering to XLA::Conditional

`jax.remat` creates rematerializing passes that don't have data dependencies on
the actual loss-computing forward pass. This means that the XLA scheduler was
free to schedule the remat forward pass before the loss-computing pass,
defeating the goal of saving accelerator memory with `jax.remat`.

In practice, it sometimes did for my workloads.

This change expresses the lowering of remat_call(f) as:
Conditional(true, inputs, f, inputs, dummy_f).

In the common case of `jax.grad(jax.remat(f))`, the content of the
lowered remat_call are both the forwards & backwards; that is, the
incoming cotangents are part of the args.

Additionally, Conditional (AFAIK) is un-inlineable in the sense that it
doesn't execute until all its inputs (e.g. cotangents!) are available.

Downsides:

- AFAICT, we can no longer interleave computation in/outside the
  rematerialized block.
- Potentially, lower performance. I do not observe this in my tests.

* provide no replication info for subcomputation params
@trevorcai trevorcai deleted the ckpt_cse branch December 1, 2020 20:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants