# Conditional Multivariate Normal Distribution

In this notebook we will learn about the [conditional multivariate normal (MVN) distribution](https://en.wikipedia.org/wiki/Multivariate_normal_distribution). In particular, we want to estimate the expected value (or the mean) of some subset of variables given that another subset has been conditioned on. Though the notation is quasi-dense, it is not terribly difficult to produce a conditional MVN from a marginal MVN distribution. 

## Case 1, pair

* $X_0 \rightarrow X_1$

In [1]:
import numpy as np
from numpy.random import normal

In [2]:
N = 10000
x0 = normal(0, 1, N)
x1 = normal(1 + 2 * x0, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print(X.shape)
print(M.shape)
print(S.shape)
print('mean', M)
print('cov', S)

(10000, 2)
(2,)
(2, 2)
mean [5.68155884e-04 1.01312455e+00]
cov [[0.99668835 2.01313553]
 [2.01313553 5.05870431]]


In [3]:
M[0] + S[0,1] / S[1,1] * (0.5 - M[1])

-0.2036322093171853

In [4]:
M[1] + S[1,0] / S[0,0] * (0.5 - M[0])

2.021889222105834

In [5]:
S[0,0] - S[0,1] / S[1,1] * S[1,0]

0.19555145212253944

In [6]:
S[1,1] - S[1,0] / S[0,0] * S[1,0]

0.9925238697013707

## Case 2, serial

* $X_0 \rightarrow X_1 \rightarrow X_2$

In [7]:
from collections import namedtuple
from numpy.linalg import inv
import warnings

warnings.filterwarnings('ignore')
COV = namedtuple('COV', 'C11 C12 C21 C22 C22I')

def to_row_indices(indices):
    return [[i] for i in indices]

def to_col_indices(indices):
    return indices

def get_covariances(i1, i2, S):
    r = to_row_indices(i1)
    c = to_col_indices(i1)
    C11 = S[r,c]
    
    r = to_row_indices(i1)
    c = to_col_indices(i2)
    C12 = S[r,c]
    
    r = to_row_indices(i2)
    c = to_col_indices(i1)
    C21 = S[r,c]
    
    r = to_row_indices(i2)
    c = to_col_indices(i2)
    C22 = S[r,c]
    
    C22I = inv(C22)
    
    return COV(C11, C12, C21, C22, C22I)

def compute_means(a, M, C, i1, i2):
    a = np.array([2.0])
    return M[i1] + C.C12.dot(C.C22I).dot(a - M[i2])

def compute_covs(C):
    return C.C11 - C.C12.dot(C.C22I).dot(C.C21)

def update_mean(m, a, M, i1, i2):
    v = np.copy(M)
    for i, mu in zip(i1, m):
        v[i] = mu
    for i, mu in zip(i2, a):
        v[i] = mu
    return v

def update_cov(c, S, i1, i2):
    m = np.copy(S)
    rows, cols = c.shape
    for row in range(rows):
        for col in range(cols):
            m[i1[row],i1[col]] = c[row,col]
    for i in i2:
        m[i,i] = 0.01
    return m

def update_mean_cov(v, iv, M, S):
    if v is None or iv is None or len(v) == 0 or len(iv) == 0:
        return np.copy(M), np.copy(S)
    i2 = iv.copy()
    i1 = [i for i in range(S.shape[0]) if i not in i2]
    
    C = get_covariances(i1, i2, S)
    m = compute_means(v, M, C, i1, i2)
    c = compute_covs(C)
    M_u = update_mean(m, v, M, i1, i2)
    S_u = update_cov(c, S, i1, i2)
    return M_u, S_u

In [8]:
N = 10000
x0 = normal(0, 1, N)
x1 = normal(1 + 2 * x0, 1, N)
x2 = normal(1 + 2 * x1, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1), x2.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print('mean', M)
print('>')
print('cov', S)
print('>')
print('corr', np.corrcoef(X.T))

mean [1.66213799e-03 9.98871604e-01 3.01395229e+00]
>
cov [[ 1.01282022  2.01247522  4.00904807]
 [ 2.01247522  4.98867722  9.94111759]
 [ 4.00904807  9.94111759 20.79087155]]
>
corr [[1.         0.89530633 0.87365252]
 [0.89530633 1.         0.97612663]
 [0.87365252 0.97612663 1.        ]]


In [9]:
M_u, S_u = update_mean_cov(np.array([2.0]), [1], M, S)

print('mean', M_u)
print('>')
print('cov', S_u)
print('>')
print('corr', np.corrcoef(np.random.multivariate_normal(M_u, S_u, N*10).T))

mean [0.40552593 2.         5.00893706]
>
cov [[ 2.00970444e-01  2.01247522e+00 -1.28410016e-03]
 [ 2.01247522e+00  1.00000000e-02  9.94111759e+00]
 [-1.28410016e-03  9.94111759e+00  9.80846857e-01]]
>
corr [[ 1.         -0.01818347  0.78173084]
 [-0.01818347  1.          0.05034101]
 [ 0.78173084  0.05034101  1.        ]]


## Case 3, diverging

* $X_0 \leftarrow X_1 \rightarrow X_2$

In [10]:
N = 10000

x1 = normal(0, 1, N)
x0 = normal(1 + 4.0 * x1, 1, N)
x2 = normal(1 + 2.0 * x1, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1), x2.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print('mean', M)
print('>')
print('cov', S)
print('>')
print('corr', np.corrcoef(X.T))

mean [0.99893644 0.00140921 1.00078216]
>
cov [[17.04367065  4.01955884  8.13217597]
 [ 4.01955884  1.00495818  2.03724162]
 [ 8.13217597  2.03724162  5.12986471]]
>
corr [[1.         0.97123166 0.86970556]
 [0.97123166 1.         0.89725439]
 [0.86970556 0.89725439 1.        ]]


In [11]:
M_u, S_u = update_mean_cov(np.array([2.0]), [1], M, S)

print('mean', M_u)
print('>')
print('cov', S_u)
print('>')
print('corr', np.corrcoef(np.random.multivariate_normal(M_u, S_u, N*10).T))

mean [8.99275494 2.         5.05230632]
>
cov [[ 0.9665307   4.01955884 -0.01623531]
 [ 4.01955884  0.01        2.03724162]
 [-0.01623531  2.03724162  0.99998796]]
>
corr [[1.         0.10108122 0.56025463]
 [0.10108122 1.         0.07837791]
 [0.56025463 0.07837791 1.        ]]


## Case 4, converging

* $X_0 \rightarrow X_1 \leftarrow X_2$

In [12]:
N = 10000

x0 = normal(0, 1, N)
x2 = normal(0, 1, N)
x1 = normal(1 + 2 * x0 + 3 * x2, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1), x2.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print('mean', M)
print('>')
print('cov', S)
print('>')
print('corr', np.corrcoef(X.T))

mean [-0.00160404  1.01593489 -0.00188201]
>
cov [[ 9.99862997e-01  1.98571152e+00 -2.46181590e-03]
 [ 1.98571152e+00  1.41042894e+01  3.03830414e+00]
 [-2.46181590e-03  3.03830414e+00  1.01292067e+00]]
>
corr [[ 1.          0.52877426 -0.00244623]
 [ 0.52877426  1.          0.80383687]
 [-0.00244623  0.80383687  1.        ]]


In [13]:
M_u, S_u = update_mean_cov(np.array([2.0]), [1], M, S)

print('mean', M_u)
print('>')
print('cov', S_u)
print('>')
print('corr', np.corrcoef(np.random.multivariate_normal(M_u, S_u, N*10).T))

mean [0.1369403  2.         0.21010237]
>
cov [[ 0.72029909  1.98571152 -0.4302179 ]
 [ 1.98571152  0.01        3.03830414]
 [-0.4302179   3.03830414  0.35841821]]
>
corr [[1.         0.01033491 0.53523265]
 [0.01033491 1.         0.01582778]
 [0.53523265 0.01582778 1.        ]]
