# To do

* コレスキー分解 numpy.linalg.cholesky と scipy.linalg.lapack.dpotrf

* Chosolv


---

* テストコード https://www.youtube.com/watch?v=Hl8UNYrp0Vg


# 0. ライブラリのインポート & トイデータの準備

In [35]:
import numpy as np
import pandas as pd
import GPy
import pods

In [36]:
np.random.seed(seed=1)

In [37]:
data = pods.datasets.olympic_100m_men()
X, Y = data["X"], data["Y"]
X_pred = np.linspace(X[:,0].min() - 30,
                     X[:,0].max() + 30,
                     500).reshape(-1,1)

In [38]:
pd.concat([pd.DataFrame(X.T, index=["X"]), pd.DataFrame(Y.T, index=["Y"])], axis=0)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,17,18,19,20,21,22,23,24,25,26
X,1896.0,1900.0,1904.0,1906.0,1908.0,1912.0,1920.0,1924.0,1928.0,1932.0,...,1972.0,1976.0,1980.0,1984.0,1988.0,1992.0,1996.0,2000.0,2004.0,2008.0
Y,12.0,11.0,11.0,11.2,10.8,10.8,10.8,10.6,10.8,10.3,...,10.14,10.06,10.25,9.99,9.92,9.96,9.84,9.87,9.85,9.69


In [39]:
pd.DataFrame(X_pred.T, index=["X_pred"])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,490,491,492,493,494,495,496,497,498,499
X_pred,1866.0,1866.344689,1866.689379,1867.034068,1867.378758,1867.723447,1868.068136,1868.412826,1868.757515,1869.102204,...,2034.897796,2035.242485,2035.587174,2035.931864,2036.276553,2036.621242,2036.965932,2037.310621,2037.655311,2038.0


# 1. RBFカーネル

$$k(x, x') = variance * \exp\left( -\frac{1}{2} \frac{|x - x'|^2}{lengthscale} \right)$$

In [212]:
class RBF:
    def __init__(self, variance=1., lengthscale=0.1):
        self.variance=variance
        self.lengthscale=lengthscale
        # self.r = self._euclidean_distance
        
    def K(self, X, X2=None):
        return self.variance * np.exp(-0.5 * (self._euc_dist(X, X2) / self.lengthscale)**2)
        # return self._euc_dist(X, X2)
        
    def _euc_dist(self, X, X2):
        if X2 is None:
            # print("X2 is None")
            # print(X2)
            Xsq = np.sum(np.square(X),1)
            r2 = -2.*(np.dot(X, X.T)) + (Xsq[:,None] + Xsq[None,:]) 
            r2 = np.clip(r2, 0, np.inf)
            np.fill_diagonal(r2, 0.)
            return np.sqrt(r2)
        else:
            # print(X)
            # print(X2)
            X1sq = np.sum(np.square(X),1)
            X2sq = np.sum(np.square(X2),1)
            r2 = -2.*np.dot(X, X2.T) + (X1sq[:,None] + X2sq[None,:])
            r2 = np.clip(r2, 0, np.inf)
            return np.sqrt(r2)

In [213]:
kern = RBF()
kern.variance, kern.lengthscale

(1.0, 0.1)

In [76]:
pd.DataFrame(kern.K(X, X_pred)).head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,490,491,492,493,494,495,496,497,498,499
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [75]:
kern.K(X, X_pred).max(), kern.K(X, X_pred).min()

(0.9927971862436876, 0.0)

ハイパーパラメータの値が重要

In [77]:
kern.lengthscale = 10.
kern.lengthscale

10.0

