Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for AMD Radeon GPU on Mac OS #7163

Open
dbl001 opened this issue Jul 1, 2021 · 6 comments
Open

Support for AMD Radeon GPU on Mac OS #7163

dbl001 opened this issue Jul 1, 2021 · 6 comments
Labels
Apple GPU (Metal) plugin contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@dbl001
Copy link

dbl001 commented Jul 1, 2021

Will there ever be support in 'Jax' for AMD's Radeon Pro 5700 XT on Big Sur via 'Metal' or Apple's machine learning framework?

plaidml can see the GPU and the GPU works with Keras backend.

 source plaidml-venv/bin/activate
(plaidml-venv) (base) dbl@x86_64-apple-darwin13 ~ % plaidml-setup

PlaidML Setup (0.7.0)

Thanks for using PlaidML!

The feedback we have received from our users indicates an ever-increasing need
for performance, programmability, and portability. During the past few months,
we have been restructuring PlaidML to address those needs.  To make all the
changes we need to make while supporting our current user base, all development
of PlaidML has moved to a branch — plaidml-v1. We will continue to maintain and
support the master branch of PlaidML and the stable 0.7.0 release.

Read more here: https://github.com/plaidml/plaidml 

Some Notes:
  * Bugs and other issues: https://github.com/plaidml/plaidml/issues
  * Questions: https://stackoverflow.com/questions/tagged/plaidml
  * Say hello: https://groups.google.com/forum/#!forum/plaidml-dev
  * PlaidML is licensed under the Apache License 2.0
 

Default Config Devices:
   llvm_cpu.0 : CPU (via LLVM)
   metal_amd_radeon_pro_5700_xt.0 : AMD Radeon Pro 5700 XT (Metal)

Experimental Config Devices:
   llvm_cpu.0 : CPU (via LLVM)
   opencl_amd_radeon_pro_5700_xt_compute_engine.0 : AMD AMD Radeon Pro 5700 XT Compute Engine (OpenCL)
   metal_amd_radeon_pro_5700_xt.0 : AMD Radeon Pro 5700 XT (Metal)

Using experimental devices can cause poor performance, crashes, and other nastiness.

Enable experimental device support? (y,n)[n]:y

Multiple devices detected (You can override by setting PLAIDML_DEVICE_IDS).
Please choose a default device:

   1 : llvm_cpu.0
   2 : opencl_amd_radeon_pro_5700_xt_compute_engine.0
   3 : metal_amd_radeon_pro_5700_xt.0

Default device? (1,2,3)[1]:3

Selected device:
    metal_amd_radeon_pro_5700_xt.0

Almost done. Multiplying some matrices...
Tile code:
  function (B[X,Z], C[Z,Y]) -> (A) { A[x,y : X,Y] = +(B[x,z] * C[z,y]); }
Whew. That worked.

Save settings to /Users/dbl/.plaidml? (y,n)[y]:y
Success!

@dbl001 dbl001 added the enhancement New feature or request label Jul 1, 2021
@hawkinsp hawkinsp changed the title Support for AMD Radeon GPU Support for AMD Radeon GPU on Mac OS Jul 7, 2021
@hawkinsp hawkinsp added the P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR label Jul 7, 2021
@hawkinsp
Copy link
Collaborator

hawkinsp commented Jul 7, 2021

We support AMD GPUs already via AMD's ROCm framework, but I believe ROCm does not support Mac OS.

Targeting Metal would require either adding Metal support to XLA, or plugging an alternative compiler into JAX. One candidate is IREE, which does not yet support Metal but does have it on its road map.

We have no plans to work on this at the moment, but would welcome contributions.

@hawkinsp hawkinsp added the contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. label Jul 7, 2021
@dbl001
Copy link
Author

dbl001 commented Jun 6, 2022

Can Jax support tensorflow-macos and tensorflow-metal?

@dbl001
Copy link
Author

dbl001 commented Aug 22, 2023

I was able to build jax-metal on my iMac 27" running MacOS 13.5 with an AMD Radeon Pro 6=5700 GPU.
Does jax-metal utilize AMD GPUs or only Apple M1/M2 GPUs.

% ipython                        
Python 3.9.17 (main, Jun 10 2023, 11:04:38) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.14.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax
   ...: import jax.numpy as jnp
   ...: 
   ...: # Set environment variable to enable MPS backend
   ...: import os
   ...: os.environ['JAX_PLATFORM_NAME'] = 'mps'
   ...: 
   ...: # Import Jax modules after setting env var
   ...: from jax import random
   ...: import jax.lax as lax
   ...: 
   ...: # Create a simple random matrix
   ...: key = random.PRNGKey(0)
   ...: x = random.normal(key, (10,10))
   ...: 
   ...: # Do some Jax operations on GPU
   ...: y = x @ x.T
   ...: z = lax.sin(y)
   ...: 
   ...: print(jnp.sum(z))
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1692732316.471622       1 tfrt_cpu_pjrt_client.cc:466] TfrtCpuClient created.
-4.521248

In [2]: jax.lib.xla_bridge.get_backend()
Out[2]: <jaxlib.xla_extension.Client at 0x1119356b0>

@mattjj
Copy link
Collaborator

mattjj commented Aug 22, 2023

@kulinseth @shuhand0 is #7163 (comment) easy to answer?

@mattjj
Copy link
Collaborator

mattjj commented Aug 22, 2023

FWIW this page suggests to me that AMD GPUs should work.

@dbl001
Copy link
Author

dbl001 commented Aug 22, 2023

% pip show jax
Name: jax
Version: 0.4.15.dev20230822
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /Users/davidlaxer/jax-metal/lib/python3.9/site-packages
Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy
Required-by: 
(jax-metal) davidlaxer@bluediamond tests % pip show jaxlib
Name: jaxlib
Version: 0.4.15.dev20230822
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /Users/davidlaxer/jax-metal/lib/python3.9/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: 

 % ipython
Python 3.9.17 (main, Jun 10 2023, 11:04:38) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.14.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: from jax import jit
   ...: import jax.numpy as jnp
   ...: ^I^I^I^I^I^I^I^I
   ...: # define the cube function
   ...: def cube(x):
   ...: ^Ireturn x * x * x
   ...: 
   ...: # generate data
   ...: x = jnp.ones((10000, 10000))
   ...: 
   ...: jit_cube = jit(cube)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1692734199.312967       1 tfrt_cpu_pjrt_client.cc:466] TfrtCpuClient created.

Does this imply the compuation was done on the 'cpu'?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

No branches or pull requests

3 participants