From b024e01440350145ebe03ac042805af77acf295b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 8 Aug 2023 06:33:35 -0700 Subject: [PATCH] Improve dispatch.py typing. Inline _xla_callable_uncached, which is trivial, into its only caller. Cleanup only, no user-visible changes intended. PiperOrigin-RevId: 554805210 --- jax/_src/dispatch.py | 72 ++++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 241f7fdf4cba..ee78dfcdd6ad 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -22,7 +22,7 @@ from functools import partial import itertools import time -from typing import (Any, Callable, Optional, NamedTuple) +from typing import Any, Callable, NamedTuple import logging import os import re @@ -73,8 +73,6 @@ traceback_util.register_exclusion(__file__) -MYPY = False # Are we currently type checking with mypy? - xe = xc._xla Backend = xe.Client @@ -92,19 +90,22 @@ ### op-by-op execution -ArgSpec = tuple[core.AbstractValue, Optional[Device]] +class _ArgSpec(NamedTuple): + aval: core.AbstractValue + sharding: XLACompatibleSharding | None + -def arg_spec(x: Any) -> ArgSpec: +def _arg_spec(x: Any) -> _ArgSpec: from jax._src import pjit aval = xla.abstractify(x) try: if isinstance(x.sharding, PmapSharding): - return aval, None - return aval, (pjit.to_gspmd_sharding(x.sharding, x.ndim) # type: ignore - if x._committed else None) + return _ArgSpec(aval, None) + return _ArgSpec(aval, (pjit.to_gspmd_sharding(x.sharding, x.ndim) # type: ignore + if x._committed else None)) except: - return aval, None + return _ArgSpec(aval, None) @dataclasses.dataclass(frozen=True) @@ -126,7 +127,7 @@ def apply_primitive(prim, *args, **params): from jax._src import pjit try: - in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args]) + in_avals, in_shardings = util.unzip2([_arg_spec(a) for a in args]) compiled_fun = xla_primitive_callable( prim, in_avals, OrigShardings(in_shardings), **params) except pxla.DeviceAssignmentMismatchError as e: @@ -209,8 +210,12 @@ def block_until_ready(self): def wait_for_tokens(): runtime_tokens.block_until_ready() + @util.cache() -def xla_primitive_callable(prim, in_avals, orig_in_shardings, **params): +def xla_primitive_callable( + prim: core.Primitive, in_avals: tuple[core.AbstractValue, ...], + orig_in_shardings: OrigShardings, **params, +) -> Callable: def prim_fun(*args): out = prim.bind(*args, **params) if prim.multiple_results: @@ -218,40 +223,34 @@ def prim_fun(*args): else: return out, donated_invars = (False,) * len(in_avals) - compiled = _xla_callable_uncached( - lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals, - orig_in_shardings) + computation = sharded_lowering( + lu.wrap_init(prim_fun), prim.name, donated_invars, keep_unused=False, + inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings, + lowering_platform=None) + compiled = computation.compile().unsafe_call if not prim.multiple_results: return lambda *args, **kw: compiled(*args, **kw)[0] else: return compiled -def sharded_lowering(fun, name, donated_invars, keep_unused, inline, - in_avals, in_shardings, lowering_platform: str | None): - if isinstance(in_shardings, OrigShardings): - in_shardings = in_shardings.shardings - - in_shardings = [UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore +def sharded_lowering( + fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool], + keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...], + in_shardings: Sequence[Sharding | None], lowering_platform: str | None +) -> pxla.MeshComputation: + in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings] # Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know # the number of output avals at this stage. lower_sharding_computation will # apply it to all out_avals. return pxla.lower_sharding_computation( - fun, 'jit', name, in_shardings, UNSPECIFIED, donated_invars, - tuple(in_avals), keep_unused=keep_unused, inline=inline, always_lower=False, + fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars, + in_avals, keep_unused=keep_unused, inline=inline, always_lower=False, devices_from_context=None, lowering_platform=lowering_platform) -def _xla_callable_uncached(fun: lu.WrappedFun, name, donated_invars, - keep_unused, in_avals, orig_in_shardings): - computation = sharded_lowering( - fun, name, donated_invars, keep_unused, True, in_avals, orig_in_shardings, - lowering_platform=None) - return computation.compile().unsafe_call - - -def is_single_device_sharding(sharding) -> bool: +def is_single_device_sharding(sharding: Sharding) -> bool: # Special case PmapSharding here because PmapSharding maps away an axis # and needs to be handled separately.test_pjit_single_device_sharding_add return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) @@ -273,7 +272,7 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None): record_event_duration_secs(event, elapsed_time) -def should_tuple_args(num_args: int, platform: str): +def should_tuple_args(num_args: int, platform: str) -> bool: # CPU and GPU do not need tuples as they use host-side data structures that # do not have small bounds. # TPU only needs a tuple for very long lists @@ -305,7 +304,7 @@ def raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr): "extra data movement anyway, so maybe you don't want it after all).") -def jaxpr_has_primitive(jaxpr, prim_name: str): +def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool: """Whether there is a primitive given by user anywhere inside a Jaxpr.""" for eqn in jaxpr.eqns: if prim_name in eqn.primitive.name: @@ -322,7 +321,8 @@ class SourceInfo(NamedTuple): def jaxpr_shardings( - jaxpr) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]: + jaxpr: core.Jaxpr, +) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]: from jax._src import pjit from jax.experimental import shard_map @@ -393,7 +393,7 @@ def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr: # TODO(mattjj,necula): this duplicates code in core.valid_jaxtype, but one # internal user relies on it for duck-typing. must fix downstream user! -def _valid_jaxtype(arg): +def _valid_jaxtype(arg: Any) -> bool: try: xla.abstractify(arg) # faster than core.get_aval except TypeError: @@ -401,7 +401,7 @@ def _valid_jaxtype(arg): else: return True -def check_arg(arg): +def check_arg(arg: Any): if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)): raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid " "JAX type.")