Skip to content

Commit

Permalink
refactor stages types, adding methods for text and for cost/memory …
Browse files Browse the repository at this point in the history
…analyses

Re-organizing things this way in order to:

* Clarify internally what a lowering and executable should do, rather than what current XLA-backed versions happen to provide.

* Document that some features (e.g. cost analysis) are best-effort and intended mainly for debugging purposes. They may be unimplemented on some backends and what they return is intentionally undefined.

For an example of the latter item, this change adds a `cost_analysis()` method on `jax.stages.Compiled`. However, the expression `jit(f).lower(*args).compile().cost_analysis()` may return `None` depending on backend. Otherwise, guarantees about its output and return type are very limited -- these can differ across invocations and across JAX/jaxlib versions.

Some specifics:
* Introduce `cost_analysis` and `memory_analysis` methods on `Compiled` that do as their name suggests.
* Introduce `as_text` methods on `Lowered` and `Compiled` that do as the name suggests.
* Rename `_src.stages.Computation` protocol to `_src.stages.Lowering`.
* Fix a handful of type annotations, add various docstrings and comments explaining the above.

PiperOrigin-RevId: 458574166
  • Loading branch information
froystig authored and jax authors committed Jul 2, 2022
1 parent 1fc9afd commit f12af93
Show file tree
Hide file tree
Showing 8 changed files with 480 additions and 84 deletions.
13 changes: 6 additions & 7 deletions jax/_src/dispatch.py
Expand Up @@ -778,7 +778,7 @@ def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
else h(None, *device_put(x, device)) for h, x in zip(handlers, outs)]


class XlaComputation(stages.Computation):
class XlaComputation(stages.XlaLowering):
name: str
_is_trivial: bool
_executable: Optional[XlaCompiledComputation]
Expand All @@ -801,6 +801,8 @@ def __init__(self, name: str, hlo, is_trivial: bool,
def is_trivial(self):
return self._is_trivial

# -- stages.XlaLowering overrides

def hlo(self) -> xc.XlaComputation:
if self.is_trivial():
raise ValueError("A trivial computation has no HLO")
Expand Down Expand Up @@ -892,7 +894,7 @@ def compile_or_get_cached(backend, computation, compile_options):
return backend_compile(backend, computation, compile_options)


class XlaCompiledComputation(stages.Executable):
class XlaCompiledComputation(stages.XlaExecutable):
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call,
keepalive: Any):
self._xla_executable = xla_executable
Expand Down Expand Up @@ -963,14 +965,11 @@ def from_trivial_jaxpr(
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
keepalive)

# -- stages.Executable protocol
# -- stages.XlaExecutable overrides

def runtime_executable(self):
def xla_extension_executable(self):
return self.xla_executable

def hlo_modules(self):
return self.xla_executable.hlo_modules()

def call(self, *args):
arg_specs = unsafe_map(arg_spec, args)
arg_avals = [spec[0] for i, spec in enumerate(arg_specs)
Expand Down

0 comments on commit f12af93

Please sign in to comment.