Skip to content

Commit

Permalink
Add a default cuda installation path and more explicit installation p…
Browse files Browse the repository at this point in the history
…aths for CUDA jaxlib.

```
# Installs Cuda 11 with Cudnn 8.2
$ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html

$ pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

$ pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

PiperOrigin-RevId: 404134291
  • Loading branch information
yashk2810 authored and jax authors committed Oct 19, 2021
1 parent 6bd0c72 commit 4d8bce1
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 10 deletions.
25 changes: 24 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Remember to align the itemized text with the first line of an item within a list
PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
-->

## jaxlib 0.1.74 (Unreleased)

## jax 0.2.23 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.22...main).
Expand All @@ -19,7 +21,28 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* `jax.numpy.take` and `jax.numpy.take_along_axis` now require array-like inputs
(see {jax-issue}`#7737`)

## jaxlib 0.1.73 (Unreleased)
## jaxlib 0.1.73 (Oct 18, 2021)

* Multiple cuDNN versions are now supported for jaxlib GPU `cuda11` wheels.
* cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
installation is new enough, since it supports additional functionality.
* cuDNN 8.0.5 or newer.

* Breaking changes:
* The install commands for GPU jaxlib are as follows:

```bash
pip install --upgrade pip

# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

## jax 0.2.22 (Oct 12, 2021)
* [GitHub
Expand Down
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,31 @@ Next, run

```bash
pip install --upgrade pip
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Note: wheels only available on linux.
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Note: wheels only available on linux.
```

The jaxlib version must correspond to the version of the existing CUDA
installation you want to use:
* For CUDA 11.1, 11.2, or 11.3, use `cuda111`. The same wheel should work for
* For CUDA 11.1 or newer use `cuda11`. The same wheel should work for
CUDA 11.x releases from 11.1 onwards.
* Older CUDA versions are not supported.
* The supported cuDNN versions for `cuda11` are:
* cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
installation is new enough, since it supports additional functionality.
* cuDNN 8.0.5 or newer.

You can specify a particular CUDA and cuDNN version for jaxlib explicitly:

```bash
pip install --upgrade pip

# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

You can find your CUDA version with the command:

Expand Down
2 changes: 1 addition & 1 deletion jaxlib/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# reflect the most recent available binaries.
# __version__ should be increased after releasing the current version
# (i.e. on main, this is always the next version to be released).
__version__ = "0.1.73"
__version__ = "0.1.74"
21 changes: 15 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
from setuptools import setup, find_packages

# The following should be updated with each new jaxlib release.
_current_jaxlib_version = '0.1.72'
_available_cuda_versions = ['111']
_current_jaxlib_version = '0.1.73'
_available_cuda_versions = ['11']
_default_cuda_version = '11'
_available_cudnn_versions = ['82', '805']
_default_cudnn_version = '82'

_dct = {}
with open('jax/version.py') as f:
exec(f.read(), _dct)
__version__ = _dct['__version__']
_minimum_jaxlib_version = _dct['_minimum_jaxlib_version']

_libtpu_version = '0.1.dev20211012'
_libtpu_version = '0.1.dev20211018'

setup(
name='jax',
Expand Down Expand Up @@ -58,9 +61,15 @@
'requests'],

# CUDA installations require adding jax releases URL; e.g.
# $ pip install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html
**{f'cuda{version}': f"jaxlib=={_current_jaxlib_version}+cuda{version}"
for version in _available_cuda_versions}
# Cuda installation defaulting to a CUDA and Cudnn version defined above.
# $ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html
'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda{_default_cuda_version}.cudnn{_default_cudnn_version}"],

# CUDA installations require adding jax releases URL; e.g.
# $ pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
# $ pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
**{f'cuda={cuda_version},cudnn={cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda{cuda_version}.cudnn{cudnn_version}"
for cuda_version in _available_cuda_versions for cudnn_version in _available_cudnn_versions}
},
url='https://github.com/google/jax',
license='Apache-2.0',
Expand Down

0 comments on commit 4d8bce1

Please sign in to comment.