Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant performance drop: jaxlib 0.4.16 vs 0.4.14 #17686

Closed
mehdiataei opened this issue Sep 20, 2023 · 26 comments
Closed

Significant performance drop: jaxlib 0.4.16 vs 0.4.14 #17686

mehdiataei opened this issue Sep 20, 2023 · 26 comments
Labels
bug Something isn't working

Comments

@mehdiataei
Copy link
Contributor

Description

In XLB library the MLUPs (Millions of Lattice Updates per Second) has dropped by about 15% after updating jaxlib from 0.4.14 to 0.4.16.

Results on a single RTX 6000 Ada:

Version 0.4.14:

omega =  0.4918839153959666
XLA backend: gpu
Number of XLA devices available: 1
WARNING: Checkpointing is disabled for this simulation.
Time to create the grid connectivity bitmask: 0.17581534385681152
Time to create the local bitmasks and normal arrays: 6.792882680892944
WARNING: Default initial conditions assumed: density = 1, velocity = 0
         To set explicit initial density and velocity, use self.initialize_macroscopic_fields.
Domain: 512 x 512 x 512
Number of voxels: 134217728
MLUPS: 827.4825740398888

Version 0.4.16:

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695236429.981993  393900 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
omega =  0.4918839153959666
XLA backend: gpu
Number of XLA devices available: 1
WARNING: Checkpointing is disabled for this simulation.
Time to create the grid connectivity bitmask: 0.17301297187805176
Time to create the local bitmasks and normal arrays: 6.891621112823486
WARNING: Default initial conditions assumed: density = 1, velocity = 0
         To set explicit initial density and velocity, use self.initialize_macroscopic_fields.
W0000 00:00:1695236443.334162  393900 hlo_rematerialization.cc:2946] Can't reduce memory use below 27.08GiB (29077237923 bytes) by rematerialization; only reduced to 29.54GiB (31714181580 bytes), down from 29.54GiB (31714181580 bytes) originally
Domain: 512 x 512 x 512
Number of voxels: 134217728
MLUPS: 726.4662730520944
I0000 00:00:1695236480.893240  393900 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.

Notice the new warnings tfrt_cpu_pjrt_client and hlo_rematerialization generated in version 0.4.16.
The hlo_rematerialization is not generated for lower array sizes, but this is not happening in 0.4.14. Also, the error persists with XLA_PYTHON_CLIENT_MEM_FRACTION at 90% as well.

Step to reproduce:
Follow the instructions in the library and run:

python examples/performance/MLUPS3d.py 512 200

(the first input in the number of voxels in each dimension and the second number is the number of iterations)

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX 6000 Ada Gene...    Off | 00000000:41:00.0 Off |                  Off |
| 30%   57C    P0              69W / 300W |      3MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX 6000 Ada Gene...    Off | 00000000:61:00.0 Off |                  Off |
| 30%   55C    P0              61W / 300W |      3MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
@mehdiataei mehdiataei added the bug Something isn't working label Sep 20, 2023
@mattjj
Copy link
Member

mattjj commented Sep 20, 2023

Thanks for this report!

@mattjj
Copy link
Member

mattjj commented Sep 20, 2023

Is there an easy way to reproduce the slowdown? We could try to bisect it down (probably easier using google-internal infra, if we have a repro to run.)

@nouiz
Copy link
Collaborator

nouiz commented Sep 20, 2023

I was able to repro on V100 16G like this:

docker run -it --gpus all nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04 /bin/bash
apt-get update
apt-get install -y python3-pip git

pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
git clone https://github.com/Autodesk/XLB
cd XLB
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed

pip install --upgrade "jax[cuda12_local]"==0.4.14 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python3 examples/performance/MLUPS3d.py 256 200

pip uninstall -y jax jaxlib
pip install --upgrade "jax[cuda12_local]"==0.4.16 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python3 examples/performance/MLUPS3d.py 256 200

Got this output:

omega =  0.7905138339920947
XLA backend: gpu
Number of XLA devices available: 1
WARNING: Checkpointing is disabled for this simulation.
Time to create the grid connectivity bitmask: 1.6283340454101562
Time to create the local bitmasks and normal arrays: 5.10488224029541
WARNING: Default initial conditions assumed: density = 1, velocity = 0
         To set explicit initial density and velocity, use self.initialize_macroscopic_fields.
Domain: 256 x 256 x 256
Number of voxels: 16777216
MLUPS: 577.1932827977173
...
omega =  0.7905138339920947
XLA backend: gpu
Number of XLA devices available: 1
WARNING: Checkpointing is disabled for this simulation.
Time to create the grid connectivity bitmask: 1.6950793266296387
Time to create the local bitmasks and normal arrays: 5.629136085510254
WARNING: Default initial conditions assumed: density = 1, velocity = 0
         To set explicit initial density and velocity, use self.initialize_macroscopic_fields.
Domain: 256 x 256 x 256
Number of voxels: 16777216
MLUPS: 536.4963689988614
...

