This notebook is made for rechecking the lower bound of the heteroscedastic model.

It shows that the lower bound works well, and that rather the model is a bit problematic.

In [None]:
from gaussian_toolbox import pdf, approximate_conditional
import numpy as np 
from jax import numpy as jnp
import jax
from matplotlib import pyplot as plt
from jax import config
config.update("jax_enable_x64", True)

Dy = 3
Dx = 1
R = 1
Du = 1

M = jnp.array([jnp.eye(Dy)])[:, :, :Dx]
b = jnp.array([jnp.zeros(Dy)])
Sigma = jnp.array([jnp.eye(Dy)]) + .1
mat = np.random.randn(Dy, Dy)
Q, R = np.linalg.qr(mat)
U = Q[:, :Du]
W = jnp.array(np.random.randn(Du, Dx + 1))

cond = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma, U=U, W=W)

# Check if cosh is rightly implemented

In [None]:
x = jnp.array([jnp.linspace(-5,5,100)]).T
cosh_gt = cond.exp_h_minus(x) + cond.exp_h_plus(x)
cosh_np = jnp.cosh((jnp.dot(W[:, 1:], x.T) + W[:, 0]))
plt.subplot(2,1,1)
plt.plot(x, cosh_gt.T)
plt.plot(x, cosh_np.T, '--')
plt.subplot(2,1,2)
plt.plot(x, 1./cosh_gt.T)
plt.plot(x, 1./cosh_np.T, '--')

## Check if covariance matrix is computed correctly

In [None]:
cov_gt, prec_gt, ln_det_cov_gt = cond.get_conditional_cov(x, invert=True)
plt.plot(x, jnp.sum(jnp.sum(jnp.abs(jnp.einsum('abc,acd->abd', cov_gt, prec_gt) - jnp.eye(Dy)[None]) / jnp.abs(cov_gt), axis=1), axis=1))

In [None]:
ln_det_cov_np = jnp.linalg.slogdet(cov_gt)[1]
plt.plot(jnp.abs(ln_det_cov_gt - ln_det_cov_np) / jnp.abs(ln_det_cov_gt))

In [None]:
diag = cond.exp_h_minus(x) + cond.exp_h_plus(x) - 1.
diag.shape

In [None]:
rot_mat = jnp.eye(Dy) + jnp.einsum('abc,dc->abd', jnp.einsum('ab,ca->bca', diag, U), U)
L = jnp.linalg.cholesky(Sigma)
cov_mat_np = jnp.einsum('abc,adc->abd', jnp.einsum('abc,acd->abd', L, rot_mat), L)
plt.plot(x, jnp.sum(jnp.sum(jnp.abs(cov_mat_np - cov_gt) / jnp.abs(cov_gt), axis=1), axis=1))

In [None]:
y = jnp.array(np.random.randn(100,Dy))
plt.plot(cond(x).evaluate(y, element_wise=True), '.')

In [None]:
from jax import random as jrand
pX = pdf.GaussianPDF(mu=jnp.zeros(Dx)[None], Sigma=jnp.eye(Dx)[None])
key = jrand.PRNGKey(0)
x_samples = pX.sample(key, 10000)
int_log_cond_y_np = jnp.mean(cond(x_samples[:,0]).evaluate_ln(y), axis=0)
px = pdf.GaussianPDF(mu=jnp.tile(pX.mu, (100,1)), Sigma=jnp.tile(pX.Sigma, (100,1,1)))
int_log_cond_y_gt = cond.integrate_log_conditional_y(p_x=px, y= y)

plt.plot(int_log_cond_y_np, int_log_cond_y_gt, '.')
plt.plot([jnp.amin(jnp.stack([int_log_cond_y_gt, int_log_cond_y_np])),jnp.amax(jnp.stack([int_log_cond_y_gt, int_log_cond_y_np]))], 
         [jnp.amin(jnp.stack([int_log_cond_y_gt, int_log_cond_y_np])),jnp.amax(jnp.stack([int_log_cond_y_gt, int_log_cond_y_np]))], '--')

### Check log determinant

In [None]:
from jax import lax
import jax

ln_det_cov_gt_arr = np.zeros(50)
ln_det_cov_np_arr = np.zeros(50)

for i in range(50):
    M = jnp.array([jnp.eye(Dy)])
    b = jnp.array([jnp.zeros(Dy)])
    Sigma = jnp.array([jnp.eye(Dy)]) + .1
    mat = np.random.randn(Dy, Dy)
    Q, R = np.linalg.qr(mat)
    U = Q[:, :Du]
    W = jnp.array(np.random.randn(Du, Dx + 1))

    cond = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma, U=U, W=W)

    pX = pdf.GaussianPDF(mu=jnp.zeros(Dx)[None], Sigma=jnp.eye(Dx)[None])
    ln_det_cov_np = jnp.mean(cond(x_samples[:,0]).ln_det_Sigma)
    def scan_body_function(carry, args_i):
        W_i, u_i = args_i
        omega_star_i, omega_dagger_i, _ = lax.stop_gradient(
            cond._get_omega_star_i(W_i, u_i, px, y)
        )
        uRu_i, log_lb_sum_i = cond._get_lb_i(
            W_i, u_i, omega_star_i, omega_dagger_i, px, y
        )
        result = (uRu_i, log_lb_sum_i)
        return carry, result

    _, result = lax.scan(scan_body_function, None, (cond.W, cond.U.T))
    uRu, log_lb_sum = result
    ln_det_cov_gt = log_lb_sum + cond.ln_det_Sigma
    ln_det_cov_gt_arr[i] = ln_det_cov_gt[0,0]
    ln_det_cov_np_arr[i] = ln_det_cov_np
#assert jnp.alltrue(jnp.less_equal(ln_det_cov_np_arr, ln_det_cov_gt_arr))

In [None]:
plt.plot(ln_det_cov_np_arr, ln_det_cov_gt_arr, '.')
plt.plot([jnp.amin(jnp.stack([ln_det_cov_gt_arr, ln_det_cov_np_arr])),jnp.amax(jnp.stack([ln_det_cov_gt_arr, ln_det_cov_np_arr]))],
            [jnp.amin(jnp.stack([ln_det_cov_gt_arr, ln_det_cov_np_arr])),jnp.amax(jnp.stack([ln_det_cov_gt_arr, ln_det_cov_np_arr]))], '--')

# Check quadratic term

In [None]:
mu = cond.get_conditional_mu(x_samples[:,0])[0]
cov, prec, ln_det_cov_gt = cond.get_conditional_cov(x_samples[:,0], invert=True)
mat = prec - jnp.linalg.inv(cond.Sigma)
y_mu = y[None] - mu[:,None]

In [None]:
res_np = jnp.mean(jnp.einsum('abc, abc->ab', jnp.einsum('abc,adc->adb', mat, y_mu), y_mu), axis=0)

In [None]:
res_gt = -jnp.sum(uRu,axis=0)

In [None]:
plt.plot(res_np, res_gt, '.')
plt.plot([jnp.amin(jnp.stack([res_gt, res_np])),jnp.amax(jnp.stack([res_gt, res_np]))],
         [jnp.amin(jnp.stack([res_gt, res_np])),jnp.amax(jnp.stack([res_gt, res_np]))], '--')

### Check diagonal term 

In [None]:
cosh = cond.exp_h_minus(x_samples[:,0]) + cond.exp_h_plus(x_samples[:,0])
G_np = jnp.mean((cosh-1)/cosh, axis=1)

In [None]:
cond._get_omega_star_i(W, U.T, px, y)

In [None]:
from gaussian_toolbox import pdf, factor
from jaxtyping import Array, Float
from typing import Tuple

def _get_G_i(
        cond,
        W_i: Float[Array, "Dx+1"],
        u_i: Float[Array, "Dy"],
        omega_star: Float[Array, "N"],
        p_x: pdf.GaussianPDF,
        y: Float[Array, "N Dy"],
    ) -> Tuple[Float[Array, "N"], Float[Array, "N"]]:
        # phi = pdf.GaussianPDF(**phi_dict)
        # beta = self.beta[iu:iu + 1]
        # Lower bound for \mathbb{E}[ln (sigma_x^2 + f(h))]
        G = p_x.R
        w_i = W_i[1:].reshape((1, -1))
        v = jnp.tile(w_i, (G, 1))
        b_i = W_i[:1]
        u_i = u_i.reshape((-1, 1))
        # Lower bound for \mathbb{E}[ln (sigma_x^2 + f(h))]
        g_omega = cond.g(omega_star)
        nu = -g_omega[:, None] * b_i * w_i
        nu_plus = w_i + nu
        nu_minus = -w_i + nu
        f_omega_star = cond.f(omega_star)
        ln_beta = -jnp.where(
            jnp.isclose(f_omega_star, 0), 0, jnp.log(1 + f_omega_star)
        ) - 0.5 * g_omega * (b_i**2 - omega_star**2)
        ln_beta_plus = ln_beta + b_i - jnp.log(2)
        ln_beta_minus = ln_beta - b_i - jnp.log(2)
        # Create OneRankFactors
        g_omega = jnp.clip(g_omega, a_min=1e-4)
        exp_factor_plus = factor.OneRankFactor(
            v=v, g=g_omega, nu=nu_plus, ln_beta=ln_beta_plus
        )
        exp_factor_minus = factor.OneRankFactor(
            v=v, g=g_omega, nu=nu_minus, ln_beta=ln_beta_minus
        )
        one_factor = factor.OneRankFactor(v=v, g=g_omega, nu=nu, ln_beta=ln_beta)
        # Create the two measures
        exp_phi_plus = p_x.hadamard(exp_factor_plus, update_full=True)
        exp_phi_minus = p_x.hadamard(exp_factor_minus, update_full=True)
        phi_one = p_x.hadamard(one_factor, update_full=True)

        mat1 = -cond.M[0]
        vec1 = y - cond.b[0]
        vec1_projected = jnp.einsum("ba,cb->ca", cond.L_inv[0], vec1)
        mat1_projected = jnp.einsum("ba,bc->ac", cond.L_inv[0], mat1)
        G_plus = exp_phi_plus.integrate(
            "(Ax+a)(Bx+b)'",
            A_mat=mat1_projected,
            a_vec=vec1_projected,
            B_mat=mat1_projected,
            b_vec=vec1_projected,
        )
        G_minus = exp_phi_minus.integrate(
            "(Ax+a)(Bx+b)'",
            A_mat=mat1_projected,
            a_vec=vec1_projected,
            B_mat=mat1_projected,
            b_vec=vec1_projected,
        )
        G_one = phi_one.integrate(
            "(Ax+a)(Bx+b)'",
            A_mat=mat1_projected,
            a_vec=vec1_projected,
            B_mat=mat1_projected,
            b_vec=vec1_projected,
        )
        G = G_plus + G_minus - G_one
        return G


In [None]:
_get_G_i(cond, W[0], U[0], omega_star, p_x, y)

# Test lower bound

In [None]:
%load_ext autoreload
%autoreload 2

from gaussian_toolbox import pdf, approximate_conditional, factor
import numpy as np 
from jax import numpy as jnp
import jax
from matplotlib import pyplot as plt
from jax import config, lax
config.update("jax_enable_x64", True)

Dy = 3
Dx = 2
R = 1
Du = 1

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
M = jax.random.normal(subkey, shape=(1, Dy, Dx))
b = jnp.array([jnp.zeros(Dy)])
Sigma = jnp.array([jnp.eye(Dy)]) + .1
mat = np.random.randn(Dy, Dy)
Q, R = np.linalg.qr(mat)
U = Q[:, :Du]
W = jnp.array(np.random.randn(Du, Dx + 1))

cond = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma, U=U, W=W)

In [None]:
num_samples = 1000
key, subkey = jax.random.split(key)
y = jax.random.normal(key, (10, Dy))
key, subkey = jax.random.split(key)
mu = jax.random.normal(subkey, (1,Dx))
key, subkey = jax.random.split(key)
rand_mat = jax.random.uniform(subkey, (Dx, Dx))
Sigma = jnp.array(jnp.eye(Dx) + jnp.dot(rand_mat, rand_mat.T))[None]
px = pdf.GaussianPDF(mu=mu, Sigma=Sigma)
p_x_tiled = pdf.GaussianPDF(mu=jnp.tile(mu, (y.shape[0], 1)), Sigma=jnp.tile(Sigma, (y.shape[0], 1, 1)))
key, subkey = jax.random.split(key)
x_samples = px.sample(subkey, num_samples)[:,0]

In [None]:
Sigma_y_x, Lambda_y_x, ln_det_Sigma_y_x = cond.get_conditional_cov(x_samples, True)
mu_y_x = cond.get_conditional_mu(x_samples)
y_mu = y[:,None] - mu_y_x
quadratic_term = - .5 * jnp.mean(jnp.einsum("abc,bcd,abd->ab", y_mu, Lambda_y_x, y_mu), axis=1)
log_det = ln_det_Sigma_y_x.mean()
sampled_log_cond = quadratic_term - .5 * log_det - .5 * Dy * jnp.log(2 * jnp.pi)


In [None]:
sampled_log_cond2 = jnp.mean(cond(x_samples).evaluate_ln(y), axis=0)

In [None]:
bounded_log_cond = jnp.empty((10,))
integrate_log_conditional_y = jax.jit(cond.integrate_log_conditional_y)
for i in range(10):
    bounded_log_cond_i = cond.integrate_log_conditional_y(p_x=px, y=y[i:i+1])
    bounded_log_cond = bounded_log_cond.at[i].set(bounded_log_cond_i[0])

In [None]:
plt.plot(sampled_log_cond, sampled_log_cond2, "o")
plt.plot(sampled_log_cond, bounded_log_cond, "o")
plt.plot([-5, 0], [-5, 0], "--")

This is not lower bound, but upper bound.

Let's check the determinant.

In [None]:
px.integration_dict.keys()

In [None]:
b = cond.W[:,0]
w = cond.W[:,1:]
omega_dagger = jnp.sqrt(px.integrate("(Ax+a)'(Bx+b)", A_mat=w, a_vec=b, B_mat=w, b_vec=b))

def k_tmp(omega_dagger, px):
    Eh2 = px.integrate("(Ax+a)'(Bx+b)", A_mat=w, a_vec=b, B_mat=w, b_vec=b)
    return jnp.log(jnp.cosh(omega_dagger)) + .5 * jnp.tanh(omega_dagger) / omega_dagger * (Eh2 - omega_dagger ** 2)

def lb_log_det_tmp(omega_dagger, px):
    k_omega = k_tmp(omega_dagger, px)
    lower_bound_log_det = cond.ln_det_Sigma + k_omega
    return lower_bound_log_det

In [None]:
# Functions to integrate in toolbox
def _get_omega_dagger(p_x: pdf.GaussianPDF, W_i: jnp.ndarray) -> jnp.ndarray:
    b = W_i[None,:1]
    w = W_i[None,1:]
    omega_dagger = jnp.sqrt(p_x.integrate("(Ax+a)'(Bx+b)", A_mat=w, a_vec=b, B_mat=w, b_vec=b))
    return omega_dagger

def k_func(p_x: pdf.GaussianPDF, W_i: jnp.ndarray, omega_dagger: jnp.ndarray):
    b = W_i[None,:1]
    w = W_i[None,1:]
    Eh2 = p_x.integrate("(Ax+a)'(Bx+b)", A_mat=w, a_vec=b, B_mat=w, b_vec=b)
    return jnp.log(jnp.cosh(omega_dagger)) + .5 * jnp.tanh(omega_dagger) / omega_dagger * (Eh2 - omega_dagger ** 2)

def get_lb_log_det(self, p_x: pdf.GaussianPDF):
    omega_dagger = lax.stop_gradient(jax.vmap(lambda W: _get_omega_dagger(p_x=p_x, W_i=W), in_axes=(0,))(self.W))
    k_omega = jax.vmap(lambda W, omega: k_func(p_x=p_x, W_i=W, omega_dagger=omega))(self.W, omega_dagger)
    lower_bound_log_det = self.ln_det_Sigma + jnp.sum(k_omega, axis=0)
    return lower_bound_log_det


In [None]:
omega_range = jnp.linspace(.5 * omega_dagger,2 * omega_dagger, 100)
plt.plot(omega_range, k_tmp(omega_range, px=px)[:,0])
plt.plot(omega_range, jax.vmap(lambda omega: k_func(p_x=px, W_i=cond.W[0], omega_dagger=omega))(omega_range)[:,0], '--')
plt.plot(omega_dagger, k_tmp(omega_dagger, px=px)[0], "o")

In [None]:
lower_bound_log_det_tmp = lb_log_det_tmp(omega_range, px=px)
lower_bound_log_det = cond.get_lb_log_det(p_x=px)

In [None]:
lower_bound_log_det

In [None]:
plt.plot(omega_range, lower_bound_log_det_tmp[:,0])
plt.plot(omega_dagger, get_lb_log_det(cond, p_x=px), "v")
plt.plot(omega_dagger, lb_log_det_tmp(omega_dagger, px), "o")
plt.hlines(log_det, omega_range[0], omega_range[-1], linestyles="--")

# Quadratic term

In [None]:
# [Du, R]
from jax import lax
omega_dagger = jax.vmap(lambda W:_get_omega_dagger(p_x=p_x_tiled, W_i=W), out_axes=1)(cond.W)

def _update_omega_star(cond, p_x: pdf.GaussianPDF, y: jnp.ndarray, W_i: jnp.ndarray, U_i: jnp.ndarray, omega_star: jnp.ndarray):
    b = W_i[None,:1]
    w = W_i[None,1:]
    
    g_1 = jnp.tanh(omega_star) / omega_star
    nu_1 = - (jnp.tanh(omega_star) / omega_star)[:,None] * b * w
    ln_beta_1 = - jnp.log(jnp.cosh(omega_star)) - .5 * jnp.tanh(omega_star) / omega_star * (b[0] ** 2 - omega_star ** 2)
    phi_1 = p_x.hadamard(factor.OneRankFactor(v=jnp.tile(w, (omega_star.shape[0], 1)), g=g_1, nu=nu_1, ln_beta=ln_beta_1), update_full=True)
    phi_plus = phi_1.hadamard(factor.LinearFactor(nu=w, ln_beta=b-jnp.log(2.)), update_full=True)
    phi_minus = phi_1.hadamard(factor.LinearFactor(nu=-w, ln_beta=-b-jnp.log(2.)), update_full=True)

    # Quartic integral

    projected_M = jnp.einsum('acb,acd->abd', cond.L_inv, cond.M)
    projected_yb = jnp.einsum('acb,ac->ab', cond.L_inv, y - cond.b)
    U_projected_M = jnp.einsum('ab,cad->cbd', U_i[:,None], projected_M)
    U_projected_yb = jnp.einsum('ab,ca->cb', U_i[:,None], projected_yb)
    
    quartic_1 = phi_1.integrate("(Ax+a)'(Bx+b)(Cx+c)'(Dx+d)", A_mat=w, a_vec=b, B_mat=w, b_vec=b, 
                                C_mat=-U_projected_M, c_vec=U_projected_yb, D_mat=-U_projected_M, d_vec=U_projected_yb)
    quartic_plus = phi_plus.integrate("(Ax+a)'(Bx+b)(Cx+c)'(Dx+d)", A_mat=w, a_vec=b, B_mat=w, b_vec=b, 
                                      C_mat=-U_projected_M, c_vec=U_projected_yb, D_mat=-U_projected_M, d_vec=U_projected_yb)
    quartic_minus = phi_minus.integrate("(Ax+a)'(Bx+b)(Cx+c)'(Dx+d)", A_mat=w, a_vec=b, B_mat=w, b_vec=b, 
                                        C_mat=-U_projected_M, c_vec=U_projected_yb, D_mat=-U_projected_M, d_vec=U_projected_yb)

    quartic_integral = - quartic_1 + quartic_plus + quartic_minus
    # Quadratic integral
    quadratic_1 = phi_1.integrate("(Ax+a)'(Bx+b)", A_mat=-U_projected_M, a_vec=U_projected_yb, B_mat=-U_projected_M, b_vec=U_projected_yb)
    quadratic_plus = phi_plus.integrate("(Ax+a)'(Bx+b)", A_mat=-U_projected_M, a_vec=U_projected_yb, B_mat=-U_projected_M, b_vec=U_projected_yb)
    quadratic_minus = phi_minus.integrate("(Ax+a)'(Bx+b)", A_mat=-U_projected_M, a_vec=U_projected_yb, B_mat=-U_projected_M, b_vec=U_projected_yb)

    quadratic_integral = - quadratic_1 + quadratic_plus + quadratic_minus

    omega_star = jnp.sqrt(quartic_integral / quadratic_integral)[0]
    return omega_star

def _get_omega_star(cond, p_x: pdf.GaussianPDF, y: jnp.ndarray, W_i: jnp.ndarray, U_i: jnp.ndarray):
    omega_star = _get_omega_dagger(p_x=p_x, W_i=W_i)
    omega_dagger = omega_star + 1.
    cond_func = lambda val: jnp.max(jnp.abs(val[0] - val[1])) > 1e-5
    body_func = lambda val: (_update_omega_star(cond, p_x=p_x, y=y, W_i=W_i, U_i=U_i, omega_star=val[0]), val[0])
    omega_star, _ = lax.while_loop(cond_func, body_func, (omega_star, omega_dagger))
    return omega_star

def get_lb_heteroscedastic_term_i(cond, p_x: pdf.GaussianPDF, y: jnp.ndarray, W_i: jnp.ndarray, U_i: jnp.ndarray):
    omega_star = lax.stop_gradient(_get_omega_star(cond, p_x=p_x, y=y, W_i=W_i, U_i=U_i))
    b = W_i[None,:1]
    w = W_i[None,1:]
    g_1 = jnp.tanh(omega_star) / omega_star
    nu_1 = - (jnp.tanh(omega_star) / omega_star)[:,None] * b * w
    ln_beta_1 = - jnp.log(jnp.cosh(omega_star)) - .5 * jnp.tanh(omega_star) / omega_star * (b ** 2 - omega_star ** 2)
    phi_1 = p_x.hadamard(factor.OneRankFactor(v=jnp.tile(w, (omega_star.shape[0], 1)), g=g_1, nu=nu_1, ln_beta=ln_beta_1), update_full=True)
    phi_plus = phi_1.hadamard(factor.LinearFactor(nu=w, ln_beta=b-jnp.log(2.)), update_full=True)
    phi_minus = phi_1.hadamard(factor.LinearFactor(nu=-w, ln_beta=-b-jnp.log(2.)), update_full=True)
    # Quadratic integral
    projected_M = jnp.einsum('acb,acd->abd', cond.L_inv, cond.M)
    projected_yb = jnp.einsum('acb,ac->ab', cond.L_inv, y - cond.b)
    U_projected_M = jnp.einsum('ab,cad->cbd', U_i[:,None], projected_M)
    U_projected_yb = jnp.einsum('ab,ca->cb', U_i[:,None], projected_yb)
    quadratic_1 = phi_1.integrate("(Ax+a)'(Bx+b)", A_mat=-U_projected_M, a_vec=U_projected_yb, B_mat=-U_projected_M, b_vec=U_projected_yb)
    quadratic_plus = phi_plus.integrate("(Ax+a)'(Bx+b)", A_mat=-U_projected_M, a_vec=U_projected_yb, B_mat=-U_projected_M, b_vec=U_projected_yb)
    quadratic_minus = phi_minus.integrate("(Ax+a)'(Bx+b)", A_mat=-U_projected_M, a_vec=U_projected_yb, B_mat=-U_projected_M, b_vec=U_projected_yb)

    G_i = - quadratic_1 + quadratic_plus + quadratic_minus
    return G_i

def get_lb_quadratic_term(cond, p_x: pdf.GaussianPDF, y: jnp.ndarray):
    projected_M = jnp.einsum('acb,acd->abd', cond.L_inv, cond.M)
    projected_yb = jnp.einsum('acb,ac->ab', cond.L_inv, y - cond.b)
    homoscedastic_term = p_x.integrate("(Ax+a)'(Bx+b)", A_mat=-projected_M, a_vec=projected_yb, B_mat=-projected_M, b_vec=projected_yb)
    get_lb_heteroscedastic_term = jnp.sum(jax.vmap(lambda W, U: get_lb_heteroscedastic_term_i(cond, p_x_tiled, y, W, U))(cond.W, cond.U.T), axis=0)
    return homoscedastic_term - get_lb_heteroscedastic_term
                   


In [None]:
get_lb_quadratic_term(cond, p_x_tiled, y)

In [None]:
def get_log_p_y(cond, p_x: pdf.GaussianPDF, y: jnp.ndarray):
    lb_quadratic_term = get_lb_quadratic_term(cond, p_x, y)
    lb_log_det = get_lb_log_det(cond, p_x)
    lb_log_p_y = -.5 * (lb_quadratic_term + lb_log_det + Dy * jnp.log(2. * jnp.pi))[0]
    return lb_log_p_y

In [None]:
intgerate_jit_locally = jax.jit(get_log_p_y)
%timeit intgerate_jit_locally(cond, p_x_tiled, y)

In [None]:
integrate_jit = jax.jit(cond.integrate_log_conditional_y)
%timeit integrate_jit(p_x_tiled, y)

In [None]:
key = jax.random.PRNGKey(0)
x_samples = px.sample(key, 1000)
log_p_y_sampled = jnp.mean(cond(x_samples[:,0]).evaluate_ln(y), axis=0)

In [None]:
intgerate_jit_locally(cond, p_x_tiled, y) - integrate_jit(p_x_tiled, y)

In [None]:
plt.plot(log_p_y_sampled, lb_log_p_y, 'o')
plt.plot(log_p_y_sampled, lb_log_p_y2, 'o')

In [None]:
cond.U.shape

In [None]:
omega_star = _get_omega_star(cond, p_x_tiled, y, cond.W[0], cond.U[:,0])
omega_range = jnp.linspace(1e-3, 10., 100)[:,None]
rhs = jax.vmap(lambda omega: _update_omega_star(cond, p_x=p_x_tiled, y=y, W_i=cond.W[0], U_i=cond.U[:,0], omega_star=omega))(omega_range)

plt.plot(omega_range, rhs, "-")
plt.plot([0,10], [0,10], "--")
plt.plot(omega_star, omega_star, "o")
plt.ylim((jnp.min(rhs), jnp.max(rhs)))

# Testing optimization

In [None]:
%load_ext autoreload
%autoreload 2

from gaussian_toolbox import pdf, approximate_conditional, factor
import numpy as np 
from jax import numpy as jnp
import jax
from matplotlib import pyplot as plt
from jax import config, lax
config.update("jax_enable_x64", True)

Dy = 1
Dx = 1
R = 1
Du = 1
N = 1000

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
M = jax.random.normal(subkey, shape=(1, Dy, Dx))
b = jnp.array([jnp.zeros(Dy)])
Sigma_cond = .1 * jnp.array([jnp.eye(Dy)])
mat = np.random.randn(Dy, Dy)
Q, R = np.linalg.qr(mat)
U = Q[:, :Du]
W = jnp.array(np.random.randn(Du, Dx + 1))

cond = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=W)

In [None]:
key, subkey = jax.random.split(key)
y = jax.random.normal(key, (10, Dy))
key, subkey = jax.random.split(key)
mu = jax.random.normal(subkey, (1,Dx))
key, subkey = jax.random.split(key)
rand_mat = jax.random.uniform(subkey, (Dx, Dx))
Sigma_x = 1. * jnp.array(jnp.eye(Dx) + jnp.dot(rand_mat, rand_mat.T))[None]
px = pdf.GaussianPDF(mu=mu, Sigma=Sigma_x)
key, subkey = jax.random.split(key)
x_samples = px.sample(subkey, N)[:,0]

In [None]:
key, subkey = jax.random.split(key)
y = cond(x_samples).sample(subkey, 1)[0]
p_x_tiled = pdf.GaussianPDF(mu=jnp.tile(mu, (y.shape[0], 1)), Sigma=jnp.tile(Sigma_x, (y.shape[0], 1, 1)))

In [None]:
x_range = jnp.linspace(-5, 5, 100)[:,None]
plt.plot(x_samples, y, '.')
plt.plot(x_range, cond(x_range).mu)
plt.fill_between(x_range[:,0], cond(x_range).mu[:,0] - jnp.sqrt(cond(x_range).Sigma[:,0,0]), cond(x_range).mu[:,0] + jnp.sqrt(cond(x_range).Sigma[:,0,0]), alpha=.2)

In [None]:
def objective(params):
    W = params['W']
    cond = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=W)
    return -jnp.sum(cond.integrate_log_conditional_y(p_x_tiled, y))

In [None]:
import jaxopt

key, subkey = jax.random.split(key)
solver = jaxopt.ScipyMinimize(method='CG', fun=objective)
res = solver.run(init_params={'W': 1e-2 * jax.random.normal(subkey, W.shape)})

In [None]:
jnp.sign(cond.W), cond.W

In [None]:
results = []
objective_results = []
for i in range(100):
    key, subkey = jax.random.split(key)
    res = solver.run(init_params={'W': jax.random.normal(subkey, W.shape)})
    results.append(res.params)
    objective_results.append(res.state.fun_val)

In [None]:
jnp.nanargmin(jnp.array(objective_results))

In [None]:
res_opt

In [None]:
res_opt = results[jnp.nanargmin(jnp.array(objective_results))]
cond_fit = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=res_opt['W'])

In [None]:
x_range = jnp.linspace(-5, 5, 100)[:,None]
plt.plot(x_samples, y, '.')
plt.plot(x_range, cond(x_range).mu)
plt.fill_between(x_range[:,0], cond(x_range).mu[:,0] - jnp.sqrt(cond(x_range).Sigma[:,0,0]), cond(x_range).mu[:,0] + jnp.sqrt(cond(x_range).Sigma[:,0,0]), alpha=.2)
plt.plot(x_range, cond_fit(x_range).mu)
plt.fill_between(x_range[:,0], cond_fit(x_range).mu[:,0] - jnp.sqrt(cond_fit(x_range).Sigma[:,0,0]), 
                 cond_fit(x_range).mu[:,0] + jnp.sqrt(cond_fit(x_range).Sigma[:,0,0]), alpha=.2)


In [None]:
def objective_sample(params, key):
    W = params['W']
    key, subkey = jax.random.split(key)
    cond_tmp = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=W)
    x_samples = px.sample(subkey, 1000)[:,0]
    return -jnp.sum(jnp.mean(cond_tmp(x_samples).evaluate_ln(y), axis=0)), key

In [None]:
b_range = jnp.linspace(-4, 4, 100)
W_range = jnp.linspace(-4, 4, 100)
mesh = jnp.meshgrid(b_range, W_range)
W_mesh = jnp.stack([mesh[0].flatten(), mesh[1].flatten()], axis=-1)

objective_arr = []
objective_sample_arr = []

objective_jit = jax.jit(objective)
objective_sample_jit = jax.jit(objective_sample)

for idx in range(W_mesh.shape[0]):
    objective_arr.append(objective_jit({'W': W_mesh[idx:idx+1]}))
    val, key = objective_sample_jit({'W': W_mesh[idx:idx+1]}, key)
    objective_sample_arr.append(val)
#cond_tmp = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=W[idx:idx+1])
#cond_fit = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=res_opt['W'])
objective_arr = jnp.array(objective_arr).reshape(mesh[0].shape)
objective_sample_arr = jnp.array(objective_sample_arr).reshape(mesh[0].shape)

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.pcolor(mesh[0], mesh[1], objective_arr)
plt.plot(res_opt['W'][0,0], res_opt['W'][0,1], 'r*')
plt.plot(W[0,0], W[0,1], 'bs')
plt.colorbar()
plt.subplot(1,2,2)
plt.pcolor(mesh[0], mesh[1], objective_sample_arr)
plt.colorbar()

In [None]:
%load_ext autoreload
%autoreload 2

from gaussian_toolbox import pdf, approximate_conditional, factor
import numpy as np 
from jax import numpy as jnp
import jax
from matplotlib import pyplot as plt
from jax import config, lax
config.update("jax_enable_x64", True)

Dy = 1
Dx = 1
R = 1
Du = 1
N = 100

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
M = jax.random.normal(subkey, shape=(1, Dy, Dx))
b = jnp.array([jnp.zeros(Dy)])
Sigma_cond = .1 * jnp.array([jnp.eye(Dy)])
mat = np.random.randn(Dy, Dy)
Q, R = np.linalg.qr(mat)
U = Q[:, :Du]
W = jnp.array(np.random.randn(Du, Dx + 1))

cond = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=W)

In [None]:
key, subkey = jax.random.split(key)
y = jax.random.normal(key, (10, Dy))
key, subkey = jax.random.split(key)
mu = jax.random.normal(subkey, (N, Dx))
key, subkey = jax.random.split(key)
rand_mat = jax.random.uniform(subkey, (Dx, Dx))
Sigma_x = 1. * jnp.array(jnp.eye(Dx) + jnp.dot(rand_mat, rand_mat.T))[None]
px = pdf.GaussianPDF(mu=mu, Sigma=Sigma_x)
key, subkey = jax.random.split(key)
x_samples = px.sample(subkey, 1)[0]

In [None]:
key, subkey = jax.random.split(key)
y = cond(x_samples).sample(subkey, 1)[0]

In [None]:
x_range = jnp.linspace(-5, 5, 100)[:,None]
plt.plot(x_samples, y, '.')
plt.plot(x_range, cond(x_range).mu)
plt.fill_between(x_range[:,0], cond(x_range).mu[:,0] - jnp.sqrt(cond(x_range).Sigma[:,0,0]), cond(x_range).mu[:,0] + jnp.sqrt(cond(x_range).Sigma[:,0,0]), alpha=.2)

In [None]:
def objective(params):
    W = params['W']
    cond = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=W)
    return -jnp.sum(cond.integrate_log_conditional_y(px, y))

import jaxopt

key, subkey = jax.random.split(key)
solver = jaxopt.ScipyMinimize(method='CG', fun=objective)
res = solver.run(init_params={'W': 1e-2 * jax.random.normal(subkey, W.shape)})
cond_fit = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=res.params['W'])

In [None]:
x_range = jnp.linspace(-5, 5, 100)[:,None]
plt.plot(x_samples, y, 'k.', label='data')
#plt.plot(x_range, cond(x_range).mu)
plt.fill_between(x_range[:,0], cond(x_range).mu[:,0] - jnp.sqrt(cond(x_range).Sigma[:,0,0]), cond(x_range).mu[:,0] + jnp.sqrt(cond(x_range).Sigma[:,0,0]), alpha=.2, label='true')
#plt.plot(x_range, cond_fit(x_range).mu)
plt.fill_between(x_range[:,0], cond_fit(x_range).mu[:,0] - jnp.sqrt(cond_fit(x_range).Sigma[:,0,0]), 
                 cond_fit(x_range).mu[:,0] + jnp.sqrt(cond_fit(x_range).Sigma[:,0,0]), alpha=.2, label='fit')
plt.plot(x_range, px.evaluate(x_range).T - 6, 'k', alpha=.5, label='p(x)')
plt.xlabel('x')
plt.ylabel('y')
plt.legend(['data', 'true', 'fit', 'p(x)'])
plt.title('High variance of x')
plt.show()


In [None]:
x_samples[:,i].shape

In [None]:
def objective_sample(params, key):
    W = params['W']
    key, subkey = jax.random.split(key)
    cond_tmp = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=W)
    x_samples = px.sample(subkey, 1000)
    log_likelihood = 0
    log_likelihood = lax.fori_loop(0, 100, lambda i, log_likelihood: log_likelihood + jnp.mean(cond_tmp(x_samples[:,i]).evaluate_ln(y[i][None])), log_likelihood)
    return -log_likelihood, key

In [None]:
y.shape

In [None]:
key, subkey = jax.random.split(key)
cond_tmp = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=W)
x_samples = px.sample(subkey, 1000)
i = 99
jnp.mean(cond_tmp(x_samples[:,i]).evaluate_ln(y[i:i+1]))

In [None]:
b_range = jnp.linspace(-4, 4, 100)
W_range = jnp.linspace(-4, 4, 100)
mesh = jnp.meshgrid(b_range, W_range)
W_mesh = jnp.stack([mesh[0].flatten(), mesh[1].flatten()], axis=-1)

objective_arr = []
objective_sample_arr = []

objective_jit = jax.jit(objective)
objective_sample_jit = jax.jit(objective_sample)

for idx in range(W_mesh.shape[0]):
    objective_arr.append(objective_jit({'W': W_mesh[idx:idx+1]}))
    val, key = objective_sample_jit({'W': W_mesh[idx:idx+1]}, key)
    objective_sample_arr.append(val)
#cond_tmp = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=W[idx:idx+1])
#cond_fit = approximate_conditional.FullHCCovGaussianConditional(M=M, b=b, Sigma=Sigma_cond, U=U, W=res_opt['W'])
objective_arr = jnp.array(objective_arr).reshape(mesh[0].shape)
objective_sample_arr = jnp.array(objective_sample_arr).reshape(mesh[0].shape)

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.pcolor(mesh[0], mesh[1], objective_arr)
plt.plot(res.params['W'][0,0], res.params['W'][0,1], 'r*', label='fit')
plt.plot(W[0,0], W[0,1], 'bs', label='true')
plt.legend()
plt.xlabel('b')
plt.ylabel('W')
plt.colorbar()
plt.title('lb loglikelihood')
plt.subplot(1,2,2)
plt.pcolor(mesh[0], mesh[1], objective_sample_arr)
plt.colorbar()
plt.title('sampled log likelihood')
plt.xlabel('b')

In [None]:
objective({"W": cond.W})

In [None]:
objective({"W": cond_fit.W})

In [None]:
y