Skip to content

Commit

Permalink
Add @jit decorators to another tranche of jax.numpy functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 390419775
  • Loading branch information
hawkinsp authored and jax authors committed Aug 12, 2021
1 parent 729b21b commit d82341d
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,12 @@ def _constant_like(x, const):
### implementations of numpy functions in terms of lax

@_wraps(np.fmin)
@jit
def fmin(x1, x2):
return where((x1 < x2) | isnan(x2), x1, x2)

@_wraps(np.fmax)
@jit
def fmax(x1, x2):
return where((x1 > x2) | isnan(x2), x1, x2)

Expand Down Expand Up @@ -483,6 +485,7 @@ def fn(x1, x2):
nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True)

@_wraps(np.arccosh)
@jit
def arccosh(x):
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
# convention than np.arccosh.
Expand All @@ -492,6 +495,8 @@ def arccosh(x):
return out

def _comparison_op(numpy_fn, lax_fn):
# TODO(https://github.com/google/jax/issues/6713): decorate this function with
# jit, after fixing a surprising interaction with remat(..., concrete=True).
def fn(x1, x2):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
# Comparison on complex types are defined as a lexicographic ordering on
Expand All @@ -512,6 +517,7 @@ def fn(x1, x2):

def _logical_op(np_op, bitwise_op):
@_wraps(np_op, update_doc=False)
@partial(jit, inline=True)
def op(*args):
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
args = (x if issubdtype(_dtype(x), bool_) else lax.ne(x, zero(x))
Expand All @@ -526,6 +532,7 @@ def op(*args):


@_wraps(np.right_shift)
@partial(jit, inline=True)
def right_shift(x1, x2):
x1, x2 = _promote_args(np.right_shift.__name__, x1, x2)
lax_fn = lax.shift_right_logical if \
Expand All @@ -534,6 +541,7 @@ def right_shift(x1, x2):


@_wraps(np.absolute)
@partial(jit, inline=True)
def absolute(x):
_check_arraylike('absolute', x)
dt = _dtype(x)
Expand All @@ -542,6 +550,7 @@ def absolute(x):


@_wraps(np.rint)
@jit
def rint(x):
_check_arraylike('rint', x)
dtype = _dtype(x)
Expand All @@ -553,6 +562,7 @@ def rint(x):


@_wraps(np.sign)
@jit
def sign(x):
_check_arraylike('sign', x)
dtype = _dtype(x)
Expand All @@ -564,6 +574,7 @@ def sign(x):


@_wraps(np.copysign)
@jit
def copysign(x1, x2):
x1, x2 = _promote_args_inexact("copysign", x1, x2)
if issubdtype(_dtype(x1), complexfloating):
Expand All @@ -572,13 +583,15 @@ def copysign(x1, x2):


@_wraps(np.true_divide)
@partial(jit, inline=True)
def true_divide(x1, x2):
x1, x2 = _promote_args_inexact("true_divide", x1, x2)
return lax.div(x1, x2)

divide = true_divide

@_wraps(np.floor_divide)
@jit
def floor_divide(x1, x2):
x1, x2 = _promote_args("floor_divide", x1, x2)
dtype = _dtype(x1)
Expand All @@ -603,6 +616,7 @@ def floor_divide(x1, x2):


@_wraps(np.divmod)
@jit
def divmod(x1, x2):
x1, x2 = _promote_args("divmod", x1, x2)
if issubdtype(_dtype(x1), integer):
Expand Down Expand Up @@ -723,23 +737,27 @@ def _logaddexp2_jvp(primals, tangents):


@_wraps(np.log2)
@partial(jit, inline=True)
def log2(x):
x, = _promote_dtypes_inexact(x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))


@_wraps(np.log10)
@partial(jit, inline=True)
def log10(x):
x, = _promote_dtypes_inexact(x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))


@_wraps(np.exp2)
@partial(jit, inline=True)
def exp2(x):
x, = _promote_dtypes_inexact(x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))

@_wraps(np.signbit)
@jit
def signbit(x):
x, = _promote_shapes("signbit", x)
dtype = _dtype(x)
Expand Down Expand Up @@ -767,6 +785,7 @@ def signbit(x):


@_wraps(np.trapz)
@partial(jit, static_argnames=('axis',))
def trapz(y, x=None, dx=1.0, axis: int = -1):
_check_arraylike('trapz', y)
y = moveaxis(y, axis, -1)
Expand All @@ -779,6 +798,7 @@ def trapz(y, x=None, dx=1.0, axis: int = -1):


@_wraps(np.trunc)
@jit
def trunc(x):
_check_arraylike('trunc', x)
return where(lax.lt(x, lax._const(x, 0)), ceil(x), floor(x))
Expand Down Expand Up @@ -818,12 +838,14 @@ def _conv(x, y, mode, op, precision):


@_wraps(np.convolve, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('mode', 'precision'))
def convolve(a, v, mode='full', *, precision=None):
_check_arraylike("convolve", a, v)
return _conv(a, v, mode, 'convolve', precision)


@_wraps(np.correlate, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('mode', 'precision'))
def correlate(a, v, mode='valid', *, precision=None):
_check_arraylike("correlate", a, v)
return _conv(a, v, mode, 'correlate', precision)
Expand Down Expand Up @@ -905,6 +927,7 @@ def frexp(x):


@_wraps(np.remainder)
@jit
def remainder(x1, x2):
x1, x2 = _promote_args("remainder", x1, x2)
zero = _constant_like(x1, 0)
Expand All @@ -917,6 +940,7 @@ def remainder(x1, x2):


@_wraps(np.fmod)
@jit
def fmod(x1, x2):
_check_arraylike("fmod", x1, x2)
if issubdtype(_dtype(x1, x2), integer):
Expand All @@ -925,19 +949,22 @@ def fmod(x1, x2):


@_wraps(np.square)
@partial(jit, inline=True)
def square(x):
_check_arraylike("square", x)
return lax.integer_pow(x, 2)


@_wraps(np.deg2rad)
@partial(jit, inline=True)
def deg2rad(x):
_check_arraylike("deg2rad", x)
x, = _promote_dtypes_inexact(x)
return lax.mul(x, lax._const(x, pi / 180))


@_wraps(np.rad2deg)
@partial(jit, inline=True)
def rad2deg(x):
_check_arraylike("rad2deg", x)
x, = _promote_dtypes_inexact(x)
Expand Down Expand Up @@ -1050,6 +1077,7 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=None):
return hist, bin_edges_by_dim

@_wraps(np.heaviside)
@jit
def heaviside(x1, x2):
_check_arraylike("heaviside", x1, x2)
x1, x2 = _promote_dtypes_inexact(x1, x2)
Expand All @@ -1059,6 +1087,7 @@ def heaviside(x1, x2):


@_wraps(np.hypot)
@jit
def hypot(x1, x2):
_check_arraylike("hypot", x1, x2)
x1, x2 = _promote_dtypes_inexact(x1, x2)
Expand All @@ -1069,13 +1098,15 @@ def hypot(x1, x2):


@_wraps(np.reciprocal)
@partial(jit, inline=True)
def reciprocal(x):
_check_arraylike("reciprocal", x)
x, = _promote_dtypes_inexact(x)
return lax.integer_pow(x, -1)


@_wraps(np.sinc, update_doc=False)
@jit
def sinc(x):
_check_arraylike("sinc", x)
x, = _promote_dtypes_inexact(x)
Expand Down Expand Up @@ -1153,35 +1184,41 @@ def flipud(m):


@_wraps(np.conjugate)
@partial(jit, inline=True)
def conjugate(x):
_check_arraylike("conjugate", x)
return lax.conj(x) if iscomplexobj(x) else x
conj = conjugate


@_wraps(np.imag)
@partial(jit, inline=True)
def imag(val):
_check_arraylike("imag", val)
return lax.imag(val) if iscomplexobj(val) else zeros_like(val)


@_wraps(np.real)
@partial(jit, inline=True)
def real(val):
_check_arraylike("real", val)
return lax.real(val) if iscomplexobj(val) else val


@_wraps(np.iscomplex)
@jit
def iscomplex(x):
i = imag(x)
return lax.ne(i, lax._const(i, 0))

@_wraps(np.isreal)
@jit
def isreal(x):
i = imag(x)
return lax.eq(i, lax._const(i, 0))

@_wraps(np.angle)
@jit
def angle(z):
re = real(z)
im = imag(z)
Expand All @@ -1195,6 +1232,7 @@ def angle(z):


@_wraps(np.diff)
@partial(jit, static_argnames=('n', 'axis'))
def diff(a, n=1, axis: int = -1, prepend=None, append=None):
_check_arraylike("diff", a)
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff")
Expand Down Expand Up @@ -1375,6 +1413,7 @@ def _transpose(a, *args):
return transpose(a, axis)

@_wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('order',), inline=True)
def ravel(a, order="C"):
_check_arraylike("ravel", a)
if order == "K":
Expand Down

0 comments on commit d82341d

Please sign in to comment.