Skip to content

Commit

Permalink
Add notes about the new CUDA version restrictions to the changelog an…
Browse files Browse the repository at this point in the history
…d installation instructions.
  • Loading branch information
hawkinsp committed Sep 27, 2023
1 parent 87af945 commit b7dfde8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
14 changes: 10 additions & 4 deletions CHANGELOG.md
Expand Up @@ -16,18 +16,24 @@ Remember to align the itemized text with the first line of an item within a list
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
* Changes:
* {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken
across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy
1.11.
* CUDA: JAX now verifies that the CUDA libraries it finds are at least as new
as the CUDA libraries that JAX was built against. If older libraries are
found, JAX raises an exception since that is preferable to mysterious
failures and crashes.
* Removed the "No GPU/TPU" found warning. Instead warn if, on Linux, an
NVIDIA GPU or a Google TPU are found but not used and `--jax_platforms` was
not specified.
* {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken
across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy
1.11.

# jaxlib 0.4.17

* Changes:
* Python 3.12 wheels were added in this release.
* The CUDA 12 wheels now require CUDA 12.2.
* The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.
* This is likely the last release of jaxlib that supports CUDA 11. We intend
to drop CUDA 11 support in the next release.

* Bug fixes:
* Fixed log spam from ABSL when the JAX CPU backend was initialized.
Expand Down
22 changes: 15 additions & 7 deletions docs/installation.md
Expand Up @@ -58,8 +58,9 @@ not being installed alongside `jax`, although `jax` may successfully install
### pip installation: GPU (CUDA, installed via pip, easier)

There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN
installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend
installing CUDA and CUDNN using the pip wheels, since it is much easier!
installed from pip wheels, and using a self-installed CUDA/CUDNN. We
strongly recommend installing CUDA and CUDNN using the pip wheels, since it is
much easier!

JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer.
Note that Kepler-series GPUs are no longer supported by JAX since
Expand Down Expand Up @@ -87,6 +88,13 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

If JAX detects the wrong version of the CUDA libraries, there are several things
to check:
* make sure that `LD_LIBRARY_PATH` is not set, since `LD_LIBRARY_PATH` can
override the CUDA libraries.
* make sure that the CUDA libraries installed are those requested by JAX.
Rerunning the installation command above should work.

### pip installation: GPU (CUDA, installed locally, harder)

If you prefer to use a preinstalled copy of CUDA, you must first
Expand All @@ -106,13 +114,13 @@ able to use the
that NVIDIA provides for this purpose.

JAX currently ships two CUDA wheel variants:
* CUDA 12.0 and CuDNN 8.9.
* CUDA 12.2 and CuDNN 8.9.4.
* CUDA 11.8 and CuDNN 8.6.

You may use a JAX wheel provided the major version of your CUDA and CuDNN
installation matches, and the minor version is at least as new as the version
JAX expects. For example, you would be able to use the CUDA 12.0 wheel with
CUDA 12.1 and CuDNN 8.9.
You may use a JAX wheel provided the major version of your CUDA
installation matches, and the minor versions are the same or newer.
JAX checks the versions of your libraries, and will report an error if they are
not sufficiently new.

Your CUDA installation must also be new enough to support your GPU. If you have
an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,
Expand Down

0 comments on commit b7dfde8

Please sign in to comment.