-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for other GPUs (than NVIDIA) #2012
Comments
In principle, sure! All we need is XLA to support that architecture. In practice that means we support at the moment: CPU, NVidia GPU, and TPU. Happily AMD has been contributing support for AMD GPUs to XLA. We haven't tried it out in JAX, but assuming the XLA support is complete, I see no good reason it wouldn't work with a few small JAX changes. If you are excited about AMD GPUs, we'd certainly welcome contributions enabling that functionality in JAX. I don't think Intel GPUs have XLA support at the moment, but I wouldn't rule it out in the future as the various compiler toolchains (e.g., XLA, MLIR) progress. |
The AMDGPU backend for XLA is being actively developed; these PRs probably have the most up-to-date status (seems like many but not all tests pass?) One thing to note is that the AMD integrations require that you rebuild XLA from source; there's no way to build a single TF or XLA binary that can use both NVIDIA CUDA and AMD ROCm. For Intel hardware, I imagine we'd need something like MLIR translation from HLO dialect to nGraph dialect. I'm guessing nobody is actively working on that, but ccing @nmostafa in case Intel has plans in that area. |
Glad to see this ROCm thing seems to be funded with fulltime developers by AMD. Better late than never I suppose. I hope they learned at least a little from their misadventures in GPGPU, with opencl being half-assedly supported; and in practice if you wanted to get anything done, you had no choice to go with the platform that didnt require you to say, reinvent your FFT libraries from scratch. I hope this time around they realize there is some minimum investment in software theyd be smart to make, if they want to offer a competitive ecosystem. Its crazy to see how much money nvidia has made off this; in the meanwhile google adds a completely new viable hardware and software alternative in the forms of TPUs; and AMD is still working on getting compatibility with any of the software out there. It does not inspire much confidence to be honest; it seems wise to bet against them ever getting out a robust feature complete alternative, if they couldnt even get anything out 4 years ago already. But id love to be wrong about this, and for there to be some genuine competition in desktop ML acceleration in the future. |
Can someone help me how to use Jax on AMD GPUs? Are there any code snippets we can start with? |
Any update on the topic? |
There's no technical blocker to using JAX on AMD GPUs. We on the JAX team simply don't have access to any AMD GPUs at the moment to develop or test the necessary changes (which are probably not that large, given most of the necessary work has been done in the context of TensorFlow.) Contributions are welcome! |
That's good to know @jekbradbury, thanks |
I just wanted to ask, when we are taking about AMD GPUs being supported, is it going to be on all platforms (i.e. including MacOS) or are we talking Linux/Windows only? |
I believe the AMDGPU backend support for XLA is based on ROCm, which doesn't support macOS. |
I was able to build jax with initial support for ROCm (AMD GPUs) The code can be found here: Executing
on my RX480 outputs
which already looks very promising. For those who want to build this: |
@inailuig That's exciting progress! Nice work! (Sorry for the slow response, many of us were on vacation this last week.) Technically speaking the cublas/cusolver and cuda_prng kernels are somewhat optional. The cuda_prng kernel is a compile-time optimization and can be safely omitted (at the cost of increased compile time), and cublas/cusolver are only needed for linear algebra support. So it might be possible to check things in even before those pieces work. I'm curious: is it possible to use upstream TF instead of the ROCm fork? We frequently update our TF (XLA) version, so any ROCm specific fork is likely to be stale. |
@hawkinsp Turns out all that is missing in upstream TF is actually looking for devices with the right platform i.e. some changes in
Do you think we could get something like that upstreamed into TF ? For cuda_prng and the cublas/cusolver kernels I was also able to get them running (2 or 3 of the lapack functions (cusolver) are not yet implemented in rocsolver, but everything else is there; also requires a few more changes to TF; I will post more once I cleaned it up a bit) |
We certainly can upstream something like that. That file is really part of JAX so we can change it as we see fit. You can send PRs to TensorFlow and assign me; I can review. |
Thank you for trying out JAX on AMD GPUs. I am on the TF framework team in AMD, and would like to get a better understanding of the TF changes that are required to get JAX working. We would be more than happy to help out. I also had a question for you. Does JAX have unit-tests that run on GPUs, and if so can you point me to the directions to run them. I would like to get them running on internally on our platform, thanks again deven |
@deven-amd We'll need to wait for @inailuig to send out their remaining changes to get things to build. Once those changes are checked in, the best way to do this is probably something like this:
This builds and installs jaxlib with TF (XLA) from head (rather than whatever version we have pinned in our The I should note there are probably a few tests that fail at head on Nvidia GPUs also (#5067). |
Hi Peter,
Thanks for the quick response.
I will try out the directions you have provided + the docs, to get the JAX
unit tests working on the ROCm platform. I expect to work on this next
week, will ping you if I run into any issues. In case I do, would you
rather I email you directly or file an issue on the JAX github repo?
Thanks
deven
…On Fri, Dec 4, 2020 at 12:10 PM Peter Hawkins ***@***.***> wrote:
See also
https://jax.readthedocs.io/en/latest/developer.html#running-the-tests
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#2012 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AIZGTXBBS2FMGNUOINZHVXTSTEJYHANCNFSM4KHSBE2Q>
.
|
If there's no reason otherwise, we like to do development in the open so the community can be involved. So I'd file issues/PRs or use Github discussions. You can ping me in any issues or PRs if you want to make sure I take a look! |
@deven-amd Thanks for reaching out, would be great if you could in particular help with fixing the tests which are still failing. I just opened #5114 for the remaining build related stuff in jax. However there are still some tests failing because of bugs (e.g. stuff related to conv, dot_general, triangular solve, ...) Also there is this error message
which keeps popping up when the program terminates. @deven-amd would you be able to look into this? For the BLAS/LAPACK wrappers (i.e. jaxlib/cusolver.py and the related pybind modules but for rocm) For this to work we still need a few changes in TF:
All of this can be found in https://github.com/inailuig/tensorflow/tree/jax-rocm-gpukernels (there are 2 more commits which are useful for debugging, but not necessary) @hawkinsp How should we proceed? |
|
Retitling this bug to focus on AMD GPUs only; we can open new bugs for other hardware vendors if needed. |
Hi team, Thanks a lot for support ROCm for jax. Now I have met some issues: I don't knwo which is the right way to build jax from source (I saw https://hub.docker.com/r/rocm/jax and do checkout branch jax_preview_release) 1.Run unittest(https://jax.readthedocs.io/en/latest/developer.html?highlight=pytest#running-the-tests)
it shows
If I run the program with jaxlib, it will also show this assertion. I think maybe the source code or build way I used is wrong? 2.Try to use rocm4.1 to build jax from source, but it failed in
3.When I setup Can you help me? Thanks a lot!!! |
Fixed 2, it's libstdc++ version problems, but still have 'Check failed' error |
@coversb, is it possible for you to upgrade your ROCm version? |
@reza-amd is is possible to get a 5.0 series build (drun image + jax)? is there anything similar for pytorch? thanks in advance! |
Sorry for my slow response. |
@reza-amd Thanks for the update! I will try test things as soon as possible. More broadly, what are the criteria to close this bug? Things seem to be working reasonably well! |
See also: #9864 |
@reza-amd Thank you again for the help getting docker working! I am able to use jax to build a docker image and then train networks locally. I did a benchmark using flax + resnet50 + imagenet with a batchsize of 256 in fp16 mode. Here are the results of a wx6800:
Here is the results of the same code (eg fp16+bs256) on a dual nvidia 3060 (cuda) setup:
What else would be needed to mark this bug as resolved? I will start a ViT run next, but that will take a few days to complete! |
@brettkoonce Thanks much for your update and testing our recent changes in ROCm-5.0. |
@reza-amd I have made a little bit of progress with ViT and am having some issues with numerical precision on the w6800. The wx6800 is able to train models using a batchsize of 128 but I get reduced accuracy compared to a reference run on a TPU. w6800 results:
TPU-v2, batch size of 128 (all other code identical):
Second run with different TPU, same code config:
I had similar results (eg lower performance on AMD) when I did my tests with 4.5.0 last year. Do you have any ideas on why this would happen/suggestions for how to improve things? |
@brettkoonce When comparing against TPU, a key thing to be careful of is that the default matmul and convolution precision on TPU is |
Here is the result when using Nvidia hardware (4x3060) with the same configuration:
|
@brettkoonce Perhaps move this to a new bug? But my suggestion would be: can you minimize it to a small self-contained test case? That's what I would do, if I had access to the hardware and were debugging it. You might consider comparing the results of a single training step between CPU and GPU, or between the two GPUs. |
ROCm build scripts have been failing for ~2 months, see #10162. |
Jax 04b751c is building with rocm 5.2! |
Jax 09794be is building with rocm 5.4! |
Hey @brettkoonce I am trying to compile jax with jaxlib for rocm on arch linux, and just cannot get a functional combination of things to work. I was able to compile jaxlib 4.6 and 4.9, but errors occurred at runtime (including seg faults). Are you able to share which commits/releases/tags you used for jax/jaxlib, xla, (tensorflow if you still used that repo), and which build options you used? |
After some time pinging back on this issue, what an excellent discussion. Is anyone lucky enough to run JAX on ARM architecture (such as the Apple Silicon processors)? |
@ricardobarroslourenco Yes. JAX has supported CPU-only execution on Apple hardware for many releases, and there is a new and experimental Apple GPU plugin (https://github.com/google/jax#pip-installation-apple-gpus). (Note: experimental). In fact, I think I'm going to declare this issue fixed, because at this point we now have at my last count four GPU vendors (NVIDIA, AMD, Apple, Intel) that support JAX to some degree, so I think we can say "we support multiple GPU vendors". We're working on better integration, better testing, and easier release processes for all of them. Feel free to file new bugs specific to particular hardware vendors! |
Just a quick comment, will it be better to mention the installation guide for ROCm devices in the README, right before the Apple Metal devices section? What do you think @hawkinsp @brettkoonce ? |
Grabbag of responses: @hawkinsp +1 closing this as well, glad to have helped!
The pattern I have had luck with (ROCm 4.5 and up) is:
It's not super-turnkey but it definitely works! @JoeyTeng With the amount of customization ROCm requires, keeping it inside the docker build sub-folder (eg where it's at right now) would be where I would keep it going forward. The jax part works fine but ROCm needs more maturity in general before I can recommend it to new ML practitioners (eg having it on the primary README). |
Is it possible to run JAX on other GPU architectures other than NVIDIA (ex.: Intel, AMD)?
The text was updated successfully, but these errors were encountered: