Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax cpp interface #1871

Closed
mgbukov opened this issue Dec 16, 2019 · 9 comments
Closed

jax cpp interface #1871

mgbukov opened this issue Dec 16, 2019 · 9 comments

Comments

@mgbukov
Copy link

mgbukov commented Dec 16, 2019

Is there (any work towards) a c++ interface for jax?

@skye
Copy link
Collaborator

skye commented Dec 16, 2019

What would you like to do exactly? It's currently possible to save a jit'd version of a function as an XLA program, which can then be invoked from C++ via XLA's C++ client. Check out jax/tools/jax_to_hlo.py (HLO is how XLA programs are represented).

@mgbukov
Copy link
Author

mgbukov commented Dec 16, 2019 via email

@skye
Copy link
Collaborator

skye commented Dec 16, 2019

With the above tooling I mentioned, you could build the network + gradient function in Python, save the resulting XLA HLO, then quickly evaluate both via the C++ interface.

(I'm gonna close this issue because I don't think there's any concrete work to do here, but feel free to continue commenting.)

@dawgster
Copy link

Hi @skye!
Could you give a few pointers as to how to get started on this? Ive went through the C++ client you linked but havnt found a possibility to load an HLO file at first or second glances. Ive already built a network in jax, and transformed it to HLO. The final piece of the puzzle is to run it via that interface. Any help would be greatly appreciated. Thanks!

@hawkinsp
Copy link
Member

@dawgster I think the key thing to notice is that the XlaComputation passed to LocalClient::Compile() has a constructor that takes an HloModuleProto. Does that help?

@dawgster
Copy link

@hawkinsp Yes, that helped quite a bit. Thanks for your help!

@zhangqiaorjc
Copy link
Member

#5337 shows an explicit example

It can be adapted easily for GPU client.

@uduse
Copy link

uduse commented May 16, 2021

Do you think it's a good alternative to call JAX python code in C++?

@Roy-Kid
Copy link

Roy-Kid commented Sep 6, 2021

Do you think it's a good alternative to call JAX python code in C++?

something like #include<Python.h>? I wonder if there is a more elegant approach to invoke jax in c++. I need to write a plugin for OpenMM and the reason I use jax is that it is small and flexible than tf.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants