Skip to content

Commit

Permalink
Move some utilities out of dispatch.py next to their users, add more …
Browse files Browse the repository at this point in the history
…types.

Internal cleanups only, no user-visible changes intended.

PiperOrigin-RevId: 554876522
  • Loading branch information
hawkinsp authored and jax authors committed Aug 8, 2023
1 parent afd56c1 commit e58f1ba
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 56 deletions.
2 changes: 1 addition & 1 deletion jax/_src/api.py
Expand Up @@ -117,7 +117,7 @@ def _nan_check_posthook(fun, args, kwargs, output):
buffers.extend(leaf.device_buffers)

try:
dispatch.check_special(pjit.pjit_p, buffers)
dispatch.check_special(pjit.pjit_p.name, buffers)
except FloatingPointError:
# compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs
Expand Down
79 changes: 29 additions & 50 deletions jax/_src/dispatch.py
Expand Up @@ -31,6 +31,7 @@

import numpy as np

from jax._src import basearray
from jax._src import compilation_cache
from jax._src import config as jax_config
from jax._src import core
Expand All @@ -47,7 +48,6 @@
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
Expand Down Expand Up @@ -281,29 +281,6 @@ def should_tuple_args(num_args: int, platform: str) -> bool:
else:
return False


def raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr):
if nreps > 1:
warnings.warn(
f"The jitted function {name} includes a pmap. Using "
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/google/jax/issues/2926.")

if nreps > xb.device_count(backend):
raise ValueError(
f"compiling computation `{name}` that requires {nreps} replicas, but "
f"only {xb.device_count(backend)} XLA devices are available.")

if xb.process_count() > 1 and (nreps > 1 or
jaxpr_has_primitive(jaxpr, "xla_pmap")):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")


def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
Expand Down Expand Up @@ -371,14 +348,6 @@ def _is_bint_axis_size(d: core.AxisSize) -> bool:
type(d.aval.dtype) is core.bint)
return False

def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> tuple[core.Jaxpr, set[int], set[int]]:
used_outputs = [True] * len(jaxpr.outvars)
new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs)
kept_const_idx = {i for i, b in enumerate(used_consts) if b}
kept_var_idx = {i for i, b in enumerate(used_inputs) if b}
return new_jaxpr, kept_const_idx, kept_var_idx


# We can optionally set a Jaxpr rewriter that can be applied just before
# compilation. This mechanism is used for compiling id_tap, we can
Expand Down Expand Up @@ -407,40 +376,38 @@ def check_arg(arg: Any):
"JAX type.")


def jaxpr_replicas(jaxpr) -> int:
def jaxpr_replicas(jaxpr: core.Jaxpr) -> int:
"""The number of replicas needed for a jaxpr.
For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
subjaxprs. For a list of eqns, take the maximum number of replicas.
"""
if isinstance(jaxpr, core.ClosedJaxpr):
jaxpr = jaxpr.jaxpr
return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1)
return max(unsafe_map(_eqn_replicas, jaxpr.eqns), default=1)

# TODO(mattjj): this function assumes that only pmap has a parameter named
# axis_size, and that it corresponds to cross-replica mapping
def eqn_replicas(eqn):
def _eqn_replicas(eqn: core.JaxprEqn) -> int:
call_jaxpr = eqn.params.get("call_jaxpr")
if call_jaxpr:
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
elif eqn.primitive in xla.initial_style_primitives:
return initial_style_primitive_replicas(eqn.params)
return _initial_style_primitive_replicas(eqn.params)
else:
return 1

def initial_style_primitive_replicas(params):
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1)
def _initial_style_primitive_replicas(params: dict[str, Any]) -> int:
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(),
default=1)


def needs_check_special():
def needs_check_special() -> bool:
return config.jax_debug_infs or config.jax_debug_nans

def check_special(name, bufs):
def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
if needs_check_special():
for buf in bufs:
_check_special(name, buf.dtype, buf)

def _check_special(name, dtype, buf):
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
if dtypes.issubdtype(dtype, np.inexact):
if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))):
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
Expand All @@ -449,7 +416,12 @@ def _check_special(name, dtype, buf):


@profiler.annotate_function
def backend_compile(backend, module: ir.Module, options, host_callbacks):
def backend_compile(
backend: Backend,
module: ir.Module,
options: xc.CompileOptions,
host_callbacks: Sequence[Any],
) -> xc.LoadedExecutable:
# Convert ir.Module to a string representation, unless the
# back-end expliclity flags the ability to handle a module directly
# (avoiding the overhead of back and forth conversions)
Expand All @@ -468,6 +440,7 @@ def backend_compile(backend, module: ir.Module, options, host_callbacks):
# to take in `host_callbacks`
return backend.compile(built_c, compile_options=options)


_ir_dump_counter = itertools.count()

def _make_string_safe_for_filename(s: str) -> str:
Expand All @@ -480,8 +453,13 @@ def _dump_ir_to_file(name: str, ir: str):
name.write_text(ir)


def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
compile_options, host_callbacks):
def compile_or_get_cached(
backend: Backend,
computation: ir.Module,
devices: np.ndarray,
compile_options: xc.CompileOptions,
host_callbacks: Sequence[Any],
) -> xc.LoadedExecutable:
sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value

Expand Down Expand Up @@ -522,7 +500,8 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,


def _cache_read(
module_name: str, cache_key: str, compile_options, backend
module_name: str, cache_key: str, compile_options: xc.CompileOptions,
backend: Backend
) -> tuple[xc.LoadedExecutable | None, int | None]:
"""Looks up the `computation` and it's compilation time in the persistent
compilation cache repository.
Expand All @@ -543,7 +522,7 @@ def _cache_write(cache_key: str,
compile_time_secs: float,
module_name: str,
backend: Backend, executable: xc.LoadedExecutable,
host_callbacks: list[Any]):
host_callbacks: Sequence[Any]) -> None:
"""Writes the `serialized_computation` and its compilation time to the
persistent compilation cache repository.
"""
Expand Down Expand Up @@ -580,7 +559,7 @@ def _cache_write(cache_key: str,

# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
# to check if shardings are compatible with the input.
def _check_sharding(aval, s):
def _check_sharding(aval: core.AbstractValue, s: Sharding):
from jax._src import pjit

if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
Expand Down
40 changes: 37 additions & 3 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -25,6 +25,7 @@
import logging
import math
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
import warnings

import numpy as np

Expand Down Expand Up @@ -1643,6 +1644,16 @@ def wrapped(f, *args, **kwargs):
return wrapped


def prune_unused_inputs(
jaxpr: core.Jaxpr,
) -> tuple[core.Jaxpr, set[int], set[int]]:
used_outputs = [True] * len(jaxpr.outvars)
new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs)
kept_const_idx = {i for i, b in enumerate(used_consts) if b}
kept_var_idx = {i for i, b in enumerate(used_inputs) if b}
return new_jaxpr, kept_const_idx, kept_var_idx


@cache_wrap
def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
keep_unused, donated_invars, auto_spmd_lowering):
Expand All @@ -1665,7 +1676,7 @@ def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
for a in global_in_avals)):
kept_var_idx = set(range(len(global_in_avals)))
else:
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
jaxpr, kept_const_idx, kept_var_idx = prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
Expand Down Expand Up @@ -1699,6 +1710,30 @@ def __eq__(self, other):
)


def _raise_warnings_or_errors_for_jit_of_pmap(
nreps: int, backend: xc.Client, name: str, jaxpr: core.Jaxpr) -> None:
if nreps > 1:
warnings.warn(
f"The jitted function {name} includes a pmap. Using "
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/google/jax/issues/2926.")

if nreps > xb.device_count(backend):
raise ValueError(
f"compiling computation `{name}` that requires {nreps} replicas, but "
f"only {xb.device_count(backend)} XLA devices are available.")

if xb.process_count() > 1 and (
nreps > 1 or dispatch.jaxpr_has_primitive(jaxpr, "xla_pmap")
):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")


@weakref_lru_cache
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings,
Expand All @@ -1724,8 +1759,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
# `jax.Array` is turned on by default.
# TODO(yashkatariya): Remove this when `jit(pmap)` is removed.
nreps = dispatch.jaxpr_replicas(jaxpr)
dispatch.raise_warnings_or_errors_for_jit_of_pmap(
nreps, backend, fun_name, jaxpr)
_raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr)

in_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None
out_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None
Expand Down
4 changes: 2 additions & 2 deletions tests/xla_interpreter_test.py
Expand Up @@ -16,7 +16,7 @@

import jax
from jax._src import test_util as jtu
from jax._src import dispatch
from jax._src.interpreters import pxla


class XlaInterpreterTest(jtu.JaxTestCase):
Expand All @@ -26,7 +26,7 @@ def f(*args):
return args[0]

closed_jaxpr = jax.make_jaxpr(f)(*range(10))
pruned_jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(
pruned_jaxpr, kept_const_idx, kept_var_idx = pxla.prune_unused_inputs(
closed_jaxpr.jaxpr)
assert len(pruned_jaxpr.invars) == 1
assert kept_const_idx == set()
Expand Down

0 comments on commit e58f1ba

Please sign in to comment.