There the hlo_rematerialization doesn't work and I still reproduce the issue.
Note, it is the MLUPS line that show the speed, higher is better.

@mattjj
Copy link
Member

mattjj commented Sep 20, 2023

Thanks, @nouiz ! I'm not sure how well we can translate those setup instructions into a google-internal repro (for bisection purposes), but it's a step forward.

Is it easy to grab a profile and look for where the diff may be?

@hawkinsp
Copy link
Member

I'll also note the logs from

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695236429.981993  393900 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.

are benign and fixed at head, but will need a new jaxlib release.

@mehdiataei
Copy link
Contributor Author

mehdiataei commented Sep 20, 2023

Hi Matthew, thanks for looking into this. What type of instructions do you prefer for the bisection purposes? I have included a script I use personally which can also reproduce this error.

@nouiz the hlo_rematerialization error will show up for a larger domain than 256 that would saturate the VRAM.

import os
import subprocess
import shutil

def run_command(command):
    subprocess.run(command, shell=True)

def extract_mlups(file_name):
    with open(file_name, 'r') as f:
        lines = f.readlines()
        for line in lines:
            if "MLUPS" in line:
                return float(line.strip().split(":")[1].strip())

def setup_and_run_XLB():
    os.chdir('/home/mehdi/XLB')

    venv_command = "python -m venv testenv"
    run_command(venv_command)

    # Generate shell script to activate venv and run commands
    with open("run_in_venv.sh", "w") as f:
        f.write("#!/bin/bash\n")
        f.write("source testenv/bin/activate\n")
        
        f.write("pip install pyvista jmp matplotlib portpicker termcolor Rtree trimesh jmp orbax-checkpoint\n")
        
        f.write("python examples/performance/MLUPS3d.py 256 100 > output.txt\n")

    run_command("chmod +x run_in_venv.sh")
    run_command("./run_in_venv.sh")

def main():
    # if os.path.exists("jax"):
    #     shutil.rmtree("jax")
    # if os.path.exists("XLB"):
    #     shutil.rmtree("XLB")

    # run_command("git clone https://github.com/google/jax.git")
    # run_command("git clone https://github.com/Autodesk/XLB.git")
    
    os.chdir("/home/mehdi/jax")
    
    run_command("git bisect start")
    run_command("git bisect good d477b921215d5030d09020a70df7f5cd46bb62dd")
    run_command("git bisect bad 88a60b808c1f91260cc9e75b9aa2508aae5bc9f9")
    
    mlups_report = {}
    
    while True:
        commit_hash = subprocess.getoutput("git rev-parse HEAD")
        
        run_command("pip install numpy wheel build")
        run_command("python build/build.py --enable_cuda")
        run_command("pip install dist/*.whl")

        setup_and_run_XLB()
        mlups_value = extract_mlups("/home/mehdi/XLB/output.txt")
        
        mlups_report[commit_hash] = mlups_value
        
        if mlups_value > 1450:
            run_command("git bisect good")
        elif mlups_value < 1450:
            run_command("git bisect bad")
        
        bisect_status = subprocess.getoutput("git bisect visualize")
        
        if "is the first bad commit" in bisect_status:
            break
    
    run_command("git bisect reset")
    
    print("MLUPS Comparison Report:")
    with open("MLUPS_Comparison_Report.txt", "w") as f:
        f.write("MLUPS Comparison Report:\n")
        for commit, mlups in mlups_report.items():
            report_line = f"{commit}: {mlups}\n"
            print(report_line.strip())
            f.write(report_line)

if __name__ == "__main__":
    main()

@nouiz
Copy link
Collaborator

nouiz commented Sep 21, 2023

I found the XLA commit that cause this regression:

3c665e2197320b95cf913dcf146fc9d35dc4ab49 is the first bad commit
commit 3c665e2197320b95cf913dcf146fc9d35dc4ab49
Author: Adrian Kuegel akuegel@google.com
Date:   Thu Aug 24 06:19:06 2023 -0700

    Do not choose bitcasts as fusion roots.

    This just adds indexing overhead without any benefit, and may require extra buffers.
    Bitcasts outside of fusions are no-ops.
    We still allow to fuse a bitcast producer into an already existing fusion.
    Otherwise they would act as fusion blockers.

    PiperOrigin-RevId: 559732964

@mattjj
Copy link
Member

mattjj commented Sep 21, 2023

Amazing work, @nouiz !

Should we document your process of bisecting XLA / jaxlib in OSS? Maybe share any notes of what process you used.

@nouiz
Copy link
Collaborator

nouiz commented Sep 21, 2023

I did a brute force testing of the JAX-toolbox nightly over 3 months.
Then I did custom scripts to checkout correctly XLA and JAX together.
Any idea where to document this in OSS?

@mattjj
Copy link
Member

mattjj commented Sep 21, 2023

We might want to make a new page on jax.readthedocs.io, like "how to debug performance regressions" or "how to bisect a failure down to a specific change". We could provide some scripts too, and/or example git bisect commands (and build commands). I imagine it'd be useful to know how to bisect over nightlies, how to get commit ranges corresponding to nightlies, and how to build from source to get down to a commit.

@hawkinsp any thoughts?

@mehdiataei
Copy link
Contributor Author

mehdiataei commented Sep 21, 2023

bisecting over nightly builds first and then bisecting from there to get down to a commit is a great idea!

@nouiz
Copy link
Collaborator

nouiz commented Sep 21, 2023

bisecting over nightly builds first and then bisecting from there to get down to a commit is a great idea!

What I have isn't that smart. It is just doing brute force at this 2 level. As it is at 2 levels, it is faster. It end up fast enough to be useful. I don't know how to easily do a bisect like that.

I don't know a tool to do bisect on top of nightly containers. I could make something more complicated, but brute-force is so simple that I didn't try.
At the lower level, there is dependency between jax and xla commits. To get rid of this issue, I do brute force at an hour granularity.

If someone know how to make them smarter without too much work, I'm happy to hear it.

@nouiz
Copy link
Collaborator

nouiz commented Oct 2, 2023

Here is a PR with the doc on how I did this investigation:
#17850

@cheshire
Copy link
Member

cheshire commented Oct 3, 2023

What is the conclusion? Revert the offending PR?

@cheshire
Copy link
Member

cheshire commented Oct 3, 2023

Tracked internally in b/303225846

@akuegel
Copy link
Member

akuegel commented Oct 5, 2023

Edit: nevermind, I got it running on a A100 cloud instance and could reproduce the slowdown (for A100 with cuda11, I got MLUPS: 1182.8535481602703 for the fast version and MLUPS: 1077.4404289392467 for the slow version).
I will extract the HLO dumps and analyze them.

@akuegel
Copy link
Member

akuegel commented Oct 6, 2023

After analyzing the dumped HloModules, I think I found the reason for the regression. After my change, we now fuse a bitcast as operand 0 of scatter fusion. Before, that bitcast was a fusion root of the "previous" fusion. I could not reproduce the slowdown with just that extracted scatter fusion (I guess it depends on the scatter indices, and random indices don't show the problem), but when run with the whole module, the difference is 3.769.591 ns (now) vs 401.699 ns (before). I don't know for sure why this makes such a big difference, but it could be that in the case where the bitcast is not fused into the scatter, we can optimize the kernel that copies the operand to the output. I believe we normally should not fuse anything into scatter operand 0, because then we can possibly do the scatter in-place and save the copying kernel.
All in all, I think my change just triggered this pre-existing performance problem with scatter fusions.

@nouiz
Copy link
Collaborator

nouiz commented Oct 6, 2023

Why this try this small change that limit the risk, don't fuse only bitcast to scatter operands?
The reasoning is the same, it is free if not fused. So we gain nothing from fusing it.
What do you think of that?

@akuegel
Copy link
Member

akuegel commented Oct 9, 2023

Yes, I think the right fix is to try to make sure we can do Scatter in-place. So don't fuse anything into the first operand of scatter. It will be a separate kernel launch even if we fuse. And if we don't fuse, we also avoid the problem with the useless bitcast kernel.

@akuegel
Copy link
Member

akuegel commented Oct 10, 2023

I have a pending change for this: openxla/xla#6159
I still need to run benchmarks and get a code review, but as far as I can tell, it seems to work and recovers the lost performance (possibly will even improve it a bit more).

@mehdiataei
Copy link
Contributor Author

@akuegel
Amazing Adrian. Let me know when it is available in the nightly builds and I can help verifying the performance changes.

@akuegel
Copy link
Member

akuegel commented Oct 13, 2023

This change has landed in openxla/xla@238685d

@mehdiataei
Copy link
Contributor Author

mehdiataei commented Oct 16, 2023

OK this is amazing!!!!!! 🚀 🚀

Thank you @akuegel for the change we now have 41% improvement vs 4.16 and 24% over 0.4.14!! This is actually a very significant jump! Hopefully it will show itself in other use cases as well.

Thanks @nouiz and @mattjj for the follow-up and finding the issue!

Shall we close the issue or are we going to continue the discussion on how to investigate the regression based on @nouiz PR?

Results on a single RTX 6000 Ada:

Build MLUPS
0.4.14 827.48
0.4.16 726.47
Nightly 1024.04
Comparison Improvement in MLUPS (%)
0.4.14 -> 0.4.16 -12.21
0.4.14 -> Nightly 23.75
0.4.16 -> Nightly 40.96

@mehdiataei
Copy link
Contributor Author

I want to seize this opportunity to direct our attention to this issue I opened earlier as well: #15368

donate_argnums can be very powerful, and I hope it can work with shardmap as well as it does with pmap!

@akuegel
Copy link
Member

akuegel commented Oct 16, 2023

I think the discussion of how to investigate such regressions could happen in a separate bug. Let's close this one.

@nouiz
Copy link
Collaborator

nouiz commented Oct 16, 2023

The discussion about how to investigate can continue for now on the PR itself unless you see a need for an issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants