In [1]:
import numpy as np

In [2]:
from scipy.integrate import quad

In [3]:
from quadrature import log_y_f, expectation

In [4]:
ys = np.random.randint(0, 2, size=10)

In [5]:
vars = np.random.randn(10)**2

In [6]:
means = np.random.randn(10)

In [7]:
expectations = expectation(ys, vars, means)



In [8]:
expectations.shape

(10,)

In [9]:
expectations

DeviceArray([-1.21908844, -0.09600639, -0.26643178, -0.29923171,
             -2.06203485, -2.99533677, -1.05473292, -2.93919373,
             -5.02521944, -0.51868159], dtype=float32)

In [10]:
# Compare with quadrature

In [28]:
from functools import partial
from jax import jit
from jax.scipy.stats import norm
import jax.numpy as jnp

@jit
def to_quadrature(f, cur_y, cur_mean, cur_var):
    
    log_prob = log_y_f(cur_y, f)
    q = norm.pdf(f, cur_mean, jnp.sqrt(cur_var))
    
    return log_prob * q



In [31]:
quad_res = np.zeros_like(means)

for i, (cur_y, cur_mean, cur_var) in enumerate(
        zip(ys, means, vars)):

    quad_fun = partial(to_quadrature, cur_y=cur_y, 
                       cur_mean=cur_mean,
                       cur_var=cur_var)

    quad_res[i] = quad(quad_fun, -np.inf, np.inf)[0]

  the requested tolerance from being achieved.  The error may be 
  underestimated.
  # Remove the CWD from sys.path while we load stuff.


In [33]:
np.allclose(quad_res, expectations)

True

In [34]:
cov1 = np.random.randn(5, 5)
cov2 = np.random.randn(5, 5)
cov1 = cov1 @ cov1.T + np.eye(5)
cov2 = cov2 @ cov2.T + np.eye(5)

In [35]:
logdet_cov1 = np.linalg.slogdet(cov1)[1]
logdet_cov2 = np.linalg.slogdet(cov2)[1]

In [38]:
# Compare

In [39]:
as_written = 0.5 * np.linalg.slogdet(cov1 @ np.linalg.inv(cov2))[1]

In [40]:
as_written

-0.5360926532145042

In [41]:
alternative = 0.5 * (logdet_cov1 - logdet_cov2)

In [42]:
alternative

-0.5360926532145043

In [43]:
np.random.randn(5) @ np.random.randn(5)

-0.6585360052276374

In [44]:
from kl import mvn_kl

In [45]:
mu1 = np.random.randn(5)
mu2 = np.random.randn(5)

In [48]:
mvn_kl(mu1, cov1, mu2, cov2)

DeviceArray(4.516652, dtype=float32)