Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multivariate_normal, change api of jax.random #269

Closed
wants to merge 1 commit into from

Conversation

h3jia
Copy link
Contributor

@h3jia h3jia commented Jan 19, 2019

I've just added the multivariate normal sampler, while I'd love to hear about your opinions on the API, specifically on the order of the random samplers' arguments. I think the general rule for this in np.random is roughly np.random.sampler(distribution parameters, shape, dtype). For example, they have np.random.normal(loc=0.0, scale=1.0, size=None) and random.randint(low, high=None, size=None, dtype='l').

Following the discussion at #260 , we cannot make jax.random's API the same as np.random's, since we need to explicitly specify the key. In this context, I'd suggest we use jax.random.sampler(key, distribution parameters, shape, dtype), i.e. to put the key at the beginning and keep the others similar to np.random.

I've changed the API in this way in the PR. If you agree with this, I will continue to check and modify the examples and tests, since there're things like jax.random.normal(key, shape), where the positional arguments become invalid under the new API. I think this should be the reason why Travis CI failed currently. Or please tell me if you have different attitudes on the API.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this PR proposed many nice points for jax.random API so it would be great if this PR is finished and reviewed.


Args:
key: a PRNGKey used as the random key.
shape: a tuple of nonnegative integers representing the shape.
loc: optional, the mean of the distribution (default 0.0).
scale: optional, the standard deviation of the distribution (default 1.0).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we can keep the current behaviour (without loc, scale). Many samplers can be transformed by x -> loc * x + scale so it is unnecessary to add loc, scale for all of them.

lo = onp.nextafter(onp.array(-1., dtype), 0., dtype=dtype)
hi = onp.array(1., dtype)
u = uniform(key, lo, hi, (*shape, dim), dtype)
return mean + np.dot(onp.array(onp.sqrt(2), dtype) * lax.erf_inv(u), L.T)
Copy link
Member

@fehiepsi fehiepsi Mar 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is np.dot(L, vector) instead of np.dot(vector, L.T)? In addition, we can simplify the implementation by taking an eps = normal(key, shape, dtype) instead of reimplementing.

@fehiepsi fehiepsi mentioned this pull request Sep 27, 2019
@jekbradbury
Copy link
Contributor

Closing in favor of #1389. Thanks for the PR, and sorry for letting it languish!

@jekbradbury jekbradbury closed this Oct 4, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants