Skip to content

Commit

Permalink
Some docs improvements (#1761)
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipVinc committed Apr 1, 2024
1 parent 5647b73 commit a3ae955
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"flax": ("https://flax.readthedocs.io/en/latest/", None),
"igraph": ("https://igraph.org/python/api/latest", None),
"qutip": ("https://qutip.org/docs/latest/", None),
"qutip": ("https://qutip.readthedocs.io/en/latest/", None),
"pyscf": ("https://pyscf.org/", None),
}

Expand Down
29 changes: 28 additions & 1 deletion docs/docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,29 @@ At the time of writing, installing a GPU version of jaxlib is as simple as runni

```bash
pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

Where the jaxlib version must correspond to the version of the existing CUDA installation you want to use.
Refer to jax documentation to learn more about matching cuda versions with python wheels.

````{admonition} CUDA
:class: warning
Jax supports two ways to install the cuda-version: `cuda12_pip` and `cuda12_local`. The `_local` version will use the CUDA version installed by the user/cluster admins and pick it up through the `LD_LIBRARY_PATH`. You will need to have installed cuda, cudnn and some other dependencies.
If you chose this approach, it is your responsability to ensure that the CUDA version is correct and all dependencies are present.
**This approach is required if you wish to use MPI**.
`_pip`, instead, will install a special CUDA version through `pip` in the current environment, and ignore the CUDA version that is installed system-wide. This approach is usually much simpler to use but **is not compatible with MPI on GPUS**.
Do note that if you install the `_pip` version, to switch to the `_local` version you must uninstall all nvidia-related dependencies. To do so, the simplest way is to simply delete the environment and start from scratch. If you don't want to do so, you may try to use the following command, but it might not work perfectly
```bash
pip freeze | grep nvidia-cuda | xargs pip uninstall -y
```
````



(install_mpi)=
## MPI
Expand Down Expand Up @@ -92,6 +109,16 @@ pip install --upgrade "netket[mpi]"
Subsequently, NetKet will exploit MPI-level parallelism for the Monte-Carlo sampling.
See {ref}`this block <warn-mpi-sampling>` to understand how NetKet behaves under MPI.

````{admonition} CUDA
:class: warning
If you wish to use multi-GPU setups through MPI, you **must** install jax as `'jax[cuda12_local]'` and cannot use the `'jax[cuda12_pip]'` variant.
This is because `cuda12_pip` installs CUDA through pip, which does not include the nvidia compiler `nvcc`, which in turn it is needed to install `mpi4jax`. If, for any reason, you already have `nvcc` but are using `cuda12_pip`, installation might not fail but you will get an error due to version mismatch of cuda versions at runtime.
````


(conda)=
## Conda

Expand Down
5 changes: 2 additions & 3 deletions netket/sampler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,8 @@ def n_chains_per_rank(self) -> int:
.. code:: python
from netket.jax import sharding
sharding.device_count()
sampler.n_chains // sharding.device_count()
from netket.jax import sharding
sampler.n_chains // sharding.device_count()
"""
n_devices = sharding.device_count()
Expand Down

0 comments on commit a3ae955

Please sign in to comment.