Skip to content

Commit

Permalink
Merge pull request #18991 from skye:revert_cuda_install
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 591097432
  • Loading branch information
jax authors committed Dec 15, 2023
2 parents 891d44c + 5d26c30 commit a7b6023
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 37 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ Some standouts:

| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
Expand Down
69 changes: 34 additions & 35 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ not being installed alongside `jax`, although `jax` may successfully install

## NVIDIA GPU

JAX supports NVIDIA GPUs that have SM version 5.0 (Maxwell) or newer.
JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer.
Note that Kepler-series GPUs are no longer supported by JAX since
NVIDIA has dropped support for Kepler GPUs in its software.

Expand All @@ -81,11 +81,11 @@ pip install --upgrade pip

# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11]"
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

If JAX detects the wrong version of the CUDA libraries, there are several things
Expand Down Expand Up @@ -162,6 +162,37 @@ Toolbox](https://github.com/NVIDIA/JAX-Toolbox) containers, which are
bleeding edge containers containing nightly releases of jax and some
models/frameworks.

## Nightly installation

Nightly releases reflect the state of the main repository at the time they are
built, and may not pass the full test suite.

* JAX:
```bash
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
```

* Jaxlib CPU:
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
```

* Jaxlib TPU:
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

* Jaxlib GPU (Cuda 12):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
```

* Jaxlib GPU (Cuda 11):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
```

## Google TPU

### pip installation: Google Cloud TPU
Expand Down Expand Up @@ -234,38 +265,6 @@ See the `conda-forge`
[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories
for more details.


## Nightly installation

Nightly releases reflect the state of the main repository at the time they are
built, and may not pass the full test suite.

* JAX:
```bash
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
```

* Jaxlib CPU:
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
```

* Jaxlib TPU:
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

* Jaxlib GPU (Cuda 12):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
```

* Jaxlib GPU (Cuda 11):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
```

## Building JAX from source
See [Building JAX from source](developer.md#building-from-source).

Expand Down

0 comments on commit a7b6023

Please sign in to comment.