Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[jax2tf] Refactor the experimental_native_lowering path in jax2tf
This is part of a suite of refactorings aimed towards supporting pjit by jax2tf experimental_native_lowering. The goal here is to remove many references to internal JAX core APIs, and instead use the AOT APIs: jax.jit(func_jax).lower(*args). Only the experimental_native_lowering behavior should be affected.
- Loading branch information
Showing
2 changed files
with
55 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters