Skip to content

Commit

Permalink
feat: validate jit args
Browse files Browse the repository at this point in the history
  • Loading branch information
JeppeKlitgaard committed May 18, 2022
1 parent 6110be4 commit 838a053
Show file tree
Hide file tree
Showing 8 changed files with 378 additions and 175 deletions.
192 changes: 103 additions & 89 deletions jax/_src/api.py
Expand Up @@ -32,6 +32,7 @@
from typing import (Any, Callable, Iterable, NamedTuple, Mapping, Optional,
Sequence, Tuple, TypeVar, Union, overload, Dict, Hashable,
List)
from typing_extensions import Literal
from warnings import warn

import numpy as np
Expand All @@ -56,7 +57,7 @@
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, argnames_partial_except)
shaped_abstractify, _ensure_str_tuple, argnames_partial_except, validate_argnames, validate_argnums)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
Expand Down Expand Up @@ -178,6 +179,7 @@ def _check_callable(fun):
raise TypeError(f"Expected a function, got a generator function: {fun}")

def _isgeneratorfunction(fun):
# TODO 3.9+: remove
# re-implemented here because of https://bugs.python.org/issue33261
while inspect.ismethod(fun):
fun = fun.__func__
Expand All @@ -188,40 +190,35 @@ def _isgeneratorfunction(fun):
_POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD

def _infer_argnums_and_argnames(
fun: Callable,
sig: inspect.Signature,
argnums: Union[int, Iterable[int], None],
argnames: Union[str, Iterable[str], None],
) -> Tuple[Tuple[int, ...], Tuple[str, ...]]:
"""Infer missing argnums and argnames for a function with inspect."""
if argnums is None and argnames is None:
argnums = ()
argnames = ()
elif argnums is not None and argnames is not None:
return (), ()

if argnums is not None and argnames is not None:
argnums = _ensure_index_tuple(argnums)
argnames = _ensure_str_tuple(argnames)

return argnums, argnames

parameters = sig.parameters
if argnums is None:
assert argnames is not None
argnames = _ensure_str_tuple(argnames)
argnums = tuple(
i for i, (k, param) in enumerate(parameters.items())
if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames
)
else:
try:
signature = inspect.signature(fun)
except ValueError:
# In rare cases, inspect can fail, e.g., on some builtin Python functions.
# In these cases, don't infer any parameters.
parameters: Mapping[str, inspect.Parameter] = {}
else:
parameters = signature.parameters
if argnums is None:
assert argnames is not None
argnames = _ensure_str_tuple(argnames)
argnums = tuple(
i for i, (k, param) in enumerate(parameters.items())
if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames
)
else:
assert argnames is None
argnums = _ensure_index_tuple(argnums)
argnames = tuple(
k for i, (k, param) in enumerate(parameters.items())
if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums
)
argnums = _ensure_index_tuple(argnums)
argnames = tuple(
k for i, (k, param) in enumerate(parameters.items())
if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums
)

return argnums, argnames


