forked from google/jax
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a simple form of partial evaluation for while_loop. (google#2497)
The issue that I wanted to fix was that when running grad(while_loop), the error was a cryptic assertion failure (that all primals are known after linearization, in ad.py:linearize). I could not figure out how to detect before that assertion that we are doing a reverse AD for while_loop. So, I implemented a simple form of partial evaluation, to allow the primals after linearization to be known, so that the code proceeds and can then fail gracefully when trying to transpose the while. This is not a proper implementation of partial evaluation. The known outputs are computed early, properly. But the unknown outputs are computed by a *whole* computation of, including the known parts. Fixes issue: google#2129
- Loading branch information
Showing
3 changed files
with
149 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters