add multivariate_normal, change api of jax.random #269
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I've just added the multivariate normal sampler, while I'd love to hear about your opinions on the API, specifically on the order of the random samplers' arguments. I think the general rule for this in
np.random
is roughlynp.random.sampler(distribution parameters, shape, dtype)
. For example, they havenp.random.normal(loc=0.0, scale=1.0, size=None)
andrandom.randint(low, high=None, size=None, dtype='l')
.Following the discussion at #260 , we cannot make
jax.random
's API the same asnp.random
's, since we need to explicitly specify the key. In this context, I'd suggest we usejax.random.sampler(key, distribution parameters, shape, dtype)
, i.e. to put the key at the beginning and keep the others similar tonp.random
.I've changed the API in this way in the PR. If you agree with this, I will continue to check and modify the examples and tests, since there're things like
jax.random.normal(key, shape)
, where the positional arguments become invalid under the new API. I think this should be the reason why Travis CI failed currently. Or please tell me if you have different attitudes on the API.