diff --git a/CHANGELOG.md b/CHANGELOG.md index c8ab9448dc99..82fb1c94bc3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,18 +16,24 @@ Remember to align the itemized text with the first line of an item within a list * `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')` * `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')` * Changes: - * {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken - across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy - 1.11. + * CUDA: JAX now verifies that the CUDA libraries it finds are at least as new + as the CUDA libraries that JAX was built against. If older libraries are + found, JAX raises an exception since that is preferable to mysterious + failures and crashes. * Removed the "No GPU/TPU" found warning. Instead warn if, on Linux, an NVIDIA GPU or a Google TPU are found but not used and `--jax_platforms` was not specified. + * {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken + across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy + 1.11. # jaxlib 0.4.17 * Changes: * Python 3.12 wheels were added in this release. - * The CUDA 12 wheels now require CUDA 12.2. + * The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer. + * This is likely the last release of jaxlib that supports CUDA 11. We intend + to drop CUDA 11 support in the next release. * Bug fixes: * Fixed log spam from ABSL when the JAX CPU backend was initialized. diff --git a/docs/installation.md b/docs/installation.md index f15a49922fbf..525adf9739c3 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -58,8 +58,9 @@ not being installed alongside `jax`, although `jax` may successfully install ### pip installation: GPU (CUDA, installed via pip, easier) There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN -installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend -installing CUDA and CUDNN using the pip wheels, since it is much easier! +installed from pip wheels, and using a self-installed CUDA/CUDNN. We +strongly recommend installing CUDA and CUDNN using the pip wheels, since it is +much easier! JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. Note that Kepler-series GPUs are no longer supported by JAX since @@ -87,6 +88,13 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` +If JAX detects the wrong version of the CUDA libraries, there are several things +to check: +* make sure that `LD_LIBRARY_PATH` is not set, since `LD_LIBRARY_PATH` can + override the CUDA libraries. +* make sure that the CUDA libraries installed are those requested by JAX. + Rerunning the installation command above should work. + ### pip installation: GPU (CUDA, installed locally, harder) If you prefer to use a preinstalled copy of CUDA, you must first @@ -106,13 +114,13 @@ able to use the that NVIDIA provides for this purpose. JAX currently ships two CUDA wheel variants: -* CUDA 12.0 and CuDNN 8.9. +* CUDA 12.2 and CuDNN 8.9.4. * CUDA 11.8 and CuDNN 8.6. -You may use a JAX wheel provided the major version of your CUDA and CuDNN -installation matches, and the minor version is at least as new as the version -JAX expects. For example, you would be able to use the CUDA 12.0 wheel with -CUDA 12.1 and CuDNN 8.9. +You may use a JAX wheel provided the major version of your CUDA +installation matches, and the minor versions are the same or newer. +JAX checks the versions of your libraries, and will report an error if they are +not sufficiently new. Your CUDA installation must also be new enough to support your GPU. If you have an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,