We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inf
multivariate_normal.logpdf gives infs when there are multiple batches:
multivariate_normal.logpdf
>>> from jax import scipy >>> import jax.numpy as np >>> from jax.scipy.stats import multivariate_normal >>> multivariate_normal.logpdf(np.array([[1. ,2.], ... [2., 3.]]), ... np.array([[1., 2.], ... [2., 3.]]), ... np.array([[[1., 0.2], ... [0.2, 1.]], ... [[1., 0.3], ... [0.3, 1.]]])) DeviceArray([inf, inf], dtype=float32)
The text was updated successfully, but these errors were encountered:
Thanks for the report!
I beleive the issue is that the diagonal of the decomposed covariance is not properly handled in the block case. It looks like changing this line: https://github.com/google/jax/blob/bb0750f31a35e2661834e8b532781785fe63496e/jax/_src/scipy/stats/multivariate_normal.py#L46-L47 to this:
return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2*np.log(2*np.pi) - jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
will make things work correctly.
I can prepare a fix with tests sometime tomorrow if nobody gets to it before that.
Sorry, something went wrong.
jakevdp
Successfully merging a pull request may close this issue.
multivariate_normal.logpdf
givesinf
s when there are multiple batches:The text was updated successfully, but these errors were encountered: