Skip to content

Commit

Permalink
Upgrade logistic (sigmoid) function into a lax primitive.
Browse files Browse the repository at this point in the history
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469841487
  • Loading branch information
jax authors committed Aug 24, 2022
1 parent 5f2b226 commit 3e3542b
Show file tree
Hide file tree
Showing 14 changed files with 21 additions and 42 deletions.
1 change: 0 additions & 1 deletion docs/jax.lax.rst
Expand Up @@ -103,7 +103,6 @@ Operators
lgamma
log
log1p
logistic
max
min
mul
Expand Down
8 changes: 0 additions & 8 deletions jax/_src/lax/lax.py
Expand Up @@ -300,10 +300,6 @@ def tanh(x: Array) -> Array:
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`."""
return tanh_p.bind(x)

def logistic(x: Array) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
return logistic_p.bind(x)

def sin(x: Array) -> Array:
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
return sin_p.bind(x)
Expand Down Expand Up @@ -1734,10 +1730,6 @@ def _round_lower(ctx, x, *, rounding_method):
sub(_one(x), ans)))
mlir.register_lowering(tanh_p, partial(_nary_lower_mhlo, mhlo.TanhOp))

logistic_p = standard_unop(_float | _complex, 'logistic')
ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans))))
mlir.register_lowering(logistic_p, partial(_nary_lower_mhlo, mhlo.LogisticOp))

sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
if mlir_api_version < 27:
Expand Down
1 change: 0 additions & 1 deletion jax/_src/lax_reference.py
Expand Up @@ -69,7 +69,6 @@ def round(x):
acosh = np.arccosh
atanh = np.arctanh

def logistic(x): return 1 / (1 + np.exp(-x))
def betainc(a, b, x): return scipy.special.betainc(a, b, x).astype(x.dtype)
def lgamma(x): return scipy.special.gammaln(x).astype(x.dtype)
def digamma(x): return scipy.special.digamma(x).astype(x.dtype)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/nn/functions.py
Expand Up @@ -91,7 +91,7 @@ def sigmoid(x: Array) -> Array:
Args:
x : input array
"""
return lax.logistic(x)
return expit(x)

@jax.jit
def silu(x: Array) -> Array:
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/scipy/special.py
Expand Up @@ -96,10 +96,13 @@ def logit(x):
lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x))))


@api.custom_jvp
@_wraps(osp_special.expit, module='scipy.special', update_doc=False)
def expit(x):
x, = _promote_args_inexact("expit", x)
return lax.logistic(x)
one = _lax_const(x, 1)
return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
expit.defjvps(lambda g, ans, x: g * ans * (_lax_const(ans, 1) - ans))


@_wraps(osp_special.logsumexp, module='scipy.special')
Expand Down
2 changes: 0 additions & 2 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1339,8 +1339,6 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray],
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(
lax_internal.atan_impl, multiple_results=False)

tf_impl[lax.logistic_p] = tf.math.sigmoid

def _atan2(y, x, **kwargs):
if x.dtype.is_complex or y.dtype.is_complex:
complex_component_dtype = {
Expand Down
18 changes: 9 additions & 9 deletions jax/experimental/jax2tf/tests/jax2tf_limitations.py
Expand Up @@ -130,15 +130,15 @@ def limitations_for_harness(
"cos", "cosh", "complex", "conj", "convert_element_type", "cummax",
"cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp",
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le",
"logistic", "lt", "log", "mul", "ne", "neg", "not", "or", "pad",
"population_count", "random_categorical", "random_uniform",
"random_randint", "reduce", "reduce_and", "reduce_prod", "reduce_or",
"reduce_sum", "reduce_window_mul", "reduce_window_min",
"reduce_window_max", "real", "reshape", "rev", "rsqrt", "scatter_max",
"scatter_min", "select_n", "select_and_scatter_add", "shift_left",
"shift_right_logical", "shift_right_arithmetic", "sign", "sin", "sinh",
"slice", "sqrt", "squeeze", "stop_gradient", "sub", "tie_in", "transpose",
"xor", "zeros_like"
"lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count",
"random_categorical", "random_uniform", "random_randint",
"reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"reduce_window_mul", "reduce_window_min", "reduce_window_max", "real",
"reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n",
"select_and_scatter_add", "shift_left", "shift_right_logical",
"shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt",
"squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor",
"zeros_like"
}

@classmethod
Expand Down
1 change: 0 additions & 1 deletion jax/experimental/jax2tf/tests/primitive_harness.py
Expand Up @@ -432,7 +432,6 @@ def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype):
_make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype)

for dtype in jtu.dtypes.all_floating:
_make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype)
Expand Down
9 changes: 3 additions & 6 deletions jax/experimental/jet.py
Expand Up @@ -498,11 +498,11 @@ def _integer_pow_taylor(primals_in, series_in, *, y):
jet_rules[lax.integer_pow_p] = _integer_pow_taylor


def _logistic_taylor(primals_in, series_in):
def _expit_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.logistic(x)] + [None] * len(series)
v = [jax.scipy.special.expit(x)] + [None] * len(series)
e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1))
Expand All @@ -511,15 +511,12 @@ def _logistic_taylor(primals_in, series_in):
primal_out, *series_out = v
return primal_out, series_out

jet_rules[lax.logistic_p] = _logistic_taylor


def _tanh_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [2*x] + [2 * series_ for series_ in series]
primals_in, *series_in = u
primal_out, series_out = _logistic_taylor((primals_in, ), (series_in, ))
primal_out, series_out = _expit_taylor((primals_in, ), (series_in, ))
series_out = [2 * series_ for series_ in series_out]
return 2 * primal_out - 1, series_out
jet_rules[lax.tanh_p] = _tanh_taylor
Expand Down
2 changes: 0 additions & 2 deletions jax/lax/__init__.py
Expand Up @@ -139,8 +139,6 @@
log1p as log1p,
log1p_p as log1p_p,
log_p as log_p,
logistic as logistic,
logistic_p as logistic_p,
lt as lt,
lt_p as lt_p,
make_bint as make_bint,
Expand Down
4 changes: 2 additions & 2 deletions tests/jet_test.py
Expand Up @@ -188,7 +188,7 @@ def expit_check(self, lims=(-2, 2), order=3):
primals = (primal_in, )
series = (terms_in, )

y, terms = jax.experimental.jet._logistic_taylor(primals, series)
y, terms = jax.experimental.jet._expit_taylor(primals, series)
expected_y, expected_terms = jvp_taylor(jax.scipy.special.expit, primals, series)

atol = 1e-4
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_cosh(self): self.unary_check(jnp.cosh)
@jtu.skip_on_devices("tpu")
def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_logistic(self): self.unary_check(lax.logistic, lims=[-100, 100], order=5)
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-100, 100], order=5)
@jtu.skip_on_devices("tpu")
def test_expit2(self): self.expit_check(lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
Expand Down
3 changes: 0 additions & 3 deletions tests/lax_autodiff_test.py
Expand Up @@ -134,9 +134,6 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
dtypes=grad_complex_dtypes),
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes, tol={np.float64: 3e-5}),
grad_test_spec(lax.logistic, nargs=1, order=2,
rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),

grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
Expand Down
3 changes: 0 additions & 3 deletions tests/lax_test.py
Expand Up @@ -114,7 +114,6 @@ def op_record(op, nargs, dtypes, rng_factory, tol=None):
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
op_record("tanh", 1, float_dtypes + complex_dtypes, jtu.rand_small,
{np.float64: 1e-9, np.complex128: 1e-7}),
op_record("logistic", 1, float_dtypes + complex_dtypes, jtu.rand_default),
op_record("sin", 1, float_dtypes + complex_dtypes, jtu.rand_default),
op_record("cos", 1, float_dtypes + complex_dtypes, jtu.rand_default),
op_record("atan2", 2, float_dtypes, jtu.rand_default),
Expand Down Expand Up @@ -2959,8 +2958,6 @@ def testArgMaxOfNanChoosesNaN(self):
for op, dtypes in unary_op_types.items()))
def testUnaryWeakTypes(self, op_name, rec_dtypes):
"""Test that all lax unary ops propagate weak_type information appropriately."""
if op_name == "bitwise_not":
raise unittest.SkipTest("https://github.com/google/jax/issues/12066")
# Find a valid dtype for the function.
for dtype in [np.float_, np.int_, np.complex_, np.bool_]:
dtype = dtypes.canonicalize_dtype(dtype)
Expand Down
4 changes: 2 additions & 2 deletions tests/scipy_stats_test.py
Expand Up @@ -351,7 +351,7 @@ def args_maker():

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=3e-5)
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(1)
Expand Down Expand Up @@ -397,7 +397,7 @@ def args_maker():
return list(map(rng, shapes, dtypes))

self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-5)
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
Expand Down

0 comments on commit 3e3542b

Please sign in to comment.