Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Reduce number of tests for jax.nn.dot_product_attention #23223

Closed
wants to merge 1 commit into from

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Aug 23, 2024

While adding the new mask mode, sliding_window, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.

For the new tests, we categorize them as follows:

  1. Non-mask tests: These verify the basic functionality of the API, including data types, vmap, groups, etc.
  2. Mask tests: These cover different masking scenarios, such as causal, padding, or other commonly used combinations.

Additionally, we will no longer maintain separate tests for inference and training.

@kaixih
Copy link
Contributor Author

kaixih commented Aug 23, 2024

@sbodenstein Can you take a look?

tests/nn_test.py Outdated
def _get_causal_mask(T, S):
causal_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
return causal_mask[jnp.newaxis, jnp.newaxis, :, :]
def _check_cudnn_backend(fn, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about uses_cudnn_backend to reflect that this is a predicate and it will not raise if the condition is not met?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. PTAL.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Aug 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants