-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make jax_debug_nans and jax_debug_infs work with pmap, xmap, and pjit.
Note that unlike in the jit case, this doesn't rerun the function in op-by-op mode when it finds a nan, since we don't have op-by-op parallel execution yet :) This change doesn't appear to regress performance: ``` ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative mean/baseline ------- --------- --------- -------- ---------- --------------- 10 8 0.105598 5.06671 1 1.00693 100 8 0.287756 0.870751 2.72502 0.973204 500 8 1.20119 0.823624 11.3752 0.955185 1000 8 2.56071 0 24.2497 0.983063 5000 8 12.909 0 122.247 0.965925 100 2 0.173727 5.15115 1.64518 0.98918 100 4 0.207774 3.71411 1.9676 0.955849 100 8 0.286103 1.60243 2.70937 0.971869 100 100 2.34168 0 22.1755 0.904475 100 500 15.9558 0 151.1 1.00483 ``` Fixes #6044
- Loading branch information
Showing
3 changed files
with
80 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters