Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for other GPUs (than NVIDIA) #2012

Closed
ricardobarroslourenco opened this issue Jan 16, 2020 · 90 comments · Fixed by #5114
Closed

Add support for other GPUs (than NVIDIA) #2012

ricardobarroslourenco opened this issue Jan 16, 2020 · 90 comments · Fixed by #5114
Labels
AMD GPU Issues pertaining to AMD GPUs (ROCM) contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@ricardobarroslourenco
Copy link

Is it possible to run JAX on other GPU architectures other than NVIDIA (ex.: Intel, AMD)?

@hawkinsp
Copy link
Member

In principle, sure! All we need is XLA to support that architecture.

In practice that means we support at the moment: CPU, NVidia GPU, and TPU.

Happily AMD has been contributing support for AMD GPUs to XLA. We haven't tried it out in JAX, but assuming the XLA support is complete, I see no good reason it wouldn't work with a few small JAX changes. If you are excited about AMD GPUs, we'd certainly welcome contributions enabling that functionality in JAX.

I don't think Intel GPUs have XLA support at the moment, but I wouldn't rule it out in the future as the various compiler toolchains (e.g., XLA, MLIR) progress.

@hawkinsp hawkinsp added the contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. label Jan 16, 2020
@jekbradbury
Copy link
Contributor

The AMDGPU backend for XLA is being actively developed; these PRs probably have the most up-to-date status (seems like many but not all tests pass?)

One thing to note is that the AMD integrations require that you rebuild XLA from source; there's no way to build a single TF or XLA binary that can use both NVIDIA CUDA and AMD ROCm.

For Intel hardware, I imagine we'd need something like MLIR translation from HLO dialect to nGraph dialect. I'm guessing nobody is actively working on that, but ccing @nmostafa in case Intel has plans in that area.

@EelcoHoogendoorn
Copy link

Glad to see this ROCm thing seems to be funded with fulltime developers by AMD. Better late than never I suppose. I hope they learned at least a little from their misadventures in GPGPU, with opencl being half-assedly supported; and in practice if you wanted to get anything done, you had no choice to go with the platform that didnt require you to say, reinvent your FFT libraries from scratch. I hope this time around they realize there is some minimum investment in software theyd be smart to make, if they want to offer a competitive ecosystem. Its crazy to see how much money nvidia has made off this; in the meanwhile google adds a completely new viable hardware and software alternative in the forms of TPUs; and AMD is still working on getting compatibility with any of the software out there. It does not inspire much confidence to be honest; it seems wise to bet against them ever getting out a robust feature complete alternative, if they couldnt even get anything out 4 years ago already. But id love to be wrong about this, and for there to be some genuine competition in desktop ML acceleration in the future.

@Cvikli
Copy link

Cvikli commented Jun 8, 2020

Can someone help me how to use Jax on AMD GPUs? Are there any code snippets we can start with?

@Sixzero
Copy link

Sixzero commented Jun 26, 2020

Any update on the topic?
How can that happen tensorflow supports AMD GPU-s but JAX doesn't?
Isn't ROCM is the CUDA for AMD GPU-s and inplace replacements of each others?

@hawkinsp
Copy link
Member

There's no technical blocker to using JAX on AMD GPUs. We on the JAX team simply don't have access to any AMD GPUs at the moment to develop or test the necessary changes (which are probably not that large, given most of the necessary work has been done in the context of TensorFlow.)

Contributions are welcome!

@8bitmp3
Copy link
Contributor

8bitmp3 commented Jun 26, 2020

The AMDGPU backend for XLA is being actively developed

That's good to know @jekbradbury, thanks

@akuz
Copy link

akuz commented Aug 28, 2020

I just wanted to ask, when we are taking about AMD GPUs being supported, is it going to be on all platforms (i.e. including MacOS) or are we talking Linux/Windows only?

@jekbradbury
Copy link
Contributor

I believe the AMDGPU backend support for XLA is based on ROCm, which doesn't support macOS.

@inailuig
Copy link
Contributor

inailuig commented Nov 22, 2020

I was able to build jax with initial support for ROCm (AMD GPUs) by compiling it using XLA from ROCmSoftwarePlatform/tensorflow-upstream (update: after tensorflow/tensorflow#45344 you can use upstream TF) and adding a few options to the build scripts.

The code can be found here: inailuig/jax (update: after #5114 you can use upstream jax)

Executing

import jax
print(jax.devices())
print(jax.devices()[0].device_kind)
x = jax.numpy.array([1.2, 3.4, 5.6])
y = jax.numpy.exp(x)
print(y)

on my RX480 outputs

[GpuDevice(id=0)]
Ellesmere [Radeon RX 470/480/570/570X/580/580X/590]
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
[  3.3201168  29.964104  270.4264   ]
2020-11-22 20:40:04.841794: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842168: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842517: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842866: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.844206: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work

which already looks very promising.
However there are still things missing such as the custom gpu kernels in jaxlib (cublas, cuda_prng, cusolver).

For those who want to build this:
I am running Ubuntu 20.04.1 with rocm 3.9.0 installed using the official instructions.
Also it is necessary to install these additional packages:
rocm-dev miopen-hip rocfft rocblas rccl hipsparse rocrand rocsolver hipblas
Then the whole thing can be built with
python3 build/build.py --enable_rocm --rocm_path /opt/rocm-3.9.0
Optionally different amdgpu targets can be specified with --rocm_amdgpu_targets (see here). For now I put in some default targets, however autodetection does also work (by passing "" (an empty string) which overrides the default).

@hawkinsp
Copy link
Member

hawkinsp commented Dec 1, 2020

@inailuig That's exciting progress! Nice work! (Sorry for the slow response, many of us were on vacation this last week.)

Technically speaking the cublas/cusolver and cuda_prng kernels are somewhat optional. The cuda_prng kernel is a compile-time optimization and can be safely omitted (at the cost of increased compile time), and cublas/cusolver are only needed for linear algebra support. So it might be possible to check things in even before those pieces work.

I'm curious: is it possible to use upstream TF instead of the ROCm fork? We frequently update our TF (XLA) version, so any ROCm specific fork is likely to be stale.

@inailuig
Copy link
Contributor

inailuig commented Dec 1, 2020

@hawkinsp Turns out all that is missing in upstream TF is actually looking for devices with the right platform i.e. some changes in tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc (from this commit: ROCm/tensorflow-upstream@0ba0236)

diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
index 4863e5e8165..870007f1dca 100644
--- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
+++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
@@ -57,11 +57,19 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
 
 // Builds an xla::LocalClient for the GPU platform.
 StatusOr<LocalClient*> GetGpuXlaClient() {
+#if GOOGLE_CUDA
   TF_ASSIGN_OR_RETURN(se::Platform * platform,
                       PlatformUtil::GetPlatform("CUDA"));
   if (platform->VisibleDeviceCount() <= 0) {
     return FailedPrecondition("No visible NVidia GPU devices.");
   }
+#else
+  TF_ASSIGN_OR_RETURN(se::Platform * platform,
+                      PlatformUtil::GetPlatform("ROCm"));
+  if (platform->VisibleDeviceCount() <= 0) {
+    return FailedPrecondition("No visible AMD GPU devices.");
+  }
+#endif
   LocalClientOptions options;
   options.set_platform(platform);
   return ClientLibrary::GetOrCreateLocalClient(options);

Do you think we could get something like that upstreamed into TF ?

For cuda_prng and the cublas/cusolver kernels I was also able to get them running (2 or 3 of the lapack functions (cusolver) are not yet implemented in rocsolver, but everything else is there; also requires a few more changes to TF; I will post more once I cleaned it up a bit)

@hawkinsp
Copy link
Member

hawkinsp commented Dec 1, 2020

We certainly can upstream something like that. That file is really part of JAX so we can change it as we see fit. You can send PRs to TensorFlow and assign me; I can review.

@deven-amd
Copy link

@hawkinsp @inailuig

Thank you for trying out JAX on AMD GPUs. I am on the TF framework team in AMD, and would like to get a better understanding of the TF changes that are required to get JAX working. We would be more than happy to help out.

I also had a question for you. Does JAX have unit-tests that run on GPUs, and if so can you point me to the directions to run them. I would like to get them running on internally on our platform,

thanks again

deven

@hawkinsp
Copy link
Member

hawkinsp commented Dec 4, 2020

@deven-amd We'll need to wait for @inailuig to send out their remaining changes to get things to build.

Once those changes are checked in, the best way to do this is probably something like this:

git clone https://github.com/google/jax.git
git clone https://github.com/tensorflow/tensorflow.git /mydir/tensorfow
cd jax
python build/build.py --bazel_options=--override_repository=org_tensorflow=/mydir/tensorflow --enable_rocm
pip install dist/*.whl
pip install -e .
XLA_PYTHON_CLIENT_ALLOCATOR=platform pytest -n auto tests examples

This builds and installs jaxlib with TF (XLA) from head (rather than whatever version we have pinned in our WORKSPACE file). (You can also achieve this by editing the WORKSPACE file; see the comments in that file.)

The XLA_PYTHON_CLIENT_ALLOCATOR avoids using the BFC allocator which preallocates GPU memory, which means that we should be able to run tests in parallel using multiple processes (-n auto enables this).

I should note there are probably a few tests that fail at head on Nvidia GPUs also (#5067).

@hawkinsp
Copy link
Member

hawkinsp commented Dec 4, 2020

@deven-amd
Copy link

deven-amd commented Dec 4, 2020 via email

@hawkinsp
Copy link
Member

hawkinsp commented Dec 4, 2020

@deven-amd

If there's no reason otherwise, we like to do development in the open so the community can be involved. So I'd file issues/PRs or use Github discussions. You can ping me in any issues or PRs if you want to make sure I take a look!

@inailuig
Copy link
Contributor

inailuig commented Dec 5, 2020

@deven-amd Thanks for reaching out, would be great if you could in particular help with fixing the tests which are still failing.

I just opened #5114 for the remaining build related stuff in jax.
In general things seem to be working.

However there are still some tests failing because of bugs (e.g. stuff related to conv, dot_general, triangular solve, ...)
Other Features are simply not implemented yet for ROCm in XLA (e.g. TRSM for complex args).
For the latter we will have to identify and skip them.

Also there is this error message

 E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work

which keeps popping up when the program terminates. @deven-amd would you be able to look into this?


For the BLAS/LAPACK wrappers (i.e. jaxlib/cusolver.py and the related pybind modules but for rocm)
I mostly followed what @hawkinsp did for cuda here since its just lots of glue code around roc/cu BLAS/Solver routines). This can be found in #5115

For this to work we still need a few changes in TF:

  1. custom_call_thunk needs to be enabled and build for rocm: https://github.com/inailuig/tensorflow/commit/44d3a233c6971344d595aacbad1459e1822264cd
  2. in xla_client.py "CUDA" is hardcoded when you try to register a custom call target
    I suggest we fix this like so:https://github.com/inailuig/tensorflow/commit/be9602a7666eb05edc33faae7825f8401968e885
    (This still keeps CUDA as default when you pass 'gpu' unfortunately)
    Then we can register functions for ROCM like this:
    xla_client.register_custom_call_target(_name, _value, platform="ROCM")
    Everywhere else in jax we can keep 'gpu'.
  3. We need to add rocSolver targets to the build scripts somewhere (I think we should add this to TF, although I guess it would also be possibe to add them just to jax)
    For my attempt at this see: https://github.com/inailuig/tensorflow/commit/606d7933b39f4115f8aea61e25bceb906855b5bf
  4. Not strictly necessary but nice to have: rocm_library, see https://github.com/inailuig/tensorflow/commit/e08f34ca8fe49056407eeaa706556af891d6857d

All of this can be found in https://github.com/inailuig/tensorflow/tree/jax-rocm-gpukernels (there are 2 more commits which are useful for debugging, but not necessary)

@hawkinsp How should we proceed?

@hawkinsp
Copy link
Member

hawkinsp commented Dec 6, 2020

  1. Seems fine: I'd send that as a PR.
  2. Also looks fine to me. I might be tempted to change "gpu" to mean "register both CUDA and ROCM", which we could do by making xla_platform_names a dictionary whose values are a list of names and then register all of them.
  3. Seems plausible, and adding it to TF is probably the better place (that way, TF can share the build rules). I'm a bit surprised that TF doesn't have ROCSolver hooked up already.
  4. Also seems reasonable to me, but I'm not as sure about this.

@hawkinsp hawkinsp reopened this Dec 7, 2020
@hawkinsp hawkinsp changed the title Using non-nvidia GPU Add support for AMD GPUs Dec 7, 2020
@hawkinsp
Copy link
Member

hawkinsp commented Dec 7, 2020

Retitling this bug to focus on AMD GPUs only; we can open new bugs for other hardware vendors if needed.

@coversb
Copy link

coversb commented Jan 27, 2022

Hi team,

Thanks a lot for support ROCm for jax. Now I have met some issues:

I don't knwo which is the right way to build jax from source (I saw https://hub.docker.com/r/rocm/jax and do checkout branch jax_preview_release)
I build with rocm4.0.1, and the device is gfx906.

1.Run unittest(https://jax.readthedocs.io/en/latest/developer.html?highlight=pytest#running-the-tests)

python tests/lax_numpy_test.py --num_generated_cases=5

it shows

2022-01-27 09:21:25.249401: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc:78] Check failed: buffer_slice.offset() + buffer_slice.size() <= base.size() (4 vs. 0)

If I run the program with jaxlib, it will also show this assertion. I think maybe the source code or build way I used is wrong?

2.Try to use rocm4.1 to build jax from source, but it failed in mlir/xla/operator_writer_gen part, I don't known how to get a right llvm tar package

bazel-out/k8-opt-exec-50AE0418/bin/external/org_tensorflow/tensorflow/compiler/mlir/xla/operator_writer_gen: symbol lookup error: bazel-out/k8-opt-exec-50AE0418/bin/external/org_tensorflow/tensorflow/compiler/mlir/xla/operator_writer_gen: undefined symbol: _ZTINSt3_V214error_categoryE

3.When I setup XLA_PYTHON_CLIENT_MEM_FRACTION or XLA_PYTHON_CLIENT_PREALLOCATE, seems don't pre-alloc gpu memory as the FRACTION set, just malloc 2%-14% GPU RAM

Can you help me? Thanks a lot!!!

@coversb
Copy link

coversb commented Jan 29, 2022

Hi team,

Thanks a lot for support ROCm for jax. Now I have met some issues:

I don't knwo which is the right way to build jax from source (I saw https://hub.docker.com/r/rocm/jax and do checkout branch jax_preview_release) I build with rocm4.0.1, and the device is gfx906.

1.Run unittest(https://jax.readthedocs.io/en/latest/developer.html?highlight=pytest#running-the-tests)

python tests/lax_numpy_test.py --num_generated_cases=5

it shows

2022-01-27 09:21:25.249401: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc:78] Check failed: buffer_slice.offset() + buffer_slice.size() <= base.size() (4 vs. 0)

If I run the program with jaxlib, it will also show this assertion. I think maybe the source code or build way I used is wrong?

2.Try to use rocm4.1 to build jax from source, but it failed in mlir/xla/operator_writer_gen part, I don't known how to get a right llvm tar package

bazel-out/k8-opt-exec-50AE0418/bin/external/org_tensorflow/tensorflow/compiler/mlir/xla/operator_writer_gen: symbol lookup error: bazel-out/k8-opt-exec-50AE0418/bin/external/org_tensorflow/tensorflow/compiler/mlir/xla/operator_writer_gen: undefined symbol: _ZTINSt3_V214error_categoryE

3.When I setup XLA_PYTHON_CLIENT_MEM_FRACTION or XLA_PYTHON_CLIENT_PREALLOCATE, seems don't pre-alloc gpu memory as the FRACTION set, just malloc 2%-14% GPU RAM

Can you help me? Thanks a lot!!!

Fixed 2, it's libstdc++ version problems, but still have 'Check failed' error

@reza-amd
Copy link
Contributor

reza-amd commented Feb 7, 2022

@coversb, is it possible for you to upgrade your ROCm version?

@coversb
Copy link

coversb commented Feb 14, 2022

@coversb, is it possible for you to upgrade your ROCm version?
@reza-amd
Thanks for your reply! Yes, which version do you think is OK? I don't know if there are some incompatible features between higher version and ROCm4.0 , seems lots of header files and lib files changed.

@brettkoonce
Copy link
Contributor

@reza-amd is is possible to get a 5.0 series build (drun image + jax)? is there anything similar for pytorch? thanks in advance!

@reza-amd
Copy link
Contributor

reza-amd commented Mar 2, 2022

Sorry for my slow response.
We have recently released ROCm-5.0 and we have updated JAX accordingly.
You can track the status of PR here: #9584
In the PR source branch, I have provided utility scripts to build a ROCm container with JAX. Please take a look at https://github.com/ROCmSoftwarePlatform/jax/tree/rocm_refactor_jaxlib/build/rocm for more details.

@brettkoonce
Copy link
Contributor

@reza-amd Thanks for the update! I will try test things as soon as possible. More broadly, what are the criteria to close this bug? Things seem to be working reasonably well!

@brettkoonce
Copy link
Contributor

See also: #9864

@brettkoonce
Copy link
Contributor

@reza-amd Thank you again for the help getting docker working! I am able to use jax to build a docker image and then train networks locally. I did a benchmark using flax + resnet50 + imagenet with a batchsize of 256 in fp16 mode.

Here are the results of a wx6800:

I0320 11:30:18.089150 140318779365120 logging_writer.py:35] [500400] steps_per_second=1.096407, train_accuracy=0.8140624761581421, train_learning_rate=3.6358835941996404e-09, train_loss=0.7592905163764954, train_scale=65536.0
I0320 11:30:53.003456 140385056196416 train.py:364] eval epoch: 99, loss: 0.9520, accuracy: 76.26
I0320 11:30:53.004454 140318779365120 logging_writer.py:35] [500400] eval_accuracy=0.7626402378082275, eval_loss=0.9520021080970764

Here is the results of the same code (eg fp16+bs256) on a dual nvidia 3060 (cuda) setup:

I0319 08:20:43.101390 140559431702272 logging_writer.py:35] [500400] steps_per_second=3.186386, train_accuracy=0.8107030987739563, train_learning_rate=3.6358835941996404e-09, train_loss=0.7701781988143921, train_scale=65536.0
I0319 08:20:57.018532 140610530277184 train.py:364] eval epoch: 99, loss: 0.9472, accuracy: 76.45
I0319 08:20:57.019293 140559431702272 logging_writer.py:35] [500400] eval_accuracy=0.7645031809806824, eval_loss=0.9471861124038696

What else would be needed to mark this bug as resolved? I will start a ViT run next, but that will take a few days to complete!

@reza-amd
Copy link
Contributor

@brettkoonce Thanks much for your update and testing our recent changes in ROCm-5.0.

@brettkoonce
Copy link
Contributor

@reza-amd I have made a little bit of progress with ViT and am having some issues with numerical precision on the w6800. The wx6800 is able to train models using a batchsize of 128 but I get reduced accuracy compared to a reference run on a TPU.

w6800 results:

I0415 08:36:18.115095 139965291800320 logging_writer.py:35] [900810] valid_loss=3.451379, valid_prec@1=0.423600

TPU-v2, batch size of 128 (all other code identical):

I0504 16:20:31.839697 140592298116864 logging_writer.py:35] [900810] valid_loss=3.074142, valid_prec@1=0.491300

Second run with different TPU, same code config:

I0504 16:20:37.937061 140423741003520 logging_writer.py:35] [900810] valid_loss=3.013428, valid_prec@1=0.498340

I had similar results (eg lower performance on AMD) when I did my tests with 4.5.0 last year. Do you have any ideas on why this would happen/suggestions for how to improve things?

@hawkinsp
Copy link
Member

hawkinsp commented May 4, 2022

@brettkoonce When comparing against TPU, a key thing to be careful of is that the default matmul and convolution precision on TPU is bfloat16 inputs with float32 accumulation. Try setting jax_default_matmul_precision to float32, which although slower should give numerics closer to typical GPUs. Just because the AMD GPU loss is worse, doesn't mean that it's necessarily that the AMD GPU implementation is doing something wrong. (It might be! But I'd try to rule out known quantities.)

@brettkoonce
Copy link
Contributor

brettkoonce commented May 4, 2022

@hawkinsp I am using the scenic vit demo in float32 mode (set like so for data / model), for what it's worth. Are there additional settings I should investigate?

I am doing a nvidia run currently with the same configuration and will report when that is finished.

@brettkoonce
Copy link
Contributor

Here is the result when using Nvidia hardware (4x3060) with the same configuration:

I0518 07:44:17.425569 140335808247552 logging_writer.py:35] [900810] valid_loss=3.027929, valid_prec@1=0.495140

@hawkinsp
Copy link
Member

@brettkoonce Perhaps move this to a new bug? But my suggestion would be: can you minimize it to a small self-contained test case? That's what I would do, if I had access to the hardware and were debugging it. You might consider comparing the results of a single training step between CPU and GPU, or between the two GPUs.

@brettkoonce
Copy link
Contributor

ROCm build scripts have been failing for ~2 months, see #10162.

@hawkinsp hawkinsp added the AMD GPU Issues pertaining to AMD GPUs (ROCM) label Aug 15, 2022
@brettkoonce
Copy link
Contributor

Jax 04b751c is building with rocm 5.2!

@brettkoonce
Copy link
Contributor

Jax 09794be is building with rocm 5.4!

@stephensrmmartin
Copy link

Hey @brettkoonce

I am trying to compile jax with jaxlib for rocm on arch linux, and just cannot get a functional combination of things to work.

I was able to compile jaxlib 4.6 and 4.9, but errors occurred at runtime (including seg faults).

Are you able to share which commits/releases/tags you used for jax/jaxlib, xla, (tensorflow if you still used that repo), and which build options you used?

@ricardobarroslourenco ricardobarroslourenco changed the title Add support for AMD GPUs Add support for other GPUs (than NVIDIA) Jun 22, 2023
@ricardobarroslourenco
Copy link
Author

After some time pinging back on this issue, what an excellent discussion. Is anyone lucky enough to run JAX on ARM architecture (such as the Apple Silicon processors)?

@hawkinsp
Copy link
Member

@ricardobarroslourenco Yes. JAX has supported CPU-only execution on Apple hardware for many releases, and there is a new and experimental Apple GPU plugin (https://github.com/google/jax#pip-installation-apple-gpus). (Note: experimental).

In fact, I think I'm going to declare this issue fixed, because at this point we now have at my last count four GPU vendors (NVIDIA, AMD, Apple, Intel) that support JAX to some degree, so I think we can say "we support multiple GPU vendors". We're working on better integration, better testing, and easier release processes for all of them.

Feel free to file new bugs specific to particular hardware vendors!

@JoeyTeng
Copy link
Contributor

Just a quick comment, will it be better to mention the installation guide for ROCm devices in the README, right before the Apple Metal devices section? What do you think @hawkinsp @brettkoonce ?

@brettkoonce
Copy link
Contributor

Grabbag of responses:

@hawkinsp +1 closing this as well, glad to have helped!

@stephensrmmartin

Are you able to share which commits/releases/tags you used for jax/jaxlib, xla, (tensorflow if you still used that repo), and which build options you used?

The pattern I have had luck with (ROCm 4.5 and up) is:

  1. Latest Ubuntu linux LTS (supported by ROCm) with HKE addon.
  2. Full ROCm install using the installer.
  3. Install Docker with hardware extensions enabled -->
  4. Then build rocm + jax inside said container, able to talk to device using the instructions in the AMD rocm guide
  5. You should now be able to run python inside the docker environment, import jax + call jax.devices() to verify things are working together.
  6. (optional) then pin/freeze said image and use it as a base for experiments.

It's not super-turnkey but it definitely works!

@JoeyTeng With the amount of customization ROCm requires, keeping it inside the docker build sub-folder (eg where it's at right now) would be where I would keep it going forward. The jax part works fine but ROCm needs more maturity in general before I can recommend it to new ML practitioners (eg having it on the primary README).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AMD GPU Issues pertaining to AMD GPUs (ROCM) contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

Successfully merging a pull request may close this issue.