Skip to content

Perplexity: add clipping and from_logits#47

Merged
jshin1394 merged 1 commit intogoogle:mainfrom
jeffcarp:fix-perplexity
Mar 31, 2025
Merged

Perplexity: add clipping and from_logits#47
jshin1394 merged 1 commit intogoogle:mainfrom
jeffcarp:fix-perplexity

Conversation

@jeffcarp
Copy link
Collaborator

@jeffcarp jeffcarp commented Mar 27, 2025

It was pointed out that Perplexity returns NaNs for negative values. This is because our implementation did not clip logit values to [0, 1], whereas the Keras implementation does. [1]

Even with that fix, the tests were failing because Keras defaults to the TensorFlow version of the metric, which applies softmax to the outputs unconditionally [2], unlike the JAX implementation which does not. [3]

Also:

  • Added a from_logits arg, similar to Keras, for users who want to pass raw logits and have us apply softmax internally.
  • Forced all Keras metrics in tests to use the JAX backend for parity.

[1] https://github.com/keras-team/keras/blob/3f8b065e82b17884bd43fcfbd4bd79f18a7019fe/keras/src/backend/jax/nn.py#L582
[2] https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits
[3] https://github.com/keras-team/keras/blob/3f8b065e82b17884bd43fcfbd4bd79f18a7019fe/keras/src/backend/jax/nn.py#L578-L579

@jeffcarp jeffcarp requested a review from jshin1394 March 27, 2025 23:39
@jeffcarp jeffcarp marked this pull request as draft March 27, 2025 23:48
@jeffcarp
Copy link
Collaborator Author

jeffcarp commented Mar 27, 2025

Looking into the test failures... looks like it only fails when the whole test suite is run?

It was pointed out that Perplexity returns NaNs for negative values.
This is because our implementation did not clip logit values to [0, 1],
whereas the Keras implementation does. [1]

Even with that fix, the tests were failing because Keras defaults to the
TensorFlow version of the metric, which applies softmax to the outputs
unconditionally [2], unlike the JAX implementation which does not. [3]

I also added a `from_logits` arg, similar to Keras, for users who want
to pass raw logits and have us apply softmax internally.

[1] https://github.com/keras-team/keras/blob/3f8b065e82b17884bd43fcfbd4bd79f18a7019fe/keras/src/backend/jax/nn.py#L582
[2] https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits
[3] https://github.com/keras-team/keras/blob/3f8b065e82b17884bd43fcfbd4bd79f18a7019fe/keras/src/backend/jax/nn.py#L578-L579
@jeffcarp jeffcarp marked this pull request as ready for review March 28, 2025 23:37
@jeffcarp
Copy link
Collaborator Author

Found the issue - when Keras is imported in other test files first it doesn't have KERAS_BACKEND set correctly.

@jshin1394 jshin1394 merged commit 79fefa2 into google:main Mar 31, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants