Skip to content

Commit

Permalink
DOC: update install instructions for GPU & TPU
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 24, 2021
1 parent f3fb0c4 commit 3c727ab
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,19 +396,21 @@ There is some initial native Windows support, but since it is still somewhat
immature, there are no binary releases and it must be
[built from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-jaxlib-from-source-on-windows).

### pip installation
### pip installation: CPU

To install a CPU-only version, which might be useful for doing local
To install a CPU-only version of JAX, which might be useful for doing local
development on a laptop, you can run

```bash
pip install --upgrade pip
pip install --upgrade jax jaxlib # CPU-only version
pip install --upgrade jax[cpu]
```

On Linux, it is often necessary to first update `pip` to a version that supports
`manylinux2010` wheels.

### pip installation: GPU (CUDA)

If you want to install JAX with both CPU and NVidia GPU support, you must first
install [CUDA](https://developer.nvidia.com/cuda-downloads) and
[CuDNN](https://developer.nvidia.com/CUDNN),
Expand All @@ -422,7 +424,7 @@ Next, run

```bash
pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.68+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install --upgrade jax[cuda111] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

The jaxlib version must correspond to the version of the existing CUDA
Expand Down Expand Up @@ -452,6 +454,16 @@ sudo ln -s /path/to/cuda /usr/local/cuda-X.X
Please let us know on [the issue tracker](https://github.com/google/jax/issues)
if you run into any errors or problems with the prebuilt wheels.

### pip installation: Google Cloud TPU
JAX also provides pre-built wheels for
[Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm).
To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run
the following in your cloud TPU VM:
```bash
pip install --upgrade pip
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

### Building JAX from source
See [Building JAX from
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
Expand Down
2 changes: 1 addition & 1 deletion docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip install jaxlib
```

See the [JAX readme](https://github.com/google/jax#installation) for full
guidance on pip installation (e.g., for GPU support).
guidance on pip installation (e.g., for GPU and TPU support).

### Building `jaxlib` from source

Expand Down

0 comments on commit 3c727ab

Please sign in to comment.