Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[shape_poly] Improve compile-time shape checking.
JAX shape polymorphism relies on implicit assumptions. For example, when tracing with input specification `(a, a)`, we assume that the first two dimensions have the same size greater or equal to 1. Here we extend the checking that these assumptions hold. When we call an `Exported` module from jax, with `jax_export.call_exported` we check these assumptions statically. However, when we stage an `Exported` using `XlaCallModule` to be called from TensorFlow, or when we use TF graph serialization we need to check these assumptions when we execute and compile the op (that is when the shapes are available). To prepare for this compile-time shape checking we add `Exported.shape_check_module` to produce a serialized MLIR module containing the shape checking code. This will be added in a future change to `XlaCallModule`.
- Loading branch information
Showing
6 changed files
with
749 additions
and
306 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.