In [79]:
pd.DataFrame(kern.K(X, X_pred)).head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,490,491,492,493,494,495,496,497,498,499
0,0.011109,0.012312,0.013629,0.015069,0.016641,0.018356,0.020223,0.022253,0.024459,0.026851,...,1.27837e-42,7.915405e-43,4.895237e-43,3.023837e-43,1.865636e-43,1.1496870000000001e-43,7.076463999999999e-44,4.350478e-44,2.6714169999999996e-44,1.638439e-44
1,0.003089,0.003471,0.003895,0.004367,0.004889,0.005468,0.006108,0.006814,0.007593,0.008452,...,3.0536190000000002e-40,1.916988e-40,1.202009e-40,7.528012e-41,4.709088e-41,2.942234e-41,1.836123e-41,1.144486e-41,7.125295e-42,4.4307719999999996e-42
2,0.000732,0.000834,0.000949,0.001078,0.001224,0.001388,0.001572,0.001778,0.002009,0.002267,...,6.215642e-38,3.9562009999999995e-38,2.515097e-38,1.597037e-38,1.0128829999999999e-38,6.416346000000001e-39,4.0597580000000005e-39,2.565645e-39,1.619485e-39,1.0210370000000001e-39
3,0.000335,0.000385,0.000441,0.000505,0.000577,0.000659,0.000751,0.000855,0.000973,0.001106,...,8.351489e-37,5.35242e-37,3.426261e-37,2.190659e-37,1.398985e-37,8.923502999999999e-38,5.685147e-38,3.617696e-38,2.2993569999999997e-38,1.4597039999999998e-38
4,0.000148,0.000171,0.000197,0.000227,0.000261,0.0003,0.000345,0.000395,0.000453,0.000518,...,1.078127e-35,6.957453e-36,4.4845049999999995e-36,2.8871059999999998e-36,1.8565e-36,1.1923699999999999e-36,7.649119e-37,4.901124e-37,3.136635e-37,2.005009e-37


In [80]:
kern.K(X, X_pred).max(), kern.K(X, X_pred).min()

(0.9999992771123355, 1.638439217006888e-44)

# 2. コレスキー分解

`custom_cholesky` 関数を作成

In [52]:
import numpy as np
from scipy import linalg
from scipy.linalg import lapack, blas
import traceback
import logging


def jitchol(A, maxtries=5):
    A = np.ascontiguousarray(A)
    L, info = lapack.dpotrf(A, lower=1)
    if info == 0:
        return L
    else:
        diagA = np.diag(A)
        if np.any(diagA <= 0.):
            raise linalg.LinAlgError("not pd: non-positive diagonal elements")
        jitter = diagA.mean() * 1e-6
        num_tries = 1
        while num_tries <= maxtries and np.isfinite(jitter):
            try:
                L = linalg.cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
                return L
            except:
                jitter *= 10
            finally:
                num_tries += 1
        raise linalg.LinAlgError("not positive definite, even with jitter.")

    try:
        raise
    except:
        logging.warning('\n'.join(['Added jitter of {:.10e}'.format(jitter),
                                   '  in ' + traceback.format_list(traceback.extract_stack(limit=3)[-2:-1])[0][2:]]))
    return L

In [135]:
def generate_non_pd_mat():    
    # Create PD matrix
    A = np.random.randn(20, 100)
    A = A.dot(A.T)
    # Compute Eigdecomp
    vals, vectors = np.linalg.eig(A)
    # Set smallest eigenval to be negative with 5 rounds worth of jitter
    vals[vals.argmin()] = 0
    default_jitter = 1e-6 * np.mean(vals)
    vals[vals.argmin()] = -default_jitter * (10 ** 3.5)
    A_corrupt = (vectors * vals).dot(vectors.T)
    return A_corrupt

In [136]:
A = generate_non_pd_mat()
np.linalg.cholesky(A)

LinAlgError: Matrix is not positive definite

In [137]:
lapack.dpotrf(A)

(array([[ 9.65081412,  0.33582105, -1.68565454, -1.14139877, -0.37228016,
         -0.17195544, -0.7136588 ,  1.37156911, -0.77773783,  0.02386796,
          0.57549029, -1.85623697, -1.63678659, -1.97397285, -0.29475409,
          0.62568007,  0.69315421, -1.35694038, -0.07766662, -1.14655859],
        [ 0.        ,  8.77726227, -1.00289126, -1.73135024, -0.30927861,
         -0.51627505,  0.49946822, -0.90580208, -0.73003844, -1.16834791,
          1.26743552,  0.76388433, -0.34926526, -0.35713272, -0.30200429,
          0.26370762,  1.18092124,  1.08678571, -1.55762634,  0.04723639],
        [ 0.        ,  0.        ,  8.26585999,  2.42290979,  1.72457824,
         -0.91897196,  1.7103156 ,  2.75443181, -1.34556061,  2.46243475,
         -1.48284766, -2.09251147, -0.89632598, -0.03501279,  0.65255423,
          2.77734342,  0.82520118, -1.91278654, -0.18600112, -0.98087357],
        [ 0.        ,  0.        ,  0.        ,  9.07474373,  0.01938929,
         -1.49210495,  1.05049838, 

In [138]:
jitchol(A)

array([[ 9.70003086,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.33411714,  8.83141325,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [-1.67710174, -0.9973907 ,  8.32565856,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [-1.13560746, -1.72117356,  2.41020951,  9.13307523,  0.        ,
         0.        ,  0.        ,  0.        ,  

In [139]:
pd.DataFrame(jitchol(A))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,9.700031,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.334117,8.831413,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,-1.677102,-0.997391,8.325659,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,-1.135607,-1.721174,2.41021,9.133075,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,-0.370391,-0.307526,1.713369,0.025768,10.915215,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,-0.171083,-0.513176,-0.911307,-1.484499,0.028086,10.28878,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,-0.710038,0.496131,1.698764,1.048933,-0.812588,-1.080877,10.198181,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,1.36461,-0.89972,2.733164,-0.116862,0.078521,-1.269719,-0.829489,9.260613,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,-0.773792,-0.725861,-1.33332,0.630858,0.03653,0.336257,0.405599,1.45275,8.641236,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,0.023747,-1.161175,2.446331,-0.857159,0.318678,-1.4374,-1.052958,-1.175322,0.111691,9.409254,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [140]:
import numpy as np
from scipy.linalg import cholesky, LinAlgError


def custom_cholesky(A, max_tries=5):
    A = np.ascontiguousarray(A) # パフォーマンス向上 計算結果にも影響
    diag_A = np.diag(A)
    jitter = diag_A.mean() * 1e-6
    num_tries = 0
    
    try:
        L = cholesky(A, lower=True)
        return L
    except LinAlgError:
        num_tries += 1
        
    while num_tries <= max_tries and np.isfinite(jitter):
        try:
            L = cholesky(A + np.eye(A.shape[0]) * jitter, lower=True)
            return L
        except LinAlgError:
            jitter *= 10
            num_tries += 1
    
    raise LinAlgError("Matrix is not positive definite, even with jitter.")

In [141]:
for i in range(50):
    A = generate_non_pd_mat()
    B = jitchol(A)
    C = custom_cholesky(A)
    print(np.unique(B == C))

[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]


In [142]:
def np_custom_cholesky(A, max_tries=5):
    A = np.ascontiguousarray(A) # パフォーマンス向上 計算結果にも影響
    diag_A = np.diag(A)
    jitter = diag_A.mean() * 1e-6
    num_tries = 0
    
    try:
        L = np.linalg.cholesky(A)
        return L
    except LinAlgError:
        num_tries += 1
        
    while num_tries <= max_tries and np.isfinite(jitter):
        try:
            L = np.linalg.cholesky(A + np.eye(A.shape[0]) * jitter)
            return L
        except LinAlgError:
            jitter *= 10
            num_tries += 1
    
    raise LinAlgError("Matrix is not positive definite, even with jitter.")

In [205]:
for i in range(50):
    A = generate_non_pd_mat()
    # B = jitchol(A)
    C = custom_cholesky(A)
    D = np_custom_cholesky(A)
    print(np.unique(C == D))

[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]
[ True]


# 3. 連立方程式のコレスキー分解を利用した解法

$$\begin{align*}
p( y^* | X^*, \mathcal D) = \mathcal N  ( y^* | K_{N*}^T ( \sigma^2 I &+ K_{NN})^{-1} y, K_{**} - K_{N*}^T (\sigma^2 I + K_{NN})^{-1}  K_{N*})
\end{align*}$$

In [215]:
from scipy.linalg import solve_triangular

In [222]:
noise = 1.
noise

1.0

In [219]:
kern = RBF()
kern.lengthscale = 10.
kern.variance, kern.lengthscale

(1.0, 10.0)

In [226]:
K = kern.K(X, X)
print(K.shape)
pd.DataFrame(K).head()

(27, 27)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,17,18,19,20,21,22,23,24,25,26
0,1.0,0.923116,0.726149,0.606531,0.486752,0.278037,0.056135,0.019841,0.005976,0.001534,...,2.867975e-13,1.266417e-14,4.765305e-16,1.52798e-17,4.1750099999999994e-19,9.720985e-21,1.92875e-22,3.2610269999999998e-24,4.698355e-26,5.7683300000000005e-28
1,0.923116,1.0,0.923116,0.83527,0.726149,0.486752,0.135335,0.056135,0.019841,0.005976,...,5.53461e-12,2.867975e-13,1.266417e-14,4.765305e-16,1.52798e-17,4.1750099999999994e-19,9.720985e-21,1.92875e-22,3.2610269999999998e-24,4.698355e-26
2,0.726149,0.923116,1.0,0.980199,0.923116,0.726149,0.278037,0.135335,0.056135,0.019841,...,9.101471e-11,5.53461e-12,2.867975e-13,1.266417e-14,4.765305e-16,1.52798e-17,4.1750099999999994e-19,9.720985e-21,1.92875e-22,3.2610269999999998e-24
3,0.606531,0.83527,0.980199,1.0,0.980199,0.83527,0.375311,0.197899,0.088922,0.034047,...,3.475891e-10,2.289735e-11,1.285337e-12,6.148396e-14,2.506222e-15,8.705427000000001e-17,2.5767570000000002e-18,6.499348e-20,1.396944e-21,2.5585920000000003e-23
4,0.486752,0.726149,0.923116,0.980199,1.0,0.923116,0.486752,0.278037,0.135335,0.056135,...,1.275408e-09,9.101471e-11,5.53461e-12,2.867975e-13,1.266417e-14,4.765305e-16,1.52798e-17,4.1750099999999994e-19,9.720985e-21,1.92875e-22


In [232]:
L = custom_cholesky(K + noise*np.eye(K.shape[0]))
LT = L.T

print(L.shape)
display(pd.DataFrame(L).head())
print(LT.shape)
display(pd.DataFrame(LT).head())

(27, 27)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,17,18,19,20,21,22,23,24,25,26
0,1.414214,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.652742,1.254563,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.513465,0.468654,1.231551,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.428882,0.442641,0.448652,1.191151,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.344186,0.399729,0.453944,0.379452,1.171198,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


(27, 27)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,17,18,19,20,21,22,23,24,25,26
0,1.414214,0.652742,0.513465,0.428882,0.344186,0.196602,0.039693,0.01403,0.004226,0.001085,...,2.027965e-13,8.954917e-15,3.369579e-16,1.0804450000000001e-17,2.952178e-19,6.873774e-21,1.363832e-22,2.305894e-24,3.3222389999999997e-26,4.0788250000000005e-28
1,0.0,1.254563,0.468654,0.442641,0.399729,0.285695,0.087222,0.037445,0.013617,0.004199,...,4.30607e-12,2.239443e-13,9.919167e-15,3.742164e-16,1.2025780000000002e-17,3.2920959999999996e-19,7.677544e-21,1.52539e-22,2.5820479999999998e-24,3.723791e-26
2,0.0,0.0,1.231551,0.448652,0.453944,0.398935,0.176021,0.089792,0.038637,0.014061,...,7.217935e-11,4.405064e-12,2.2896e-13,1.01362e-14,3.82236e-16,1.227882e-17,3.3602589999999997e-19,7.83428e-21,1.5561510000000001e-22,2.633563e-24
3,0.0,0.0,0.0,1.191151,0.379452,0.374015,0.202079,0.113354,0.053518,0.021337,...,2.629496e-10,1.747725e-11,9.890255e-13,4.765648e-14,1.955488e-15,6.833448e-17,2.033782e-18,5.155526e-20,1.113184e-21,2.047407e-23
4,0.0,0.0,0.0,0.0,1.171198,0.357099,0.240473,0.148965,0.077349,0.033815,...,9.742798e-10,7.026198e-11,4.31294e-12,2.253757e-13,1.002711e-14,3.798615e-16,1.2254480000000001e-17,3.3668109999999997e-19,7.878174e-21,1.570149e-22


In [246]:
alpha = solve_triangular(LT, solve_triangular(L, Y, lower=True))

In [247]:
K_starN = kern.K(X, X_pred).T
K_starN.shape

(500, 27)

In [248]:
mu = K_starN @ alpha

In [251]:
pd.DataFrame(mu.T, index=['mu'])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,490,491,492,493,494,495,496,497,498,499
mu,0.049292,0.054763,0.060772,0.067363,0.074585,0.082486,0.091122,0.100547,0.110821,0.122007,...,0.097011,0.088061,0.079846,0.072316,0.065422,0.059118,0.053361,0.04811,0.043326,0.038974


## ↑ GPyの計算結果と一致した

In [220]:
mu = kern.K(X, X_pred) @ 

array([[1.11089965e-02, 1.23119217e-02, 1.36289021e-02, ...,
        4.35047795e-44, 2.67141677e-44, 1.63843922e-44],
       [3.08871541e-03, 3.47069761e-03, 3.89528887e-03, ...,
        1.14448560e-41, 7.12529547e-42, 4.43077231e-42],
       [7.31802419e-04, 8.33720712e-04, 9.48705355e-04, ...,
        2.56564510e-39, 1.61948523e-39, 1.02103685e-39],
       ...,
       [1.02103685e-39, 1.61948523e-39, 2.56564510e-39, ...,
        9.48705355e-04, 8.33720712e-04, 7.31802419e-04],
       [4.43077231e-42, 7.12529547e-42, 1.14448560e-41, ...,
        3.89528887e-03, 3.47069761e-03, 3.08871541e-03],
       [1.63843922e-44, 2.67141677e-44, 4.35047795e-44, ...,
        1.36289021e-02, 1.23119217e-02, 1.11089965e-02]])

In [None]:
from scipy.linalg import solve_triangular

def _raw_predict(self, kern, Xnew, pred_var, full_cov=False):
    woodbury_vector = self.woodbury_vector
    woodbury_inv = self.woodbury_inv

    if not isinstance(Xnew, VariationalPosterior):
        Kx = kern.K(pred_var, Xnew)
        mu = np.dot(Kx.T, woodbury_vector)
        if len(mu.shape) == 1:
            mu = mu.reshape(-1, 1)
        if full_cov:
            Kxx = kern.K(Xnew)
            if woodbury_inv.ndim == 2:
                var = Kxx - np.dot(Kx.T, np.dot(woodbury_inv, Kx))
            elif woodbury_inv.ndim == 3:  # Missing data
                var = np.empty((Kxx.shape[0], Kxx.shape[1], woodbury_inv.shape[2]))
                from ...util.linalg import mdot
                for i in range(var.shape[2]):
                    var[:, :, i] = (Kxx - mdot(Kx.T, woodbury_inv[:, :, i], Kx))
            var = var
        else:
            Kxx = kern.Kdiag(Xnew)
            if woodbury_inv.ndim == 2:
                var = (Kxx - np.sum(np.dot(woodbury_inv.T, Kx) * Kx, 0))[:, None]
            elif woodbury_inv.ndim == 3:  # Missing data
                var = np.empty((Kxx.shape[0], woodbury_inv.shape[2]))
                for i in range(var.shape[1]):
                    var[:, i] = (Kxx - (np.sum(np.dot(woodbury_inv[:, :, i].T, Kx) * Kx, 0)))
            var = var
            var = np.clip(var, 1e-15, np.inf)
    return mu, var

In [None]:
_K = 
_K_chol = 
_woodbury_chol = 
_woodbury_vector = 
