Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

install command doesn't work #8279

Closed
xmax1 opened this issue Oct 19, 2021 · 15 comments
Closed

install command doesn't work #8279

xmax1 opened this issue Oct 19, 2021 · 15 comments
Assignees

Comments

@xmax1
Copy link

xmax1 commented Oct 19, 2021

What should the command to install be?

`
(pansatz) [amawi@sylg nn_ansatz]$ pip install --upgrade pip

Requirement already satisfied: pip in /home/energy/amawi/miniconda3/envs/pansatz/lib/python3.7/site-packages (21.3)

(pansatz) [amawi@sylg nn_ansatz]$ pip install jax[cuda==11, cudnn==82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

ERROR: Invalid requirement: 'jax[cuda==11,'

(pansatz) [amawi@sylg nn_ansatz]$ pip install jax[cuda=11, cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

ERROR: Invalid requirement: 'jax[cuda=11,'
Hint: = is not a valid operator. Did you mean == ?

(pansatz) [amawi@sylg nn_ansatz]$ pip install "jax[cuda=11, cudnn=82]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

ERROR: unknown command "install jax[cuda=11, cudnn=82]"
`

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 19, 2021

Sorry, we're working on it – the commands mentioned in the README are new as of yesterday, and we had to yank the release that supports specifying cudnn versions. Until we get that fixed, you can use the installation instructions from the 0.2.22 release: https://github.com/google/jax/blob/jax-v0.2.22/README.md#pip-installation-gpu-cuda

@xmax1
Copy link
Author

xmax1 commented Oct 19, 2021

Thanks for the quick response I was so confused
The 02.22 release install works

@yashk2810
Copy link
Member

https://pypi.org/project/jax/0.2.24/ is live which fixes the error.

@GeoffNN
Copy link

GeoffNN commented Oct 19, 2021

I'm still getting this error from within a (new) conda environment, with both python 3.9 and 3.10.

When using the previous instructions
pip install --upgrade "jax[cuda113]==0.2.22" -f https://storage.googleapis.com/jax-releases/jax_releases.html
I get the following WARNING: jax 0.2.24 does not provide the extra 'cuda113'.
Finally, I get ModuleNotFoundError: No module named 'jaxlib' when running import jax from a python prompt.

I have CUDA 11.3 and cuDNN 8.2.

@yashk2810
Copy link
Member

We don't have py3.10 builds yet.

I have CUDA 11.3 and cuDNN 8.2.

Try, pip install -U jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html.

The new instructions are here: https://github.com/google/jax/blob/main/README.md#pip-installation-gpu-cuda

@GeoffNN
Copy link

GeoffNN commented Oct 19, 2021

Thanks!
This gets jaxlib -- strangely I still get an import error, this time with ModuleNotFoundError: No module named 'scipy.linalg' -- scipy is installed.

@yashk2810
Copy link
Member

What are you trying to import?

@GeoffNN
Copy link

GeoffNN commented Oct 19, 2021

import jax gets this stacktrace

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../anaconda3/envs/aaa/lib/python3.9/site-packages/jax/__init__.py", line 37, in <module>
    from . import config as _config_module
  File ".../anaconda3/envs/aaa/lib/python3.9/site-packages/jax/config.py", line 18, in <module>
    from jax._src.config import config
  File ".../anaconda3/envs/aaa/lib/python3.9/site-packages/jax/_src/config.py", line 27, in <module>
    from jax._src import lib
  File ".../anaconda3/envs/aaa/lib/python3.9/site-packages/jax/_src/lib/__init__.py", line 75, in <module>
    from jaxlib import lapack
  File ".../anaconda3/envs/aaa/lib/python3.9/site-packages/jaxlib/lapack.py", line 21, in <module>
    from . import _lapack
ImportError: ModuleNotFoundError: No module named 'scipy.linalg'
>>> 

@yashk2810
Copy link
Member

Interesting.. i don't see that.

I just tried this in colab: !pip install --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html with a GPU runtime.

Then import jax; jax.devices()

Can you check on colab and see if you hit the error?

@GeoffNN
Copy link

GeoffNN commented Oct 19, 2021

On colab, I only get the warning WARNING: jax 0.2.21 does not provide the extra 'cuda11_cudnn805'; the import and devices work. (But also Requirement already satisfied: jax[cuda11_cudnn805] in /usr/local/lib/python3.7/dist-packages (0.2.21))

The machine I am getting the errors on is running Ubuntu 18.04.5.

@yashk2810
Copy link
Member

I think you also need pip install -U jax to get the latest jax version which has the new install commands.

@GeoffNN
Copy link

GeoffNN commented Oct 19, 2021

Do you mean using that before pip install -U jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html ? Using it after gets a bunch of requirement already satisfied messages

@yashk2810
Copy link
Member

I think before. pip install -U jax[cuda11_cudnn805] -f ... should have installed the latest jax.

But I don't see that error. So maybe something is wrong with your setup?

@GeoffNN
Copy link

GeoffNN commented Oct 20, 2021

Probably -- doing conda update conda and then re-doing everything from scratch: new environment, first pip install -U jax, then pip install -U jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html worked.
Thanks!

@CornCobs
Copy link

CornCobs commented Oct 22, 2021

UPDATE: I managed to get it to work! installing via jax[cuda] instead of jax[cuda11] somehow managed to get pip to correctly find and install jaxlib:

Collecting jaxlib==0.1.73+cuda11.cudnn82
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.73%2Bcuda11.cudnn82-cp37-none-manylinux2010_x86_64.whl (138.6 MB)

I'm getting a similar error trying to install jax. On a singularity image running ubuntu18.04.
cuda version: 11.1
cudnn version: 8.0

The command I ran was
pip install --prefix=[omitted] --upgrade "jax[cuda11]" -f https://storage.googleapis.com/jax-releases/jax_releases.html and the output is:

Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html
Collecting jax[cuda11]
  Downloading jax-0.2.24.tar.gz (786 kB)
     |████████████████████████████████| 786 kB 19.7 MB/s
WARNING: jax 0.2.24 does not provide the extra 'cuda11'
Collecting absl-py
  Downloading absl_py-0.15.0-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 54.1 MB/s
Collecting numpy>=1.18
  Downloading numpy-1.21.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.7 MB)
     |████████████████████████████████| 15.7 MB 41.1 MB/s
Collecting opt_einsum
  Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)
     |████████████████████████████████| 65 kB 1.3 MB/s
Collecting scipy>=1.2.1
  Downloading scipy-1.7.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (28.5 MB)
     |████████████████████████████████| 28.5 MB 114 kB/s
Collecting six
  Downloading six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting typing_extensions
  Downloading typing_extensions-3.10.0.2-py3-none-any.whl (26 kB)
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... done
  Created wheel for jax: filename=jax-0.2.24-py3-none-any.whl size=903114 sha256=9100eaa1bb6616d51d1b79e20451cdfab8f0a5d46994b63ad88c1253b86a61a1
  Stored in directory: /<omitted>/.cache/pip/wheels/28/a9/0f/3497740c85f6e1de8f4d291fd2f77d046d66a87620143d0d0e
Successfully built jax
Installing collected packages: six, numpy, typing-extensions, scipy, opt-einsum, absl-py, jax

However when attempting to load the package I get a ModuleNotFoundError: No module named 'jaxlib'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants