New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Significant performance drop: jaxlib 0.4.16 vs 0.4.14 #17686
Comments
Thanks for this report! |
Is there an easy way to reproduce the slowdown? We could try to bisect it down (probably easier using google-internal infra, if we have a repro to run.) |
I was able to repro on V100 16G like this:
Got this output:
There the hlo_rematerialization doesn't work and I still reproduce the issue. |
Thanks, @nouiz ! I'm not sure how well we can translate those setup instructions into a google-internal repro (for bisection purposes), but it's a step forward. Is it easy to grab a profile and look for where the diff may be? |
I'll also note the logs from
are benign and fixed at head, but will need a new jaxlib release. |
Hi Matthew, thanks for looking into this. What type of instructions do you prefer for the bisection purposes? I have included a script I use personally which can also reproduce this error. @nouiz the hlo_rematerialization error will show up for a larger domain than 256 that would saturate the VRAM.
|
I found the XLA commit that cause this regression: 3c665e2197320b95cf913dcf146fc9d35dc4ab49 is the first bad commit Do not choose bitcasts as fusion roots. This just adds indexing overhead without any benefit, and may require extra buffers. PiperOrigin-RevId: 559732964 |
Amazing work, @nouiz ! Should we document your process of bisecting XLA / jaxlib in OSS? Maybe share any notes of what process you used. |
I did a brute force testing of the JAX-toolbox nightly over 3 months. |
We might want to make a new page on jax.readthedocs.io, like "how to debug performance regressions" or "how to bisect a failure down to a specific change". We could provide some scripts too, and/or example git bisect commands (and build commands). I imagine it'd be useful to know how to bisect over nightlies, how to get commit ranges corresponding to nightlies, and how to build from source to get down to a commit. @hawkinsp any thoughts? |
bisecting over nightly builds first and then bisecting from there to get down to a commit is a great idea! |
What I have isn't that smart. It is just doing brute force at this 2 level. As it is at 2 levels, it is faster. It end up fast enough to be useful. I don't know how to easily do a bisect like that. I don't know a tool to do bisect on top of nightly containers. I could make something more complicated, but brute-force is so simple that I didn't try. If someone know how to make them smarter without too much work, I'm happy to hear it. |
Here is a PR with the doc on how I did this investigation: |
What is the conclusion? Revert the offending PR? |
Tracked internally in b/303225846 |
Edit: nevermind, I got it running on a A100 cloud instance and could reproduce the slowdown (for A100 with cuda11, I got MLUPS: 1182.8535481602703 for the fast version and MLUPS: 1077.4404289392467 for the slow version). |
After analyzing the dumped HloModules, I think I found the reason for the regression. After my change, we now fuse a bitcast as operand 0 of scatter fusion. Before, that bitcast was a fusion root of the "previous" fusion. I could not reproduce the slowdown with just that extracted scatter fusion (I guess it depends on the scatter indices, and random indices don't show the problem), but when run with the whole module, the difference is 3.769.591 ns (now) vs 401.699 ns (before). I don't know for sure why this makes such a big difference, but it could be that in the case where the bitcast is not fused into the scatter, we can optimize the kernel that copies the operand to the output. I believe we normally should not fuse anything into scatter operand 0, because then we can possibly do the scatter in-place and save the copying kernel. |
Why this try this small change that limit the risk, don't fuse only bitcast to scatter operands? |
Yes, I think the right fix is to try to make sure we can do Scatter in-place. So don't fuse anything into the first operand of scatter. It will be a separate kernel launch even if we fuse. And if we don't fuse, we also avoid the problem with the useless bitcast kernel. |
I have a pending change for this: openxla/xla#6159 |
@akuegel |
This change has landed in openxla/xla@238685d |
OK this is amazing!!!!!! 🚀 🚀 Thank you @akuegel for the change we now have 41% improvement vs 4.16 and 24% over 0.4.14!! This is actually a very significant jump! Hopefully it will show itself in other use cases as well. Thanks @nouiz and @mattjj for the follow-up and finding the issue! Shall we close the issue or are we going to continue the discussion on how to investigate the regression based on @nouiz PR? Results on a single RTX 6000 Ada:
|
I want to seize this opportunity to direct our attention to this issue I opened earlier as well: #15368 donate_argnums can be very powerful, and I hope it can work with shardmap as well as it does with pmap! |
I think the discussion of how to investigate such regressions could happen in a separate bug. Let's close this one. |
The discussion about how to investigate can continue for now on the PR itself unless you see a need for an issue. |
Description
In XLB library the MLUPs (Millions of Lattice Updates per Second) has dropped by about 15% after updating jaxlib from 0.4.14 to 0.4.16.
Results on a single RTX 6000 Ada:
Version 0.4.14:
Version 0.4.16:
Notice the new warnings
tfrt_cpu_pjrt_client
andhlo_rematerialization
generated in version 0.4.16.The
hlo_rematerialization
is not generated for lower array sizes, but this is not happening in 0.4.14. Also, the error persists with XLA_PYTHON_CLIENT_MEM_FRACTION at 90% as well.Step to reproduce:
Follow the instructions in the library and run:
python examples/performance/MLUPS3d.py 512 200
(the first input in the number of voxels in each dimension and the second number is the number of iterations)
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: