Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor
stages
types, adding methods for text and for cost/memory …
…analyses Re-organizing things this way in order to: * Clarify internally what a lowering and executable should do, rather than what current XLA-backed versions happen to provide. * Document that some features (e.g. cost analysis) are best-effort and intended mainly for debugging purposes. They may be unimplemented on some backends and what they return is intentionally undefined. For an example of the latter item, this change adds a `cost_analysis()` method on `jax.stages.Compiled`. However, the expression `jit(f).lower(*args).compile().cost_analysis()` may return `None` depending on backend. Otherwise, guarantees about its output and return type are very limited -- these can differ across invocations and across JAX/jaxlib versions. Some specifics: * Introduce `cost_analysis` and `memory_analysis` methods on `Compiled` that do as their name suggests. * Introduce `as_text` methods on `Lowered` and `Compiled` that do as the name suggests. * Rename `_src.stages.Computation` protocol to `_src.stages.Lowering`. * Fix a handful of type annotations, add various docstrings and comments explaining the above. PiperOrigin-RevId: 458574166
- Loading branch information