Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cuda support #72

Closed
wants to merge 5 commits into from
Closed

add cuda support #72

wants to merge 5 commits into from

Conversation

wolfv
Copy link
Member

@wolfv wolfv commented Nov 11, 2021

Your wish is my command!

Let's see if this works. Trying to build it "locally" on a server right now.

@conda-forge-linter
Copy link

Hi! This is the friendly automated conda-forge-linting service.

I just wanted to let you know that I linted all conda-recipes in your PR (recipe) and found it was in an excellent condition.

@wolfv wolfv mentioned this pull request Nov 11, 2021
@wolfv
Copy link
Member Author

wolfv commented Nov 11, 2021

hmmm doesn't seem to be that simple!
It downloads the tensorflow cuda bazel rules or something, and then does some more... which I don't fully understand yet.

@wolfv
Copy link
Member Author

wolfv commented Nov 11, 2021

Ok, for my own notes:

bazel cache is in _build_env/share/bazel/... and that's where we can find the local_config_cuda.

@wolfv
Copy link
Member Author

wolfv commented Nov 11, 2021

OK, I think this might work now :)

@wolfv
Copy link
Member Author

wolfv commented Nov 11, 2021

Success!

Python 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:23:11)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
jax.>>> jax.devices()
[GpuDevice(id=0, process_index=0)]

@wolfv
Copy link
Member Author

wolfv commented Nov 11, 2021

You can try this one: https://anaconda.org/wolfv/jaxlib/files

@wolfv
Copy link
Member Author

wolfv commented Nov 11, 2021

Which one should be the default variant? Cuda or no cuda? Should we make a jaxlib-gpu meta package?

@ericmjl
Copy link
Contributor

ericmjl commented Nov 11, 2021

@wolfv thank you so much for your work here!

I think no-cuda should be the default. Most of us don't carry around an NVIDIA-equipped laptop. 😄.

jaxlib-gpu could be a good idea. It'd have to be paired with a "jax-gpu" meta-package, would that be right? The reason I ask is because jaxlib is a dependency of jax; I'm not sure how the conda commands would look like when installing jax though.

Trying to work backwords, wondering what your thoughts are on establishing a 'canonical' way of installing JAX via conda/mamba? Here are some ways I can think of:

conda install -c conda-forge jax  # installs cpu-only? gpu-only?
conda install -c conda-forge jax jaxlib-gpu  # this way to install GPU?
conda install -c conda-forge jax-gpu  # this way to install GPU?

@wolfv
Copy link
Member Author

wolfv commented Nov 11, 2021

Hi @ericmjl

the GPU packages will also work with a non-gpu laptop (but the download is pretty big).
I don't know which one you guys prefer as I personally don't use JAX at all, so it's really up to you :)

We could do both, a jax-gpu and jaxlib-gpu package (as well as jax-cpu and jaxlib-cpu).

@ericmjl
Copy link
Contributor

ericmjl commented Nov 11, 2021

These are great thoughts. Thanks a ton, @wolfv 😄.

I personally think that having to worry about the cpu/gpu divide is a bit troublesome. I've never tried installing a GPU-compiled jaxlib on a CPU-only machine; have you tried that? Does the following code block work correctly?

import jax.numpy as np
a = np.arange(3)

If it does, then I can see a path to making the GPU package the default thing installed from conda-forge, under a command conda/mamba install jax. Doing so would be a wonderful quality-of-life enhancement for data scientists! It'd also greatly simplify how we write our environment definitions; no need to worry about whether a GPU is present or not; it just works. Size is usually the least of our concerns; data science containers are large to begin with; our regular hard disks have enough capacity; and cloud storage, so I hear, is infinite 😸.

But if not, then I think having a jax/jaxlib vs. jax-gpu/jaxlib-gpu divide is a sensible thing to do. We can pick and choose based on what we need.

@xhochy
Copy link
Member

xhochy commented Nov 11, 2021

As far as I understood the discussion on the tensorflow-feedstock, we would only install the *-gpu variant if there is a CUDA installation/GPU available? Thus having the GPU variant being the default one would be preferred as such you get the fastest package for your hardware.

@wolfv
Copy link
Member Author

wolfv commented Nov 11, 2021

@xhochy I am afraid that isn't true. unfortunately the __cuda "dependency" is a run-constraint on cudatoolkit which means that it's not considered if completely absent.

However, the package itself works on CPU as well (because it comes with code for CPU and GPU at the same time). I strongly believe this is true for both tensorflow and jaxlib.

Now, since a "cpu-only" package is available, we could introduce a proper run dependency on __cuda in these packages here which would make the solver select the CPU package if no __cuda is available at all.

@ericmjl
Copy link
Contributor

ericmjl commented Nov 20, 2021

Now, since a "cpu-only" package is available, we could introduce a proper run dependency on __cuda in these packages here which would make the solver select the CPU package if no __cuda is available at all.

I think that sounds like a good idea, @wolfv!

@lkhphuc
Copy link

lkhphuc commented Dec 1, 2021

Just want to say thank you to all. I spent yesterday running into all sort of problems and conflicts trying to get jax, tensorflow, tensorflow-probability, distrax etc running on my 3090. Using this package is the only that work.
google/jax#5723

Curious what block it from being merged?

@wolfv
Copy link
Member Author

wolfv commented Dec 1, 2021

I can give this another push today :)

@ngam
Copy link
Contributor

ngam commented May 13, 2022

#97

@ngam ngam mentioned this pull request May 23, 2022
5 tasks
@xhochy xhochy closed this in #103 May 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants