From b5e4ba4900aa0c823c623f7a7855f546b508b0cb Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 15 Feb 2024 13:48:49 -0800 Subject: [PATCH] Don't call inspect.signature() each time we trace a jit(). We can just call it once when jit itself is called. While we're here, also don't recompute api_util.fun_sourceinfo. PiperOrigin-RevId: 607443283 --- jax/_src/api.py | 16 +++++++++++++--- jax/_src/api_util.py | 23 ++++++++++++++++------- jax/_src/pjit.py | 18 +++++++++++++++--- 3 files changed, 44 insertions(+), 13 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 7af35361d4ab..c3c3b94381b0 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -41,6 +41,7 @@ tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix, prefix_errors, generate_key_paths) +from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dispatch @@ -60,7 +61,8 @@ argnums_partial_except, flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, shaped_abstractify, _ensure_str_tuple, apply_flat_fun_nokwargs, - check_callable, debug_info, result_paths, flat_out_axes, debug_info_final) + check_callable, debug_info, result_paths, flat_out_axes, debug_info_final, + fun_sourceinfo) from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc @@ -304,12 +306,16 @@ def jit( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes) + fun_sourceinfo = api_util.fun_sourceinfo(fun) + fun_signature = api_util.fun_signature(fun) + def infer_params(*args, **kwargs): # TODO(yashkatariya): Remove this when it's added on jit. in_layouts = kwargs.pop('_in_layouts', None) out_layouts = kwargs.pop('_out_layouts', None) pjit_info_args = pjit.PjitInfo( - fun=fun, in_shardings=in_shardings, + fun=fun, fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, + in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, device=device, backend=backend, @@ -1651,7 +1657,11 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, if in_devices is not None and len(in_devices) == 0: raise ValueError("'devices' argument to pmap must be non-empty, or None.") - dbg = debug_info('pmap', fun, args, kwargs, static_broadcasted_tuple, ()) + src = fun_sourceinfo(fun) + signature = api_util.fun_signature(fun) + + dbg = debug_info('pmap', src, signature, args, kwargs, + static_broadcasted_tuple, ()) f = lu.wrap_init(fun) if static_broadcasted_tuple: diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index cf09e27ed0d1..2f3f6866bca0 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -612,16 +612,24 @@ def api_hook(fun, tag: str): return fun -def debug_info(traced_for: str, fun: Callable, args: tuple[Any], - kwargs: dict[str, Any], static_argnums: tuple[int, ...], - static_argnames: tuple[str, ...]) -> TracingDebugInfo | None: +def debug_info( + traced_for: str, src: str | None, fun_signature: inspect.Signature | None, + args: tuple[Any], kwargs: dict[str, Any], static_argnums: tuple[int, ...], + static_argnames: tuple[str, ...] +) -> TracingDebugInfo | None: """Try to build trace-time debug info for fun when applied to args/kwargs.""" - src = fun_sourceinfo(fun) - arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames) + arg_names = _arg_names(fun_signature, args, kwargs, static_argnums, + static_argnames) if src is None or arg_names is None: return None return TracingDebugInfo(traced_for, src, arg_names, None) +def fun_signature(fun: Callable) -> inspect.Signature | None: + try: + return inspect.signature(fun) + except (ValueError, TypeError): + return None + # TODO(mattjj): make this function internal to this module def fun_sourceinfo(fun: Callable) -> str | None: while isinstance(fun, partial): @@ -634,15 +642,16 @@ def fun_sourceinfo(fun: Callable) -> str | None: except AttributeError: return None -def _arg_names(fn, args, kwargs, static_argnums, static_argnames, +def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames, ) -> tuple[str, ...] | None: + if fn_signature is None: return None static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} try: - ba = inspect.signature(fn).bind(*args_, **kwargs) + ba = fn_signature.bind(*args_, **kwargs) except (ValueError, TypeError): return None return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 557b9e87607a..0f45f5dd665d 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -17,6 +17,7 @@ from collections.abc import Sequence, Iterable import dataclasses from functools import partial, lru_cache +import inspect import itertools as it import logging import operator as op @@ -28,6 +29,7 @@ import numpy as np from jax._src import api +from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dispatch @@ -384,6 +386,8 @@ def _pjit_explicit_sharding(in_shardings, out_shardings, device, class PjitInfo(NamedTuple): fun: Callable + fun_sourceinfo: str | None + fun_signature: inspect.Signature in_shardings: Any out_shardings: Any static_argnums: tuple[int, ...] @@ -401,7 +405,8 @@ class PjitInfo(NamedTuple): def common_infer_params(pjit_info_args, *args, **kwargs): - (fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames, + (fun, fun_sourceinfo, fun_signature, user_in_shardings, user_out_shardings, + static_argnums, static_argnames, donate_argnums, donate_argnames, device, backend, keep_unused, inline, resource_env, abstracted_axes, in_layouts, out_layouts) = pjit_info_args @@ -424,7 +429,8 @@ def common_infer_params(pjit_info_args, *args, **kwargs): jit_name = 'jit' if resource_env is None else 'pjit' - dbg = debug_info(jit_name, fun, args, kwargs, static_argnums, static_argnames) + dbg = debug_info(jit_name, fun_sourceinfo, fun_signature, args, kwargs, + static_argnums, static_argnames) f = lu.wrap_init(fun) f, res_paths = result_paths(f) f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=True) @@ -767,6 +773,9 @@ def pjit( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes) + fun_sourceinfo = api_util.fun_sourceinfo(fun) + fun_signature = api_util.fun_signature(fun) + def infer_params(*args, **kwargs): # Putting this outside of wrapped would make resources lexically scoped resource_env = mesh_lib.thread_resources.env @@ -775,7 +784,10 @@ def infer_params(*args, **kwargs): in_layouts = kwargs.pop('_in_layouts', None) out_layouts = kwargs.pop('_out_layouts', None) pjit_info_args = PjitInfo( - fun=fun, in_shardings=in_shardings, + fun=fun, + fun_sourceinfo=fun_sourceinfo, + fun_signature=fun_signature, + in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, device=device, backend=backend,