Skip to content

Commit

Permalink
Avoid NaN gradients from Categorical KL and entropy.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 469127064
  • Loading branch information
DistraxDev authored and DistraxDev committed Aug 22, 2022
1 parent 0f5f670 commit ff2d8b3
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 28 deletions.
33 changes: 5 additions & 28 deletions distrax/_src/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,10 @@ def prob(self, value: Array) -> Array:
def entropy(self) -> Array:
"""See `Distribution.entropy`."""
if self._logits is None:
return -jnp.sum(
math.multiply_no_nan(jnp.log(self._probs), self._probs), axis=-1)
# The following result can be derived as follows. Write log(p[i]) as:
# s[i]-m-lse(s[i]-m) where m=max(s), then you have:
# sum_i exp(s[i]-m-lse(s-m)) (s[i] - m - lse(s-m))
# = -m - lse(s-m) + sum_i s[i] exp(s[i]-m-lse(s-m))
# = -m - lse(s-m) + (1/exp(lse(s-m))) sum_i s[i] exp(s[i]-m)
# = -m - lse(s-m) + (1/sumexp(s-m)) sum_i s[i] exp(s[i]-m)
# Write x[i]=s[i]-m then you have:
# = -m - lse(x) + (1/sum_exp(x)) sum_i s[i] exp(x[i])
# Negating all of this result is the Shannon (discrete) entropy.
m = jnp.max(self._logits, axis=-1, keepdims=True)
x = self._logits - m
sum_exp_x = jnp.sum(jnp.exp(x), axis=-1)
lse_logits = jnp.squeeze(m, axis=-1) + jnp.log(sum_exp_x)
return lse_logits - jnp.sum(
math.multiply_no_nan(self._logits, jnp.exp(x)), axis=-1) / sum_exp_x
log_probs = jnp.log(self._probs)
else:
log_probs = jax.nn.log_softmax(self._logits)
return -jnp.sum(math.mul_exp(log_probs, log_probs), axis=-1)

def mode(self) -> Array:
"""See `Distribution.mode`."""
Expand Down Expand Up @@ -203,20 +190,10 @@ def _kl_divergence_categorical_categorical(
f'{num_categories1} categories, while the second distribution has '
f'{num_categories2} categories.')

# pylint: disable=protected-access
if dist1._probs is None:
probs1 = jax.nn.softmax(logits1, axis=-1)
else:
probs1 = dist1.probs

log_probs1 = jax.nn.log_softmax(logits1, axis=-1)
log_probs2 = jax.nn.log_softmax(logits2, axis=-1)

# The KL is a sum over the support of `dist1`, that is, over the components of
# `dist1` that have non-zero probability. So we exclude terms with
# `probs1 == 0` by setting them to zero in the sum below.
return jnp.sum(
jnp.where(probs1 == 0, 0., probs1 * (log_probs1 - log_probs2)), axis=-1)
math.mul_exp(log_probs1 - log_probs2, log_probs1), axis=-1)


# Register the KL functions with TFP.
Expand Down
44 changes: 44 additions & 0 deletions distrax/_src/distributions/categorical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,50 @@ def summed_dist(logits):
x = jnp.array([[[0]], [[1]], [[0]]], jnp.int_)
self.assertion_fn(rtol=1e-6)(actual.log_prob(x), expected.log_prob(x))

@parameterized.named_parameters(
('-inf logits', np.array([-jnp.inf, 2, -3, -jnp.inf, 5.0])),
('uniform large negative logits', np.array([-1e9] * 11)),
('uniform large positive logits', np.array([1e9] * 11)),
('uniform', np.array([0.0] * 11)),
('typical', np.array([1, 7, -3, 2, 4.0])),
)
def test_entropy_grad(self, logits):
clipped_logits = jnp.maximum(-10000, logits)

def entropy_fn(logits):
return categorical.Categorical(logits).entropy()
entropy, grads = jax.value_and_grad(entropy_fn)(logits)
expected_entropy, expected_grads = jax.value_and_grad(entropy_fn)(
clipped_logits)
self.assertion_fn(rtol=1e-6)(expected_entropy, entropy)
self.assertion_fn(rtol=1e-6)(expected_grads, grads)
self.assertTrue(np.isfinite(entropy).all())
self.assertTrue(np.isfinite(grads).all())

@parameterized.named_parameters(
('-inf logits1', np.array([-jnp.inf, 2, -3, -jnp.inf, 5.0]),
np.array([1, 7, -3, 2, 4.0])),
('-inf logits both', np.array([-jnp.inf, 2, -1000, -jnp.inf, 5.0]),
np.array([-jnp.inf, 7, -jnp.inf, 2, 4.0])),
('typical', np.array([5, -2, 0, 1, 4.0]),
np.array([1, 7, -3, 2, 4.0])),
)
def test_kl_grad(self, logits1, logits2):
clipped_logits1 = jnp.maximum(-10000, logits1)
clipped_logits2 = jnp.maximum(-10000, logits2)

def kl_fn(logits1, logits2):
return categorical.Categorical(logits1).kl_divergence(
categorical.Categorical(logits2))
kl, grads = jax.value_and_grad(
kl_fn, argnums=(0, 1))(logits1, logits2)
expected_kl, expected_grads = jax.value_and_grad(
kl_fn, argnums=(0, 1))(clipped_logits1, clipped_logits2)
self.assertion_fn(rtol=1e-6)(expected_kl, kl)
self.assertion_fn(rtol=1e-6)(expected_grads, grads)
self.assertTrue(np.isfinite(kl).all())
self.assertTrue(np.isfinite(grads).all())


if __name__ == '__main__':
absltest.main()
18 changes: 18 additions & 0 deletions distrax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ def power_no_nan(x: Array, y: Array) -> Array:
return jnp.where(y == 0, jnp.ones((), dtype=dtype), jnp.power(x, y))


def mul_exp(x: Array, logp: Array) -> Array:
"""Returns `x * exp(logp)` with zero output if `exp(logp)==0`.
Args:
x: An array.
logp: An array.
Returns:
`x * exp(logp)` with zero output and zero gradient if `exp(logp)==0`,
even if `x` is NaN or infinite.
"""
p = jnp.exp(logp)
# If p==0, the gradient with respect to logp is zero,
# so we can replace the possibly non-finite `x` with zero.
x = jnp.where(p == 0, 0.0, x)
return x * p


def normalize(
*, probs: Optional[Array] = None, logits: Optional[Array] = None) -> Array:
"""Normalize logits (via log_softmax) or probs (ensuring they sum to one)."""
Expand Down

0 comments on commit ff2d8b3

Please sign in to comment.