Skip to content

Commit

Permalink
DOC: ensure that _wraps() generates correct links to wrapped functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 21, 2022
1 parent 07fcf79 commit 9769a0a
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 77 deletions.
16 changes: 8 additions & 8 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -302,12 +302,12 @@ def load(*args, **kwargs):

### implementations of numpy functions in terms of lax

@_wraps(np.fmin)
@_wraps(np.fmin, module='numpy')
@jit
def fmin(x1, x2):
return where((x1 < x2) | isnan(x2), x1, x2)

@_wraps(np.fmax)
@_wraps(np.fmax, module='numpy')
@jit
def fmax(x1, x2):
return where((x1 > x2) | isnan(x2), x1, x2)
Expand Down Expand Up @@ -346,7 +346,7 @@ def trapz(y, x=None, dx=1.0, axis: int = -1):
return 0.5 * (dx * (y[..., 1:] + y[..., :-1])).sum(-1)


@_wraps(np.trunc)
@_wraps(np.trunc, module='numpy')
@jit
def trunc(x):
_check_arraylike('trunc', x)
Expand Down Expand Up @@ -2760,7 +2760,7 @@ def dot(a, b, *, precision=None): # pylint: disable=missing-docstring
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)


@_wraps(np.matmul, lax_description=_PRECISION_DOC)
@_wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
_check_arraylike("matmul", a, b)
Expand Down Expand Up @@ -4087,7 +4087,7 @@ def _gcd_body_fn(xs):
where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0)))
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))

@_wraps(np.gcd)
@_wraps(np.gcd, module='numpy')
@jit
def gcd(x1, x2):
_check_arraylike("gcd", x1, x2)
Expand All @@ -4100,7 +4100,7 @@ def gcd(x1, x2):
return gcd


@_wraps(np.lcm)
@_wraps(np.lcm, module='numpy')
@jit
def lcm(x1, x2):
_check_arraylike("lcm", x1, x2)
Expand Down Expand Up @@ -4580,8 +4580,8 @@ def _notimplemented_flat(self):
*** This function is not yet implemented by jax.numpy, and will raise NotImplementedError ***
"""

def _not_implemented(fun):
@_wraps(fun, update_doc=False, lax_description=_NOT_IMPLEMENTED_DESC)
def _not_implemented(fun, module=None):
@_wraps(fun, module=module, update_doc=False, lax_description=_NOT_IMPLEMENTED_DESC)
def wrapped(*args, **kwargs):
msg = "Numpy function {} not yet implemented"
raise NotImplementedError(msg.format(fun))
Expand Down
82 changes: 41 additions & 41 deletions jax/_src/numpy/ufuncs.py
Expand Up @@ -57,9 +57,9 @@ def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
else:
return _wraps(numpy_fn)(fn)
return _wraps(numpy_fn, module='numpy')(fn)


def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
Expand All @@ -70,9 +70,9 @@ def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False)
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
else:
return _wraps(numpy_fn)(fn)
return _wraps(numpy_fn, module='numpy')(fn)


def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False):
Expand All @@ -82,9 +82,9 @@ def fn(x1, x2):
fn = jit(fn, inline=True)
if lax_doc:
doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn)
else:
return _wraps(numpy_fn)(fn)
return _wraps(numpy_fn, module='numpy')(fn)


def _comparison_op(numpy_fn, lax_fn):
Expand All @@ -100,11 +100,11 @@ def fn(x1, x2):
return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
lax_fn(rx, ry))
return lax_fn(x1, x2)
return _wraps(numpy_fn)(fn)
return _wraps(numpy_fn, module='numpy')(fn)


def _logical_op(np_op, bitwise_op):
@_wraps(np_op, update_doc=False)
@_wraps(np_op, update_doc=False, module='numpy')
@partial(jit, inline=True)
def op(*args):
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
Expand Down Expand Up @@ -165,7 +165,7 @@ def op(*args):
logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)


@_wraps(np.arccosh)
@_wraps(np.arccosh, module='numpy')
@jit
def arccosh(x):
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
Expand All @@ -176,7 +176,7 @@ def arccosh(x):
return out


@_wraps(np.right_shift)
@_wraps(np.right_shift, module='numpy')
@partial(jit, inline=True)
def right_shift(x1, x2):
x1, x2 = _promote_args(np.right_shift.__name__, x1, x2)
Expand All @@ -185,16 +185,16 @@ def right_shift(x1, x2):
return lax_fn(x1, x2)


@_wraps(np.absolute)
@_wraps(np.absolute, module='numpy')
@partial(jit, inline=True)
def absolute(x):
_check_arraylike('absolute', x)
dt = dtypes.dtype(x)
return x if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
abs = _wraps(np.abs)(absolute)
abs = _wraps(np.abs, module='numpy')(absolute)


@_wraps(np.rint)
@_wraps(np.rint, module='numpy')
@jit
def rint(x):
_check_arraylike('rint', x)
Expand All @@ -206,7 +206,7 @@ def rint(x):
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)


@_wraps(np.sign)
@_wraps(np.sign, module='numpy')
@jit
def sign(x):
_check_arraylike('sign', x)
Expand All @@ -218,7 +218,7 @@ def sign(x):
return lax.sign(x)


@_wraps(np.copysign)
@_wraps(np.copysign, module='numpy')
@jit
def copysign(x1, x2):
x1, x2 = _promote_args_inexact("copysign", x1, x2)
Expand All @@ -227,7 +227,7 @@ def copysign(x1, x2):
return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))


@_wraps(np.true_divide)
@_wraps(np.true_divide, module='numpy')
@partial(jit, inline=True)
def true_divide(x1, x2):
x1, x2 = _promote_args_inexact("true_divide", x1, x2)
Expand All @@ -236,7 +236,7 @@ def true_divide(x1, x2):
divide = true_divide


@_wraps(np.floor_divide)
@_wraps(np.floor_divide, module='numpy')
@jit
def floor_divide(x1, x2):
x1, x2 = _promote_args("floor_divide", x1, x2)
Expand All @@ -261,7 +261,7 @@ def floor_divide(x1, x2):
return _float_divmod(x1, x2)[0]


@_wraps(np.divmod)
@_wraps(np.divmod, module='numpy')
@jit
def divmod(x1, x2):
x1, x2 = _promote_args("divmod", x1, x2)
Expand Down Expand Up @@ -305,7 +305,7 @@ def _power(x1, x2):
return acc


@_wraps(np.power)
@_wraps(np.power, module='numpy')
def power(x1, x2):
# Special case for concrete integer scalars: use binary exponentiation.
# Using lax.pow may be imprecise for floating-point values; the goal of this
Expand All @@ -322,7 +322,7 @@ def power(x1, x2):


@custom_jvp
@_wraps(np.logaddexp)
@_wraps(np.logaddexp, module='numpy')
@jit
def logaddexp(x1, x2):
x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
Expand Down Expand Up @@ -360,7 +360,7 @@ def _logaddexp_jvp(primals, tangents):


@custom_jvp
@_wraps(np.logaddexp2)
@_wraps(np.logaddexp2, module='numpy')
@jit
def logaddexp2(x1, x2):
x1, x2 = _promote_args_inexact("logaddexp2", x1, x2)
Expand Down Expand Up @@ -388,28 +388,28 @@ def _logaddexp2_jvp(primals, tangents):
return primal_out, tangent_out


@_wraps(np.log2)
@_wraps(np.log2, module='numpy')
@partial(jit, inline=True)
def log2(x):
x, = _promote_args_inexact("log2", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))


@_wraps(np.log10)
@_wraps(np.log10, module='numpy')
@partial(jit, inline=True)
def log10(x):
x, = _promote_args_inexact("log10", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))


@_wraps(np.exp2)
@_wraps(np.exp2, module='numpy')
@partial(jit, inline=True)
def exp2(x):
x, = _promote_args_inexact("exp2", x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))


@_wraps(np.signbit)
@_wraps(np.signbit, module='numpy')
@jit
def signbit(x):
x, = _promote_args("signbit", x)
Expand Down Expand Up @@ -446,7 +446,7 @@ def _normalize_float(x):
return lax.bitcast_convert_type(x1, int_type), x2


@_wraps(np.ldexp)
@_wraps(np.ldexp, module='numpy')
@jit
def ldexp(x1, x2):
_check_arraylike("ldexp", x1, x2)
Expand Down Expand Up @@ -495,7 +495,7 @@ def ldexp(x1, x2):
return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)


@_wraps(np.frexp)
@_wraps(np.frexp, module='numpy')
@jit
def frexp(x):
_check_arraylike("frexp", x)
Expand All @@ -519,7 +519,7 @@ def frexp(x):
return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)


@_wraps(np.remainder)
@_wraps(np.remainder, module='numpy')
@jit
def remainder(x1, x2):
x1, x2 = _promote_args("remainder", x1, x2)
Expand All @@ -529,10 +529,10 @@ def remainder(x1, x2):
do_plus = lax.bitwise_and(
lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
mod = _wraps(np.mod)(remainder)
mod = _wraps(np.mod, module='numpy')(remainder)


@_wraps(np.fmod)
@_wraps(np.fmod, module='numpy')
@jit
def fmod(x1, x2):
_check_arraylike("fmod", x1, x2)
Expand All @@ -541,21 +541,21 @@ def fmod(x1, x2):
return lax.rem(*_promote_args("fmod", x1, x2))


@_wraps(np.square)
@_wraps(np.square, module='numpy')
@partial(jit, inline=True)
def square(x):
_check_arraylike("square", x)
return lax.integer_pow(x, 2)


@_wraps(np.deg2rad)
@_wraps(np.deg2rad, module='numpy')
@partial(jit, inline=True)
def deg2rad(x):
x, = _promote_args_inexact("deg2rad", x)
return lax.mul(x, _lax_const(x, np.pi / 180))


@_wraps(np.rad2deg)
@_wraps(np.rad2deg, module='numpy')
@partial(jit, inline=True)
def rad2deg(x):
x, = _promote_args_inexact("rad2deg", x)
Expand All @@ -566,7 +566,7 @@ def rad2deg(x):
radians = deg2rad


@_wraps(np.conjugate)
@_wraps(np.conjugate, module='numpy')
@partial(jit, inline=True)
def conjugate(x):
_check_arraylike("conjugate", x)
Expand All @@ -587,7 +587,7 @@ def real(val):
_check_arraylike("real", val)
return lax.real(val) if np.iscomplexobj(val) else val

@_wraps(np.modf, skip_params=['out'])
@_wraps(np.modf, module='numpy', skip_params=['out'])
@jit
def modf(x, out=None):
_check_arraylike("modf", x)
Expand All @@ -598,7 +598,7 @@ def modf(x, out=None):
return x - whole, whole


@_wraps(np.isfinite)
@_wraps(np.isfinite, module='numpy')
@jit
def isfinite(x):
_check_arraylike("isfinite", x)
Expand All @@ -611,7 +611,7 @@ def isfinite(x):
return lax.full_like(x, True, dtype=np.bool_)


@_wraps(np.isinf)
@_wraps(np.isinf, module='numpy')
@jit
def isinf(x):
_check_arraylike("isinf", x)
Expand Down Expand Up @@ -649,14 +649,14 @@ def _isposneginf(infinity, x, out):
)


@_wraps(np.isnan)
@_wraps(np.isnan, module='numpy')
@jit
def isnan(x):
_check_arraylike("isnan", x)
return lax.ne(x, x)


@_wraps(np.heaviside)
@_wraps(np.heaviside, module='numpy')
@jit
def heaviside(x1, x2):
_check_arraylike("heaviside", x1, x2)
Expand All @@ -666,7 +666,7 @@ def heaviside(x1, x2):
_where(lax.gt(x1, zero), _lax_const(x1, 1), x2))


@_wraps(np.hypot)
@_wraps(np.hypot, module='numpy')
@jit
def hypot(x1, x2):
_check_arraylike("hypot", x1, x2)
Expand All @@ -677,7 +677,7 @@ def hypot(x1, x2):
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))


@_wraps(np.reciprocal)
@_wraps(np.reciprocal, module='numpy')
@partial(jit, inline=True)
def reciprocal(x):
_check_arraylike("reciprocal", x)
Expand Down

0 comments on commit 9769a0a

Please sign in to comment.