Skip to content

Commit

Permalink
Add a link to the Apple Metal plugin to the JAX README.
Browse files Browse the repository at this point in the history
Remove references to cloud TPU on colab, since support was dropped in JAX 0.4. Users should use TPUs via Kaggle or via Cloud TPU VMs.
  • Loading branch information
hawkinsp committed Jun 22, 2023
1 parent 06f76bc commit 85a84fd
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions README.md
Expand Up @@ -517,26 +517,33 @@ Please let us know on [the issue tracker](https://github.com/google/jax/issues)
if you run into any errors or problems with the prebuilt wheels.

### pip installation: Google Cloud TPU
JAX also provides pre-built wheels for

JAX provides pre-built wheels for
[Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm).
To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run
the following in your cloud TPU VM:
```bash
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

### pip installation: Colab TPU
Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so the installation instructions differ.
The Colab TPU runtime comes with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU:
```python
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
```
Note that Colab TPU runtimes are not compatible with JAX version 0.4.0 or newer.
If you need to re-install JAX on a Colab TPU runtime, you can use the following command:
```
!pip install jax<=0.3.25 jaxlib<=0.3.25
```
For interactive notebook users: Colab TPUs are no longer supported by JAX as of
version 0.4. However, for an interactive TPU notebook in the cloud, you can use
[Kaggle TPU notebooks](https://www.kaggle.com/docs/tpu), which are fully
supported by JAX.

### pip installation: Apple GPUs

Apple provides an experimental Metal plugin for Apple GPU hardware. For details,
see
[Apple's JAX on Metal documentation](https://developer.apple.com/metal/jax/).

There are several caveats with the Metal plugin:
* the Metal plugin is new and experimental and has a number of
[known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22).
Please report any issues on the JAX issue tracker.
* the Metal plugin currently requires very specific versions of `jax` and
`jaxlib`. This restriction will be relaxed over time as the plugin API
matures.

### Conda installation

Expand Down

0 comments on commit 85a84fd

Please sign in to comment.