Skip to content

Commit

Permalink
Improve dispatch.py typing.
Browse files Browse the repository at this point in the history
Inline _xla_callable_uncached, which is trivial, into its only caller.

Cleanup only, no user-visible changes intended.

PiperOrigin-RevId: 554805210
  • Loading branch information
hawkinsp authored and jax authors committed Aug 8, 2023
1 parent dec2366 commit b024e01
Showing 1 changed file with 36 additions and 36 deletions.
72 changes: 36 additions & 36 deletions jax/_src/dispatch.py
Expand Up @@ -22,7 +22,7 @@
from functools import partial
import itertools
import time
from typing import (Any, Callable, Optional, NamedTuple)
from typing import Any, Callable, NamedTuple
import logging
import os
import re
Expand Down Expand Up @@ -73,8 +73,6 @@

traceback_util.register_exclusion(__file__)

MYPY = False # Are we currently type checking with mypy?

xe = xc._xla

Backend = xe.Client
Expand All @@ -92,19 +90,22 @@

### op-by-op execution

ArgSpec = tuple[core.AbstractValue, Optional[Device]]
class _ArgSpec(NamedTuple):
aval: core.AbstractValue
sharding: XLACompatibleSharding | None


def arg_spec(x: Any) -> ArgSpec:
def _arg_spec(x: Any) -> _ArgSpec:
from jax._src import pjit

aval = xla.abstractify(x)
try:
if isinstance(x.sharding, PmapSharding):
return aval, None
return aval, (pjit.to_gspmd_sharding(x.sharding, x.ndim) # type: ignore
if x._committed else None)
return _ArgSpec(aval, None)
return _ArgSpec(aval, (pjit.to_gspmd_sharding(x.sharding, x.ndim) # type: ignore
if x._committed else None))
except:
return aval, None
return _ArgSpec(aval, None)


@dataclasses.dataclass(frozen=True)
Expand All @@ -126,7 +127,7 @@ def apply_primitive(prim, *args, **params):
from jax._src import pjit

try:
in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
in_avals, in_shardings = util.unzip2([_arg_spec(a) for a in args])
compiled_fun = xla_primitive_callable(
prim, in_avals, OrigShardings(in_shardings), **params)
except pxla.DeviceAssignmentMismatchError as e:
Expand Down Expand Up @@ -209,49 +210,47 @@ def block_until_ready(self):
def wait_for_tokens():
runtime_tokens.block_until_ready()


@util.cache()
def xla_primitive_callable(prim, in_avals, orig_in_shardings, **params):
def xla_primitive_callable(
prim: core.Primitive, in_avals: tuple[core.AbstractValue, ...],
orig_in_shardings: OrigShardings, **params,
) -> Callable:
def prim_fun(*args):
out = prim.bind(*args, **params)
if prim.multiple_results:
return out
else:
return out,
donated_invars = (False,) * len(in_avals)
compiled = _xla_callable_uncached(
lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals,
orig_in_shardings)
computation = sharded_lowering(
lu.wrap_init(prim_fun), prim.name, donated_invars, keep_unused=False,
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
lowering_platform=None)
compiled = computation.compile().unsafe_call
if not prim.multiple_results:
return lambda *args, **kw: compiled(*args, **kw)[0]
else:
return compiled


def sharded_lowering(fun, name, donated_invars, keep_unused, inline,
in_avals, in_shardings, lowering_platform: str | None):
if isinstance(in_shardings, OrigShardings):
in_shardings = in_shardings.shardings

in_shardings = [UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore
def sharded_lowering(
fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool],
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
in_shardings: Sequence[Sharding | None], lowering_platform: str | None
) -> pxla.MeshComputation:
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]

# Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know
# the number of output avals at this stage. lower_sharding_computation will
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings, UNSPECIFIED, donated_invars,
tuple(in_avals), keep_unused=keep_unused, inline=inline, always_lower=False,
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
in_avals, keep_unused=keep_unused, inline=inline, always_lower=False,
devices_from_context=None, lowering_platform=lowering_platform)


def _xla_callable_uncached(fun: lu.WrappedFun, name, donated_invars,
keep_unused, in_avals, orig_in_shardings):
computation = sharded_lowering(
fun, name, donated_invars, keep_unused, True, in_avals, orig_in_shardings,
lowering_platform=None)
return computation.compile().unsafe_call


def is_single_device_sharding(sharding) -> bool:
def is_single_device_sharding(sharding: Sharding) -> bool:
# Special case PmapSharding here because PmapSharding maps away an axis
# and needs to be handled separately.test_pjit_single_device_sharding_add
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
Expand All @@ -273,7 +272,7 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
record_event_duration_secs(event, elapsed_time)


def should_tuple_args(num_args: int, platform: str):
def should_tuple_args(num_args: int, platform: str) -> bool:
# CPU and GPU do not need tuples as they use host-side data structures that
# do not have small bounds.
# TPU only needs a tuple for very long lists
Expand Down Expand Up @@ -305,7 +304,7 @@ def raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr):
"extra data movement anyway, so maybe you don't want it after all).")


def jaxpr_has_primitive(jaxpr, prim_name: str):
def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if prim_name in eqn.primitive.name:
Expand All @@ -322,7 +321,8 @@ class SourceInfo(NamedTuple):


def jaxpr_shardings(
jaxpr) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]:
jaxpr: core.Jaxpr,
) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]:
from jax._src import pjit
from jax.experimental import shard_map

Expand Down Expand Up @@ -393,15 +393,15 @@ def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:

# TODO(mattjj,necula): this duplicates code in core.valid_jaxtype, but one
# internal user relies on it for duck-typing. must fix downstream user!
def _valid_jaxtype(arg):
def _valid_jaxtype(arg: Any) -> bool:
try:
xla.abstractify(arg) # faster than core.get_aval
except TypeError:
return core.valid_jaxtype(arg)
else:
return True

def check_arg(arg):
def check_arg(arg: Any):
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
"JAX type.")
Expand Down

0 comments on commit b024e01

Please sign in to comment.