-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorized named parameters and generalized allowed input #462
Conversation
88b4b6c
to
44c080c
Compare
for more information, see https://pre-commit.ci
Thanks! I think this looks great, although I don't really use this interface, so I'm not extremely familiar with the details. A couple of tests to demonstrate this behavior would be awesome! |
Thanks for taking a look! Will work on the tests soon (the current failure is just a preexisting check that vectorization + name parameters raised an assertion error). I suppose it would be prudent to add examples to the docs, too, while I'm at it. In fact, I don't think there's an example of vectorization in there - probably good to point out, since for cheap likelihoods it can be much faster than multiprocessing (given the latter's added overhead). Another idea: would it not be better UX to forgo the |
This is an interesting idea! I've always felt that this interface didn't have the best UX. One alternative idea: Recently, I've been using JAX a lot and I love the Pytree interface, which easily covers all the use cases I would expect us to need here. I'm not very keen to bring in JAX as a dependency, but we could at least write a tutorial talking through how you could use JAX to build an interface for emcee models. As a demo, I put together a notebook that demonstrates this and it really only requires a tiny bit of boilerplate code, even with support for What do you think about using this interface and using this instead of trying to reinvent the wheel ourselves over here? |
I like the idea in principle - thanks for the demo (I haven't taken the opportunity to try jax before). I've certainly long yearned for a seamless flat array <-> dict-like data container converter (having hacked such interfaces for However, I'm going to advocate for my previous idea (inspecting
More broadly, I have played with writing a thin wrapper to automate some of the usual boilerplate---specifying parameters with names and priors, using the names to create dicts from parameter vectors (i.e., as |
Those bench marking numbers actually don't scare me - stating them as relative for such a cheap model does make them sound scary, but in most real world applications, I'd expect this to be a fairly small price to pay. But no I wasn't suggesting that we make JAX a dependency of emcee! All that being said, if you're keen to implement a version that inspects the input state and you're up for testing it thoroughly then that sounds like a great emcee contribution to me, and I'd love to review it! Do you want to use this PR for it or open a new one? |
I totally agree - I just meant to frame it as an overhead compared to emcee's existing internal overhead, which I think is the right thing to compare to. 300us/eval is slow even compared to the more expensive moves (like KDE). (And I certainly do have use cases---the one that motivated this PR---with << 300us likelihoods where vectorization helps a lot.)
Yeah, I didn't think you were suggesting that - but it would be a dependency for those who want to follow your example! That's what I meant as overkill.
It sounds like the way to go to me! I'll open a new PR. Thanks for hashing this out. |
This first adds support for vectorized named parameters. I also figured requiring the dict input to have values of type
list[int]
was needlessly restrictive, so I generalized it to allow slices, ints, or arbitrary sequences of ints. This includes multidimensional arrays, i.e., the named parameters now can have arbitrary shapes (which I think was originally desired in #386).I also simplified the input validation (and replaced
asserts
with exceptions)---the messages could be more specific (e.g., "you have too many/too few/duplicate indices"), but I think they're sufficient as is. So far as I have thought of, the one kind of semi-invalid input that could sneak through is overslicing, i.e., passing{"x": slice(ndim+1)}
. Presumably this would still, say, trigger an exception inlog_prob
if one was really expectingx
to have size larger thanndim
.I will add tests for the extra cases allowed, I just put this together quickly (and got it working for my use case) and wanted to post for feedback.
Small note: I removed the import try/except for python 2 when I added an import from
collections.abc
.(Edit: force pushed to replace checks for
Iterable
s withSequences
as dictionaries are the former but not the latter.)