Skip to content

Commit

Permalink
Fix categorical KL computation when the second distribution has zero …
Browse files Browse the repository at this point in the history
…probability on entries where the first distribution also has zero probability.

PiperOrigin-RevId: 432144782
  • Loading branch information
franrruiz authored and DistraxDev committed Mar 8, 2022
1 parent 20d3df2 commit 2f7113f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
13 changes: 6 additions & 7 deletions distrax/_src/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,14 @@ def _kl_divergence_categorical_categorical(
else:
probs1 = dist1.probs

# If any probabilities of the first distribution are 0, we ignore those
# components and set the corresponding log probabilities to 0 instead of
# computing its log softmax. By doing so, we still output a valid KL
# divergence because 0 * log(0) = 0 for those specific components.
log_probs1 = jnp.where(
probs1 == 0, 0., jax.nn.log_softmax(logits1, axis=-1))
log_probs1 = jax.nn.log_softmax(logits1, axis=-1)
log_probs2 = jax.nn.log_softmax(logits2, axis=-1)

return jnp.sum((probs1 * (log_probs1 - log_probs2)), axis=-1)
# The KL is a sum over the support of `dist1`, that is, over the components of
# `dist1` that have non-zero probability. So we exclude terms with
# `probs1 == 0` by setting them to zero in the sum below.
return jnp.sum(
jnp.where(probs1 == 0, 0., probs1 * (log_probs1 - log_probs2)), axis=-1)


# Register the KL functions with TFP.
Expand Down
24 changes: 23 additions & 1 deletion distrax/_src/distributions/categorical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,35 @@ def test_with_two_distributions(self, function_string, mode_string):
mode_string=mode_string,
dist1_kwargs={
'probs':
jnp.asarray([[0.4, 0.0, 0.6], [0.1, 0.5, 0.4], [0.2, 0.4, 0.4]])
jnp.asarray([[0.1, 0.5, 0.4], [0.2, 0.4, 0.4]])
},
dist2_kwargs={
'logits': jnp.asarray([0.0, 0.1, 0.1]),
},
assertion_fn=self.assertion_fn)

@chex.all_variants(with_pmap=False)
@parameterized.named_parameters(
('kl distrax_to_distrax', 'kl_divergence', 'distrax_to_distrax'),
('kl distrax_to_tfp', 'kl_divergence', 'distrax_to_tfp'),
('kl tfp_to_distrax', 'kl_divergence', 'tfp_to_distrax'),
('cross-ent distrax_to_distrax', 'cross_entropy', 'distrax_to_distrax'),
('cross-ent distrax_to_tfp', 'cross_entropy', 'distrax_to_tfp'),
('cross-ent tfp_to_distrax', 'cross_entropy', 'tfp_to_distrax'))
def test_with_two_distributions_extreme_cases(
self, function_string, mode_string):
super()._test_with_two_distributions(
attribute_string=function_string,
mode_string=mode_string,
dist1_kwargs={
'probs':
jnp.asarray([[0.1, 0.5, 0.4], [0.4, 0.0, 0.6], [0.4, 0.6, 0.]])
},
dist2_kwargs={
'logits': jnp.asarray([0.0, 0.1, -jnp.inf]),
},
assertion_fn=self.assertion_fn)

def test_jittable(self):
super()._test_jittable((np.array([0., 4., -1., 4.]),))

Expand Down

0 comments on commit 2f7113f

Please sign in to comment.