Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Numpy and JAX substrates using exoplanet-core #132

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft

Conversation

dfm
Copy link
Member

@dfm dfm commented Dec 31, 2020

This is a big one and it'll be nice when it's finished!

To do:

  • Implement JAX compat
  • Update dev docs
  • Update tutorials
  • Write tests for substrates
  • Write tutorial(s?) for substrates

@dfm dfm marked this pull request as draft December 31, 2020 02:46
@dfm dfm added this to In progress in Major release with JAX support via automation Dec 31, 2020
@dfm
Copy link
Member Author

dfm commented Jan 4, 2021

There was some discussion on Twitter and elsewhere about the interface for this change. Here's my summary:

My original proposal:

import exoplanet.pymc as xo
# and
import exoplanet.jax as xo
# and
import exoplanet.numpy as xo

was met with some criticism because it seems like exoplanet now exports, for example, numpy (which I certainly agree is a bit odd!). Some alternatives were proposed:

  1. From @barentsen: xo.use("...") for consistency with matplotlib syntax. This won't work for us here because the particular interface/backend chosen will depend on the user's code. For example, when using PyMC3 for inference, users will generally want to import the theano/pymc implementation, whereas users of TensorFlow probability or numpyro will want to import the jax implementation. Even though the exoplanet interface (things like xo.orbits.KeplerianOrbit) will have the same syntax, users will really care about which backend they have imported. Furthermore, there are many good reasons why users might want to use both the numpy interface and another within the same program (in order to simulate data and then fit it, respectively).

  2. From @jedbrown: from expolanet.jax import xo to have a less confusing export. This would certainly work, but I'm not sure that it's far better than the original or my proposal below. I would probably rephrase this as from exoplanet.jax import exoplanet as xo since xo is more like the np for numpy and I don't want to start a pl vs plt flame war :D

Some other options include:

  1. Having completely separate packages exoplanet_jax, exoplanet_pymc, etc., but that adds some maintenance overhead and I'm still not convinced that import exoplanet_jax as xo is much less confusing than import exoplanet.jax as xo.

  2. Automatically detecting the context in which the library is being used. This would be slick, but it seems hard to do properly and it might be tricky to support multiple backends within the same script. I think explicit is better.

Finally, there was also some words of warning from @twiecki that supporting multiple backends might not be worth it, but this isn't quite the whole story here because of the general design of exoplanet. The key is that exoplanet is not a high-level interface for doing inference. Instead, exoplanet provides the building blocks to construct probabilistic models inside of higher-level frameworks. At it's core, it's really just a couple of custom C++ ops and backpropagation rules that evaluate exoplanet/astrophysics-specific models. While I agree that it adds some maintenance and contribution overhead to provide interfaces to these ops that support multiple frameworks, I think it's worth it! It's not obvious which inference library (PyMC 3+, TensorFlow probability, numpyro, emcee, dynesty, ...) is best suited for all exoplanet inference, and if a small amount of code is sufficient to expose support for all of these frameworks then I'm all for it.

It's also worth noting that there are plans to migrate PyMC to a JAX backend which means that it was going to be necessary to implement JAX-compatible XLA ops for all of the exoplanet code anyways.

My proposed API

I think that, in the short term, PyMC (perhaps implicitly with JAX used behind the scenes) will still be the primary interface, so I think that supporting import exoplanet as xo as an alias for the pymc interface would be good. Then I propose to expose the other interfaces using either the same syntax as TensorFlow probability:

import exoplanet.substrates.jax as xo

or the word interfaces:

import exoplanet.interfaces.jax as xo

since the goal is not quite the same as TFP's substrates. This still has the issue that it's strange to export modules called numpy or jax, but it does have a precedent in a similar domain. If folks have thoughts about this, I'd always love to hear them and apologies in advance if I'm a little stubborn and over-committed to what I've done so far.

@jedbrown
Copy link

jedbrown commented Jan 5, 2021

I don't see import numpy as np as a pattern that needs repeating. It's used because import np would be claiming the non-descript np in the global namespace and yet people get tired of writing numpy.array and numpy.exp. You're already under the exoplanet namespace so you don't have to worry about collisions.

Re your "other option 2": do these interfaces only need to be accessed inside a with lib.Model() as model: block, or do you need a compatible set outside the model block? What about this, which is very explicit and self-contained:

with lib.Model() as model:
    xo = exoplanet.api(model)
    ...

You could maybe offer a convenience wrapper to save a line

with exoplanet.Model(lib) as (xo, model):
    ...

@twiecki
Copy link

twiecki commented Jan 6, 2021

Looks like a great proposal but I guess I would still ask what the benefit to the user is, or whether you've gotten requests from users for this. Bambi supported stan and pymc3 because the authors thought it would be good to let the user switch between them but it really made no difference, you still specified your model in the same way and got the same answers. A lot of developer and code complexity was spent on it which could have been spent on improving UX, adding features, documentation etc. Now there were 2 areas of bugs and things to support, all new features needed to be implemented for both backends, so dev velocity was severely negatively affected. So while it might be cool, I would still think carefully about the cost-benefit-tradeoff.

@dfm
Copy link
Member Author

dfm commented Jan 6, 2021

@jedbrown: Good point about namespaces! I'll think a little more about that part. For PyMC, your proposed model context interface would work, but it's definitely trickier in other cases. For example, numpyro and emcee, two use cases that I want to support, both use different model syntax. But the exoplanet.api(...) syntax (perhaps with something else, like a string, in the function call) is definitely something to consider - thanks!

@twiecki: I totally agree that it's not obvious and that it's important to do the cost-benefit analysis! I do get lots of requests for the numpy API because there are lots of use cases where the Theano dependency and model compilation overhead are not beneficial. In fact, I'd already implemented a subset of the API bypassing Theano/PyMC entirely. As for JAX, that's just for me so far, but (in my opinion) I'm my most important customer so I think it's worth it for now :D. I think I'm going to push forward with this while making sure that I market the PyMC interface as the stable one. As I said, the JAX ops are necessary (and I'm already getting some nice benefits out of them, more soon!) but a full API implementation is just the cherry on top so I agree that I should be careful about promising too much!

Thanks, both, for the feedback!

@twiecki
Copy link

twiecki commented Jan 6, 2021

@dfm Sounds good 👍, excited for the JAX backend!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

None yet

3 participants