Skip to content

Commit

Permalink
Improve installation instructions.
Browse files Browse the repository at this point in the history
Simplify the common case installation instructions on the JAX readme.

Move less-commonly used or more complicated options to the installation page of the JAX docs.
  • Loading branch information
hawkinsp committed Sep 22, 2023
1 parent 4269705 commit 8cbd58a
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 236 deletions.
222 changes: 20 additions & 202 deletions README.md
Expand Up @@ -378,215 +378,33 @@ Some standouts:

## Installation

JAX is written in pure Python, but it depends on XLA, which needs to be
installed as the `jaxlib` package. Use the following instructions to install a
binary package with `pip` or `conda`, to use a
[Docker container](#docker-containers-nvidia-gpu), or to [build JAX from
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
### Supported platforms

| | Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac ARM | Windows x86_64 | Windows WSL2 x86_64 |
|------------|--------------|-------------------------|--------------|----------------|----------------|---------------------|
| CPU | yes | yes (build from source) | yes | yes | yes | yes |
| NVIDIA GPU | yes | yes (build from source) | no | n/a | no | experimental |
| Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
| AMD GPU | experimental (build from source) | no | no | n/a | no | no |
| Apple GPU | n/a | no | experimental | experimental | n/a | n/a |

We support installing or building `jaxlib` on Linux (Ubuntu 20.04 or later) and
macOS (10.12 or later) platforms. There is also *experimental* native Windows
support.

Windows users can use JAX on CPU and GPU via the [Windows Subsystem for
Linux](https://docs.microsoft.com/en-us/windows/wsl/about), or alternatively
they can use the *experimental* native Windows CPU-only support.
### Instructions

### pip installation: CPU
| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U jax[cpu]` |
| NVIDIA GPU | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | [Build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |

We currently release `jaxlib` wheels for the following
operating systems and architectures:
* Linux, x86-64
* Mac, Intel
* Mac, ARM
* Windows, x86-64 (*experimental*)
See [the documentation](https://jax.readthedocs.io/en/latest/installation.html)
for information on alternative installation strategies. These includes compiling
from source, installing with Docker, using other versions of CUDA, a
community-supported conda build, and answers to some frequently-asked questions.

To install a CPU-only version of JAX, which might be useful for doing local
development on a laptop, you can run

```bash
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
```

On Windows, you may also need to install the
[Microsoft Visual Studio 2019 Redistributable](https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170#visual-studio-2015-2017-2019-and-2022)
if it is not already installed on your machine.

Other operating systems and architectures require building from source. Trying
to pip install on other operating systems and architectures may lead to `jaxlib`
not being installed alongside `jax`, although `jax` may successfully install
(but fail at runtime).

### pip installation: GPU (CUDA, installed via pip, easier)

There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN
installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend
installing CUDA and CUDNN using the pip wheels, since it is much easier!

JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer.
Note that Kepler-series GPUs are no longer supported by JAX since
NVIDIA has dropped support for Kepler GPUs in its software.

You must first install the NVIDIA driver. We
recommend installing the newest driver available from NVIDIA, but the driver
must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
If you need to use a newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVIDIA driver easily, you may be
able to use the
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
that NVIDIA provides for this purpose.


```bash
pip install --upgrade pip

# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

### pip installation: GPU (CUDA, installed locally, harder)

If you prefer to use a preinstalled copy of CUDA, you must first
install [CUDA](https://developer.nvidia.com/cuda-downloads) and
[CuDNN](https://developer.nvidia.com/CUDNN).

JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other
combinations of operating system and architecture are possible, but require
[building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).

You should use an NVIDIA driver version that is at least as new as your
[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
If you need to use a newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVIDIA driver easily, you may be
able to use the
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
that NVIDIA provides for this purpose.

JAX currently ships two CUDA wheel variants:
* CUDA 12.0 and CuDNN 8.9.
* CUDA 11.8 and CuDNN 8.6.

You may use a JAX wheel provided the major version of your CUDA and CuDNN
installation matches, and the minor version is at least as new as the version
JAX expects. For example, you would be able to use the CUDA 12.0 wheel with
CUDA 12.1 and CuDNN 8.9.

Your CUDA installation must also be new enough to support your GPU. If you have
an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,
you must use CUDA 11.8 or newer.


To install, run

```bash
pip install --upgrade pip

# Installs the wheel compatible with CUDA 12 and cuDNN 8.9 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

**These `pip` installations do not work with Windows, and may fail silently; see
[above](#installation).**

You can find your CUDA version with the command:

```bash
nvcc --version
```

Some GPU functionality expects the CUDA installation to be at
`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number
(e.g. `cuda-11.8`). If CUDA is installed elsewhere on your system, you can either
create a symlink:

```bash
sudo ln -s /path/to/cuda /usr/local/cuda-X.X
```

Please let us know on [the issue tracker](https://github.com/google/jax/issues)
if you run into any errors or problems with the prebuilt wheels.

### Docker containers: NVIDIA GPU

NVIDIA provides the [JAX
Toolbox](https://github.com/NVIDIA/JAX-Toolbox) containers, which are
bleeding edge containers containing nightly releases of jax and some
models/frameworks.

### pip installation: Google Cloud TPU

JAX provides pre-built wheels for
[Google Cloud TPU](https://cloud.google.com/tpu/docs/users-guide-tpu-vm).
To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can run
the following in your cloud TPU VM:
```bash
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

For interactive notebook users: Colab TPUs no longer support JAX as of
JAX version 0.4. However, for an interactive TPU notebook in the cloud, you can
use [Kaggle TPU notebooks](https://www.kaggle.com/docs/tpu), which fully
support JAX.

### pip installation: Apple GPUs

Apple provides an experimental Metal plugin for Apple GPU hardware. For details,
see
[Apple's JAX on Metal documentation](https://developer.apple.com/metal/jax/).

There are several caveats with the Metal plugin:
* the Metal plugin is new and experimental and has a number of
[known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22).
Please report any issues on the JAX issue tracker.
* the Metal plugin currently requires very specific versions of `jax` and
`jaxlib`. This restriction will be relaxed over time as the plugin API
matures.

### Conda installation

There is a community-supported Conda build of `jax`. To install using `conda`,
simply run

```bash
conda install jax -c conda-forge
```

To install on a machine with an NVIDIA GPU, run
```bash
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
```

Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which
JAX requires. You must therefore either install the `cuda-nvcc` package from
the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
is in your path. The channel order above is important (`conda-forge` before
`nvidia`).

If you would like to override which release of CUDA is used by JAX, or to
install the CUDA build on a machine without GPUs, follow the instructions in the
[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch)
section of the `conda-forge` website.

See the `conda-forge`
[jaxlib](https://github.com/conda-forge/jaxlib-feedstock#installing-jaxlib) and
[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories
for more details.

### Building JAX from source
See [Building JAX from
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).

## Neural network libraries

Expand Down
1 change: 1 addition & 0 deletions docs/developer.md
@@ -1,3 +1,4 @@
(building-from-source)=
# Building from source

First, obtain the JAX source code:
Expand Down
26 changes: 0 additions & 26 deletions docs/index.rst
Expand Up @@ -50,31 +50,6 @@ JAX is Autograd_ and XLA_, brought together for high-performance numerical compu
:class-card: developer-docs


Installation
------------
.. tab-set::

.. tab-item:: CPU

.. code-block:: bash
pip install "jax[cpu]"
.. tab-item:: GPU (CUDA)

.. code-block:: bash
pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
.. tab-item:: TPU (Google Cloud)

.. code-block:: bash
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
For more information about supported accelerators and platforms, and for other
installation options, see the `Install Guide`_ in the project README.

.. toctree::
:hidden:
:maxdepth: 1
Expand Down Expand Up @@ -116,4 +91,3 @@ installation options, see the `Install Guide`_ in the project README.

.. _Autograd: https://github.com/hips/autograd
.. _XLA: https://www.tensorflow.org/xla
.. _Install Guide: https://github.com/google/jax#installation

0 comments on commit 8cbd58a

Please sign in to comment.