Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inf using multivariate_normal.logpdf with multiple batches #5570

Closed
sidravi1 opened this issue Jan 31, 2021 · 1 comment · Fixed by #5584
Closed

inf using multivariate_normal.logpdf with multiple batches #5570

sidravi1 opened this issue Jan 31, 2021 · 1 comment · Fixed by #5584
Assignees
Labels
bug Something isn't working

Comments

@sidravi1
Copy link

multivariate_normal.logpdf gives infs when there are multiple batches:

>>> from jax import scipy
>>> import jax.numpy as np
>>> from jax.scipy.stats import multivariate_normal
>>> multivariate_normal.logpdf(np.array([[1. ,2.],
...                                      [2., 3.]]),
...                            np.array([[1., 2.],
...                                      [2., 3.]]),
...                            np.array([[[1., 0.2],
...                                       [0.2, 1.]],
...                                      [[1., 0.3],
...                                       [0.3, 1.]]]))
DeviceArray([inf, inf], dtype=float32)
@sidravi1 sidravi1 changed the title Error when broadcasting inf using multivariate_normal.logpdf with multiple batches Jan 31, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 31, 2021

Thanks for the report!

I beleive the issue is that the diagonal of the decomposed covariance is not properly handled in the block case. It looks like changing this line: https://github.com/google/jax/blob/bb0750f31a35e2661834e8b532781785fe63496e/jax/_src/scipy/stats/multivariate_normal.py#L46-L47
to this:

return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2*np.log(2*np.pi)
        - jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))

will make things work correctly.

I can prepare a fix with tests sometime tomorrow if nobody gets to it before that.

@jakevdp jakevdp added the bug Something isn't working label Jan 31, 2021
@jakevdp jakevdp self-assigned this Jan 31, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants