Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[jax2tf] Fix grad of pjit in native lowering.
Since jax2tf.convert is called recursively for the purpose of serializing the vjp function, we must ensure that if the primal function is a pjit with shardings then the vjp function must also be converted as a pjit. Without this fix the serialization with gradients of a pjit function will fail the an error that there are shardings but not pjit at the top-level.
- Loading branch information
Showing
4 changed files
with
150 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters