Skip to content

Commit

Permalink
Don't call inspect.signature() each time we trace a jit().
Browse files Browse the repository at this point in the history
We can just call it once when jit itself is called.

While we're here, also don't recompute api_util.fun_sourceinfo.

PiperOrigin-RevId: 607443283
  • Loading branch information
hawkinsp authored and jax authors committed Feb 15, 2024
1 parent 8888006 commit b5e4ba4
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 13 deletions.
16 changes: 13 additions & 3 deletions jax/_src/api.py
Expand Up @@ -41,6 +41,7 @@
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix,
prefix_errors, generate_key_paths)
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
Expand All @@ -60,7 +61,8 @@
argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, apply_flat_fun_nokwargs,
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final)
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final,
fun_sourceinfo)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -304,12 +306,16 @@ def jit(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes)

fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)

def infer_params(*args, **kwargs):
# TODO(yashkatariya): Remove this when it's added on jit.
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
pjit_info_args = pjit.PjitInfo(
fun=fun, in_shardings=in_shardings,
fun=fun, fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature,
in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
Expand Down Expand Up @@ -1651,7 +1657,11 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

dbg = debug_info('pmap', fun, args, kwargs, static_broadcasted_tuple, ())
src = fun_sourceinfo(fun)
signature = api_util.fun_signature(fun)

dbg = debug_info('pmap', src, signature, args, kwargs,
static_broadcasted_tuple, ())

f = lu.wrap_init(fun)
if static_broadcasted_tuple:
Expand Down
23 changes: 16 additions & 7 deletions jax/_src/api_util.py
Expand Up @@ -612,16 +612,24 @@ def api_hook(fun, tag: str):
return fun


def debug_info(traced_for: str, fun: Callable, args: tuple[Any],
kwargs: dict[str, Any], static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]) -> TracingDebugInfo | None:
def debug_info(
traced_for: str, src: str | None, fun_signature: inspect.Signature | None,
args: tuple[Any], kwargs: dict[str, Any], static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]
) -> TracingDebugInfo | None:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
src = fun_sourceinfo(fun)
arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames)
arg_names = _arg_names(fun_signature, args, kwargs, static_argnums,
static_argnames)
if src is None or arg_names is None:
return None
return TracingDebugInfo(traced_for, src, arg_names, None)

def fun_signature(fun: Callable) -> inspect.Signature | None:
try:
return inspect.signature(fun)
except (ValueError, TypeError):
return None

# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> str | None:
while isinstance(fun, partial):
Expand All @@ -634,15 +642,16 @@ def fun_sourceinfo(fun: Callable) -> str | None:
except AttributeError:
return None

def _arg_names(fn, args, kwargs, static_argnums, static_argnames,
def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
) -> tuple[str, ...] | None:
if fn_signature is None: return None
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
try:
ba = inspect.signature(fn).bind(*args_, **kwargs)
ba = fn_signature.bind(*args_, **kwargs)
except (ValueError, TypeError):
return None
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
Expand Down
18 changes: 15 additions & 3 deletions jax/_src/pjit.py
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Sequence, Iterable
import dataclasses
from functools import partial, lru_cache
import inspect
import itertools as it
import logging
import operator as op
Expand All @@ -28,6 +29,7 @@
import numpy as np

from jax._src import api
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
Expand Down Expand Up @@ -384,6 +386,8 @@ def _pjit_explicit_sharding(in_shardings, out_shardings, device,

class PjitInfo(NamedTuple):
fun: Callable
fun_sourceinfo: str | None
fun_signature: inspect.Signature
in_shardings: Any
out_shardings: Any
static_argnums: tuple[int, ...]
Expand All @@ -401,7 +405,8 @@ class PjitInfo(NamedTuple):


def common_infer_params(pjit_info_args, *args, **kwargs):
(fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames,
(fun, fun_sourceinfo, fun_signature, user_in_shardings, user_out_shardings,
static_argnums, static_argnames,
donate_argnums, donate_argnames, device, backend, keep_unused, inline,
resource_env, abstracted_axes, in_layouts, out_layouts) = pjit_info_args

Expand All @@ -424,7 +429,8 @@ def common_infer_params(pjit_info_args, *args, **kwargs):

jit_name = 'jit' if resource_env is None else 'pjit'

dbg = debug_info(jit_name, fun, args, kwargs, static_argnums, static_argnames)
dbg = debug_info(jit_name, fun_sourceinfo, fun_signature, args, kwargs,
static_argnums, static_argnames)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
Expand Down Expand Up @@ -767,6 +773,9 @@ def pjit(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes)

fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)

def infer_params(*args, **kwargs):
# Putting this outside of wrapped would make resources lexically scoped
resource_env = mesh_lib.thread_resources.env
Expand All @@ -775,7 +784,10 @@ def infer_params(*args, **kwargs):
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
pjit_info_args = PjitInfo(
fun=fun, in_shardings=in_shardings,
fun=fun,
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
Expand Down

0 comments on commit b5e4ba4

Please sign in to comment.