Skip to content

Commit

Permalink
jax.nn.softmax: fix fill value when where is specified
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 1, 2023
1 parent ae9d149 commit c474de4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
15 changes: 12 additions & 3 deletions jax/_src/nn/functions.py
Expand Up @@ -317,7 +317,10 @@ def log_softmax(x: Array,
shifted = x - lax.stop_gradient(x_max)
shifted_logsumexp = jnp.log(
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
return shifted - shifted_logsumexp
result = shifted - shifted_logsumexp
if where is not None:
return jnp.where(where, result, -jnp.inf)
return result


# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
Expand Down Expand Up @@ -357,7 +360,10 @@ def _softmax(
initial: Optional[Array] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - x_max)
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
if where is not None:
result = jnp.where(where, result, 0)
return result

@_softmax.defjvp
def _softmax_jvp(axis, primals, tangents):
Expand All @@ -368,7 +374,10 @@ def _softmax_jvp(axis, primals, tangents):
def _softmax_deprecated(x, axis, where, initial):
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
if where is not None:
result = jnp.where(where, result, 0)
return result


@partial(jax.jit, static_argnames=("axis",))
Expand Down
9 changes: 4 additions & 5 deletions tests/nn_test.py
Expand Up @@ -133,13 +133,12 @@ def testHardTanhMemory(self):
def testSoftmaxWhereMask(self, fn):
x = jnp.array([5.5, 1.3, -4.2, 0.9])
m = jnp.array([True, False, True, True])
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))

out_masked = jnp.take(
fn(x, where=m, initial=-jnp.inf), jnp.array([0, 2, 3]))
out_filtered = fn(x_filtered)
out = fn(x, where=m, initial=-jnp.inf)
self.assertAllClose(out[m], fn(x[m]))

self.assertAllClose(out_masked, out_filtered)
probs = out if fn is nn.softmax else jnp.exp(out)
self.assertAllClose(probs.sum(), 1.0)

# TODO(mattjj): include log_softmax in these extra tests if/when we add a
# custom_jvp rule for it (since otherwise it doesn't pass the numerical
Expand Down

0 comments on commit c474de4

Please sign in to comment.