# Bridge Proposals for SDEs

**Martin Lysy, University of Waterloo** 

**January 1, 2022**

## Formula for a Conditional Normal Distribution

Suppose that we have

$$
\begin{aligned}
\WW & \sim \N(\mmu_W, \SSi_W) \\
\XX \mid \WW & \sim \N(\WW + \mmu_{X|W}, \SSi_{X|W}) \\
\YY \mid \XX, \WW & \sim \N(\AA \XX, \OOm).
\end{aligned}
$$

then 

$$
\begin{bmatrix} \WW \\ \XX \\ \YY \end{bmatrix} \sim \N\left(\begin{bmatrix} \mmu_W \\ \mmu_W + \mmu_{X|W} \\ \mmu_{Y} \end{bmatrix}, \begin{bmatrix} \SSi_W & \SSi_W & \SSi_W \AA' \\ 
\SSi_W & \SSi_W + \SSi_{X|W} & (\SSi_W + \SSi_{XW}) \AA' \\
\AA \SSi_W & \AA (\SSi_W + \SSi_{XW}) & \SSi_Y\end{bmatrix} \right),
$$

where $\mmu_{Y} = \AA[\mmu_W + \mmu_{X|W}]$ and $\SSi_Y = \AA (\SSi_W + \SSi_{X|W}) \AA' + \OOm$, such that

$$
\WW \mid \YY \sim \N\left(\mmu_W + \SSi_W \AA' \SSi_Y^{-1}(\YY - \mmu_Y), \SSi_W - \SSi_W \AA' \SSi_Y^{-1} \AA \SSi_W \right).
$$

### Numerical Verification

We first verify the factorization $p(\WW, \XX, \YY) = p(\WW) p(\XX \mid \WW) p(\YY \mid \XX, \WW)$, where the RHS is the product of the three normal PDFs in the first equation and the LHS is the multivariate normal in the second equation.

In [2]:
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as random

key = random.PRNGKey(0)


def var_sim(key, size):
    """
    Generate a variance matrix of given size.
    """
    Z = random.normal(key, (size, size))
    return jnp.matmul(Z.T, Z)


def mvn_bridge_pars(mu_W, Sigma_W, mu_XW, Sigma_XW, Y, A, Omega):
    """
    Calculate the mean and variance of the conditional distribution `p(W | Y)`.
    """

    mu_Y = jnp.matmul(A, mu_W + mu_XW)
    AS_W = jnp.matmul(A, Sigma_W)
    Sigma_Y = jnp.linalg.multi_dot([A, Sigma_W + Sigma_XW, A.T]) + Omega
    # solve both linear systems simultaneously
    sol = jnp.matmul(AS_W.T, jnp.linalg.solve(
        Sigma_Y, jnp.hstack([jnp.array([Y-mu_Y]).T, AS_W])))
    return mu_W + jnp.squeeze(sol[:, 0]), Sigma_W - sol[:, 1:]


n_lat = 3  # number of dimensions of W and X
n_obs = 2  # number of dimensions of Y

# generate random values of the matrices and vectors

key, *subkeys = random.split(key, num=4)
mu_W = random.normal(subkeys[0], (n_lat,))
Sigma_W = var_sim(subkeys[1], n_lat)
W = random.normal(subkeys[2], (n_lat,))

key, *subkeys = random.split(key, num=4)
mu_XW = random.normal(subkeys[0], (n_lat,))
Sigma_XW = var_sim(subkeys[1], n_lat)
X = random.normal(subkeys[2], (n_lat,))

key, *subkeys = random.split(key, num=4)
A = random.normal(subkeys[0], (n_obs, n_lat))
Omega = var_sim(subkeys[1], n_obs)
Y = random.normal(subkeys[2], (n_obs,))

# joint distribution using factorization
lpdf1 = jsp.stats.multivariate_normal.logpdf(W, mu_W, Sigma_W)
lpdf1 = lpdf1 + jsp.stats.multivariate_normal.logpdf(X, W + mu_XW, Sigma_XW)
lpdf1 = lpdf1 + \
    jsp.stats.multivariate_normal.logpdf(Y, jnp.matmul(A, X), Omega)

# joint distribution using single mvn
mu_Y = jnp.matmul(A, mu_W + mu_XW)
Sigma_Y = jnp.linalg.multi_dot([A, Sigma_W + Sigma_XW, A.T]) + Omega
AS_W = jnp.matmul(A, Sigma_W)
AS_XW = jnp.matmul(A, Sigma_W + Sigma_XW)
mu = jnp.block([mu_W, mu_W + mu_XW, mu_Y])
Sigma = jnp.block([
    [Sigma_W, Sigma_W, AS_W.T],
    [Sigma_W, Sigma_W + Sigma_XW, AS_XW.T],
    [AS_W, AS_XW, Sigma_Y]
])
lpdf2 = jsp.stats.multivariate_normal.logpdf(jnp.block([W, X, Y]), mu, Sigma)

(lpdf1, lpdf2)



(DeviceArray(-500.63470941, dtype=float64),
 DeviceArray(-500.63470941, dtype=float64))

Next, we verify that $p(\WW, \YY) = p(\YY) p(\WW \mid \YY)$, where the LHS is the multivariate normal with mean and variance subset from those of $p(\WW, \XX, \YY)$, and the RHS is the product of the conditional distribution of interest and the marginal of $\WW$, which is the top left corner of $p(\WW, \XX, \YY)$.

In [89]:
# joint distribution using factorization
mu_WY, Sigma_WY = mvn_bridge_pars(mu_W, Sigma_W, mu_XW, Sigma_XW, Y, A, Omega)
lpdf1 = jsp.stats.multivariate_normal.logpdf(Y, mu_Y, Sigma_Y)
lpdf1 = lpdf1 + jsp.stats.multivariate_normal.logpdf(W, mu_WY, Sigma_WY)

# joint distribution using single mvn
ind = jnp.concatenate([jnp.arange(n_lat), 2*n_lat + jnp.arange(n_obs)])
lpdf2 = jsp.stats.multivariate_normal.logpdf(jnp.block([W, Y]), mu[ind], Sigma[jnp.ix_(ind, ind)])

(lpdf1, lpdf2)

(DeviceArray(-254.89534275, dtype=float64),
 DeviceArray(-254.89534275, dtype=float64))

In [91]:
x = jnp.arange(10)
x[:-1]

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=int64)

In [92]:
x[-1]

DeviceArray(9, dtype=int64)

In [93]:
x

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64)

In [11]:
x = jnp.arange(3)
y = jnp.zeros((5, 3))
jnp.append(jnp.expand_dims(x, axis=0), y, axis=0)

DeviceArray([[0., 1., 2.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]], dtype=float64)

In [18]:
jnp.block([[x], [y]])

DeviceArray([[0., 1., 2.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]], dtype=float64)

In [21]:
x = jnp.ones((2, 3))
y = jnp.zeros((4, 2, 3))
jnp.append(jnp.expand_dims(x, axis=0), y, axis=0)

DeviceArray([[[1., 1., 1.],
              [1., 1., 1.]],

             [[0., 0., 0.],
              [0., 0., 0.]],

             [[0., 0., 0.],
              [0., 0., 0.]],

             [[0., 0., 0.],
              [0., 0., 0.]],

             [[0., 0., 0.],
              [0., 0., 0.]]], dtype=float64)

In [24]:
(7,) + x.shape

(7, 2, 3)