Expand Down Expand Up @@ -332,15 +329,63 @@ def jit(
DeviceArray([ 0, 1, 256, 6561], dtype=int32)
"""
if FLAGS.experimental_cpp_jit and not config.jax_dynamic_shapes:
return _cpp_jit(fun, static_argnums, static_argnames, device, backend,
return _jit(True, fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline, keep_unused)
return _jit(False, fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline, keep_unused, abstracted_axes)

def _jit(
use_cpp_jit: bool,
fun: Callable,
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
keep_unused: bool = False,
abstracted_axes: Optional[Any] = None,
) -> stages.Wrapped:
# Implemements common logic between CPP and Python backends
_check_callable(fun)

# Coerce input
donate_argnums = _ensure_index_tuple(donate_argnums)

try:
sig = inspect.signature(fun)
except ValueError:
# Some built-in functions don't support signature.
# See: https://github.com/python/cpython/issues/73485
# In this case no validation is done
static_argnums = () if static_argnums is None else _ensure_index_tuple(static_argnums)
static_argnames = () if static_argnames is None else _ensure_str_tuple(static_argnames)
else:
return _python_jit(fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline, keep_unused, abstracted_axes)
# Infer argnums and argnames according to docstring
static_argnums, static_argnames = _infer_argnums_and_argnames(
sig, static_argnums, static_argnames)

# Validation
validate_argnums(sig, static_argnums, "static_argnums")
validate_argnums(sig, donate_argnums, "donate_argnums")

validate_argnames(sig, static_argnames, "static_argnames")

# Compensate for static argnums absorbing args
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)

if use_cpp_jit:
return _cpp_jit(fun, static_argnums=static_argnums, static_argnames=static_argnames,
device=device, backend=backend,
donate_argnums=donate_argnums, inline=inline, keep_unused=keep_unused)

return _python_jit(fun, static_argnums=static_argnums, static_argnames=static_argnames,
device=device, backend=backend, donate_argnums=donate_argnums,
inline=inline, keep_unused=keep_unused, abstracted_axes=abstracted_axes)

def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
args, kwargs):
# Validate donate_argnums
if max(donate_argnums, default=-1) >= len(args):
raise ValueError(
f"jitted function has donate_argnums={donate_argnums} but "
Expand All @@ -362,22 +407,16 @@ def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,

def _python_jit(
fun: Callable,
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
keep_unused: bool = False,
abstracted_axes: Optional[PytreeOfAbstractedAxesSpec] = None,
*,
static_argnums: Tuple[int, ...],
static_argnames: Tuple[str, ...],
device: Optional[xc.Device],
backend: Optional[str],
donate_argnums: Tuple[int, ...],
inline: bool,
keep_unused: bool,
abstracted_axes: Optional[PytreeOfAbstractedAxesSpec],
) -> stages.Wrapped:
_check_callable(fun)
static_argnums, static_argnames = _infer_argnums_and_argnames(
fun, static_argnums, static_argnames)
static_argnums = _ensure_index_tuple(static_argnums)
donate_argnums = _ensure_index_tuple(donate_argnums)
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)

@wraps(fun)
@api_boundary
def f_jitted(*args, **kwargs):
Expand Down Expand Up @@ -429,13 +468,14 @@ class _FastpathData(NamedTuple):

def _cpp_jit(
fun: Callable,
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
keep_unused: bool = False,
*,
static_argnums: Tuple[int, ...],
static_argnames: Tuple[str, ...],
device: Optional[xc.Device],
backend: Optional[str],
donate_argnums: Tuple[int, ...],
inline: bool,
keep_unused: bool,
) -> stages.Wrapped:
# An implementation of `jit` that tries to do as much as possible in C++.
# The goal of this function is to speed up the time it takes to process the
Expand All @@ -444,13 +484,6 @@ def _cpp_jit(
# As long as it does not support all features of the Python implementation
# the C++ code will fallback to `_python_jit` when it faces some unsupported
# feature.
_check_callable(fun)
static_argnums, static_argnames = _infer_argnums_and_argnames(
fun, static_argnums, static_argnames)
static_argnums = _ensure_index_tuple(static_argnums)
donate_argnums = _ensure_index_tuple(donate_argnums)
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)

if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got device={device} and backend={backend}.")
Expand Down Expand Up @@ -2372,37 +2405,18 @@ def _vjp_pullback_wrapper(cotangent_dtypes, cotangent_shapes,
ans = fun(*args)
return tree_unflatten(out_tree, ans)

@overload
def vjp(fun: Callable[..., T],
*primals: Any,
has_aux: Literal[False] = False,
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable]:
...

if sys.version_info >= (3, 8):
from typing import Literal

@overload # type: ignore
def vjp(fun: Callable[..., T],
*primals: Any,
has_aux: Literal[False] = False,
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable]:
...

@overload
def vjp(fun: Callable[..., Tuple[T, U]], *primals: Any,
has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable, U]:
...
else:

@overload # type: ignore
def vjp(fun: Callable[..., T], *primals: Any) -> Tuple[T, Callable]:
...

@overload
def vjp(
fun: Callable[..., Any], *primals: Any,
has_aux: bool,
reduce_axes: Sequence[AxisName] = ()
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
...


@overload
def vjp(fun: Callable[..., Tuple[T, U]], *primals: Any,
has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable, U]:
...
def vjp( # type: ignore
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
Expand Down

0 comments on commit 838a053

Please sign in to comment.