Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorized named parameters and generalized allowed input #462

Closed
wants to merge 5 commits into from

Conversation

zachjweiner
Copy link

@zachjweiner zachjweiner commented Mar 1, 2023

This first adds support for vectorized named parameters. I also figured requiring the dict input to have values of type list[int] was needlessly restrictive, so I generalized it to allow slices, ints, or arbitrary sequences of ints. This includes multidimensional arrays, i.e., the named parameters now can have arbitrary shapes (which I think was originally desired in #386).

I also simplified the input validation (and replaced asserts with exceptions)---the messages could be more specific (e.g., "you have too many/too few/duplicate indices"), but I think they're sufficient as is. So far as I have thought of, the one kind of semi-invalid input that could sneak through is overslicing, i.e., passing {"x": slice(ndim+1)}. Presumably this would still, say, trigger an exception in log_prob if one was really expecting x to have size larger than ndim.

I will add tests for the extra cases allowed, I just put this together quickly (and got it working for my use case) and wanted to post for feedback.

Small note: I removed the import try/except for python 2 when I added an import from collections.abc.

(Edit: force pushed to replace checks for Iterables with Sequences as dictionaries are the former but not the latter.)

src/emcee/ensemble.py Show resolved Hide resolved
@dfm
Copy link
Owner

dfm commented Mar 5, 2023

Thanks! I think this looks great, although I don't really use this interface, so I'm not extremely familiar with the details. A couple of tests to demonstrate this behavior would be awesome!

@zachjweiner
Copy link
Author

Thanks for taking a look! Will work on the tests soon (the current failure is just a preexisting check that vectorization + name parameters raised an assertion error).

I suppose it would be prudent to add examples to the docs, too, while I'm at it. In fact, I don't think there's an example of vectorization in there - probably good to point out, since for cheap likelihoods it can be much faster than multiprocessing (given the latter's added overhead).

Another idea: would it not be better UX to forgo the parameter_names argument and simply allow users to pass dict-like initial_states and infer the mapping to a flat array automatically? Maybe this goes beyond what you'd want to support, given that you don't use the interface yourself. (Feel free to ping me if issues arise here either way.)

@dfm
Copy link
Owner

dfm commented Mar 5, 2023

Another idea: would it not be better UX to forgo the parameter_names argument and simply allow users to pass dict-like initial_states and infer the mapping to a flat array automatically? Maybe this goes beyond what you'd want to support, given that you don't use the interface yourself. (Feel free to ping me if issues arise here either way.)

This is an interesting idea! I've always felt that this interface didn't have the best UX. One alternative idea: Recently, I've been using JAX a lot and I love the Pytree interface, which easily covers all the use cases I would expect us to need here. I'm not very keen to bring in JAX as a dependency, but we could at least write a tutorial talking through how you could use JAX to build an interface for emcee models. As a demo, I put together a notebook that demonstrates this and it really only requires a tiny bit of boilerplate code, even with support for vectorized models: https://gist.github.com/dfm/3e637d60d452e1306b2c4077e33f103a

What do you think about using this interface and using this instead of trying to reinvent the wheel ourselves over here?

@zachjweiner
Copy link
Author

I like the idea in principle - thanks for the demo (I haven't taken the opportunity to try jax before). I've certainly long yearned for a seamless flat array <-> dict-like data container converter (having hacked such interfaces for scipy.integrate.solve_ivp which itself only interfaces with flat arrays).

However, I'm going to advocate for my previous idea (inspecting initial_state) for the following reasons:

  • I think dict[str, float | np.ndarray] is quite straightforward to support, has a well-defined scope, and provides a significant QOL improvement for (I suspect) the vast majority of use cases. (I've never used emcee without writing a wrapper to emulate this!) So I think it makes sense for emcee, even given its intended minimalism re: interface. I wouldn't want to sacrifice it to instead suggest users take on a (beefy) additional dependency and add rather than remove boilerplate, all to enable arbitrarily nested data structures - it seems like overkill. (That said, I'd be interested if you have examples where the full power of Pytrees was uniquely useful - maybe my view is too narrow!)
  • The overhead is killer. Locally I observe the wrapped/jit-ed function adds 330us of overhead. Sampling in your demo (with 50 walkers) is 40x slower than the vanilla, unvectorized case. Even with vectorization a factor of 5 remains. The implementation of named parameters in this PR adds no observable overhead - the conversion takes half a microsecond (which makes sense, as taking views of arrays and constructing dicts should be nearly free).

More broadly, I have played with writing a thin wrapper to automate some of the usual boilerplate---specifying parameters with names and priors, using the names to create dicts from parameter vectors (i.e., as parameter_names does), evaluating the priors and skipping likelihood evals if log_prior == -np.inf, sampling from the priors to generate initial state vectors, blobs handling, converting chains to xarray.Dataset/arviz.InferenceData. I'd be happy to discuss more if you're it's something you're interested in pursuing (external to emcee, of course, if not generalized beyond it as well).

@dfm
Copy link
Owner

dfm commented Mar 5, 2023

Those bench marking numbers actually don't scare me - stating them as relative for such a cheap model does make them sound scary, but in most real world applications, I'd expect this to be a fairly small price to pay. But no I wasn't suggesting that we make JAX a dependency of emcee!

All that being said, if you're keen to implement a version that inspects the input state and you're up for testing it thoroughly then that sounds like a great emcee contribution to me, and I'd love to review it! Do you want to use this PR for it or open a new one?

@zachjweiner
Copy link
Author

zachjweiner commented Mar 5, 2023

Those bench marking numbers actually don't scare me - stating them as relative for such a cheap model does make them sound scary, but in most real world applications, I'd expect this to be a fairly small price to pay.

I totally agree - I just meant to frame it as an overhead compared to emcee's existing internal overhead, which I think is the right thing to compare to. 300us/eval is slow even compared to the more expensive moves (like KDE). (And I certainly do have use cases---the one that motivated this PR---with << 300us likelihoods where vectorization helps a lot.)

But no I wasn't suggesting that we make JAX a dependency of emcee!

Yeah, I didn't think you were suggesting that - but it would be a dependency for those who want to follow your example! That's what I meant as overkill.

All that being said, if you're keen to implement a version that inspects the input state and you're up for testing it thoroughly then that sounds like a great emcee contribution to me, and I'd love to review it! Do you want to use this PR for it or open a new one?

It sounds like the way to go to me! I'll open a new PR. Thanks for hashing this out.

@zachjweiner zachjweiner closed this Mar 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants