diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index cbc34f3b9cd5..3b4ca8adf0bd 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -317,7 +317,10 @@ def log_softmax(x: Array, shifted = x - lax.stop_gradient(x_max) shifted_logsumexp = jnp.log( jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True)) - return shifted - shifted_logsumexp + result = shifted - shifted_logsumexp + if where is not None: + return jnp.where(where, result, -jnp.inf) + return result # TODO(phawkins): this jit was found to change numerics in a test. Debug this. @@ -357,7 +360,10 @@ def _softmax( initial: Optional[Array] = None) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) unnormalized = jnp.exp(x - x_max) - return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True) + result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True) + if where is not None: + result = jnp.where(where, result, 0) + return result @_softmax.defjvp def _softmax_jvp(axis, primals, tangents): @@ -368,7 +374,10 @@ def _softmax_jvp(axis, primals, tangents): def _softmax_deprecated(x, axis, where, initial): x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) unnormalized = jnp.exp(x - lax.stop_gradient(x_max)) - return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True) + result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True) + if where is not None: + result = jnp.where(where, result, 0) + return result @partial(jax.jit, static_argnames=("axis",)) diff --git a/tests/nn_test.py b/tests/nn_test.py index e558c0b6b291..46554e5bc2b8 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -133,13 +133,12 @@ def testHardTanhMemory(self): def testSoftmaxWhereMask(self, fn): x = jnp.array([5.5, 1.3, -4.2, 0.9]) m = jnp.array([True, False, True, True]) - x_filtered = jnp.take(x, jnp.array([0, 2, 3])) - out_masked = jnp.take( - fn(x, where=m, initial=-jnp.inf), jnp.array([0, 2, 3])) - out_filtered = fn(x_filtered) + out = fn(x, where=m, initial=-jnp.inf) + self.assertAllClose(out[m], fn(x[m])) - self.assertAllClose(out_masked, out_filtered) + probs = out if fn is nn.softmax else jnp.exp(out) + self.assertAllClose(probs.sum(), 1.0) # TODO(mattjj): include log_softmax in these extra tests if/when we add a # custom_jvp rule for it (since otherwise it doesn't pass the numerical