Skip to content

Commit

Permalink
distrax: Use a safer log_prob in KL divergence between categoricals.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 427788284
  • Loading branch information
suryabhupa authored and DistraxDev committed Feb 23, 2022
1 parent 6318bce commit 27fc742
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
10 changes: 9 additions & 1 deletion distrax/_src/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def _kl_divergence_categorical_categorical(
) -> Array:
"""Obtain the KL divergence `KL(dist1 || dist2)` between two Categoricals.
The KL computation takes into account that `0 * log(0) = 0`; therefore,
`dist1` may have zeros in its probability vector.
Args:
dist1: A Categorical distribution.
dist2: A Categorical distribution.
Expand Down Expand Up @@ -195,7 +198,12 @@ def _kl_divergence_categorical_categorical(
else:
probs1 = dist1.probs

log_probs1 = jax.nn.log_softmax(logits1, axis=-1)
# If any probabilities of the first distribution are 0, we ignore those
# components and set the corresponding log probabilities to 0 instead of
# computing its log softmax. By doing so, we still output a valid KL
# divergence because 0 * log(0) = 0 for those specific components.
log_probs1 = jnp.where(
probs1 == 0, 0., jax.nn.log_softmax(logits1, axis=-1))
log_probs2 = jax.nn.log_softmax(logits2, axis=-1)

return jnp.sum((probs1 * (log_probs1 - log_probs2)), axis=-1)
Expand Down
9 changes: 7 additions & 2 deletions distrax/_src/distributions/categorical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,13 @@ def test_with_two_distributions(self, function_string, mode_string):
super()._test_with_two_distributions(
attribute_string=function_string,
mode_string=mode_string,
dist1_kwargs={'probs': jnp.asarray([[0.1, 0.5, 0.4], [0.2, 0.4, 0.4]])},
dist2_kwargs={'logits': jnp.asarray([0.0, 0.1, 0.1]),},
dist1_kwargs={
'probs':
jnp.asarray([[0.4, 0.0, 0.6], [0.1, 0.5, 0.4], [0.2, 0.4, 0.4]])
},
dist2_kwargs={
'logits': jnp.asarray([0.0, 0.1, 0.1]),
},
assertion_fn=self.assertion_fn)

def test_jittable(self):
Expand Down

0 comments on commit 27fc742

Please sign in to comment.