-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Fix siglip
flaky test_eager_matches_sdpa_inference
#40584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -469,10 +469,25 @@ def _test_eager_matches_sdpa_inference( | |
logits_sdpa = _logits_sdpa | ||
logits_eager = _logits_eager | ||
|
||
# Avoid test flakiness with bf16! | ||
# bf16 is not good at precision when the magnitude is larger. We have some models like `SiglipVision` with | ||
# this test passing all the time for fp32/fp16 but flaky with bf16. Furthermore, `llama` and `clip` have | ||
# this test passing all the time for bf16: it turns out their outputs are of smaller size (0.1 and 1.0) | ||
# while `siglip` has outputs with maximal values around 3.0/4.0. | ||
outputs_magnitude = float( | ||
(torch.max(logits_sdpa.abs().amax(), logits_eager.abs().amax())).detach().to("cpu") | ||
) | ||
# The choice of `3e-2` in `outputs_magnitude * 1e-2` might not work if a model has even more larger outputs. | ||
# (we can try to analyze the `rtol` more closely element-wise in the future and adjust the `rtol` instead of `atol`). | ||
computed_atol = outputs_magnitude * 3e-2 | ||
Comment on lines
+480
to
+482
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, i think with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
(but yeah, the comment here is more for future weird models :-) ) |
||
if dtype == torch.bfloat16: | ||
atol = max(atol, computed_atol) | ||
|
||
results = [ | ||
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) | ||
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) | ||
] | ||
|
||
# If 80% batch elements have matched results, it's fine | ||
if np.mean(results) < 0.8: | ||
mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
even with this (retry 5 times), we still had some failures from time to time.
Now we don't need this anymore