# Gaussian Distributions Utils
> Functions to help work with gaussian distributions

In [None]:
#| hide
#| default_exp gaussian

In [None]:
from fastcore.test import *

In [None]:
import altair as alt

## Normal Parameters

In [None]:
#| export
from collections import namedtuple
from fastcore.basics import patch

### Normal

In [None]:
import torch

In [None]:
#| export
ListNormal = namedtuple('ListNormal', ['mean', 'std'])

In [None]:
#| export
Normal = namedtuple('Normal', ['mean', 'std'])

In [None]:
#| export
@patch
def __getitem__(self: ListNormal, n:int
           )->Normal:
    """Get the mean and cov for the nth Normal distribution in the list """
    return Normal(self.mean[n], self.std[n])

In [None]:
#| export
@patch
def detach(self: ListNormal)->ListNormal:
    """Detach both mean and cov at once """
    return ListNormal(self.mean.detach(), self.std.detach())

In [None]:
ListNormal(torch.rand(10), torch.rand(10))[1]

Normal(mean=tensor(0.9594), std=tensor(0.9120))

### Multivariate Normal

In [None]:
#| export
ListMNormal = namedtuple('ListMultiNormal', ['mean', 'cov'])

In [None]:
#| export
MNormal = namedtuple('MultiNormal', ['mean', 'cov'])

In [None]:
#| export
@patch
def __getitem__(self: ListMNormal, n:int
           )->Normal:
    """Get the mean and cov for the nth Normal distribution in the list """
    return MNormal(self.mean[n], self.cov[n])
@patch
def __setitem__(self: ListMNormal, idx, value)->Normal:
    """set the mean and cov for the nth Normal distribution in the list """
    self.mean[idx], self.cov[idx] = value

In [None]:
#| export
@patch
def detach(self: ListMNormal)->ListMNormal:
    """Detach both mean and cov at once """
    return ListMNormal(self.mean.detach(), self.cov.detach())

In [None]:
ListMNormal(torch.rand(2,10), torch.rand(2,10,10))[1]

MultiNormal(mean=tensor([0.3448, 0.9602, 0.4728, 0.2222, 0.2587, 0.4202, 0.2224, 0.0836, 0.2493,
        0.8830]), cov=tensor([[3.3335e-02, 9.3149e-01, 5.8070e-01, 8.6764e-01, 6.3861e-01, 1.4773e-01,
         8.8264e-01, 5.2911e-01, 4.7929e-01, 9.1570e-01],
        [1.8673e-01, 6.1650e-02, 3.7119e-01, 4.4148e-01, 6.5621e-01, 6.7826e-01,
         2.7303e-01, 5.1167e-01, 3.1728e-01, 7.7031e-01],
        [7.2650e-01, 4.6095e-01, 4.3758e-01, 1.4045e-01, 8.8102e-01, 2.4084e-01,
         1.4318e-01, 3.7561e-01, 9.7127e-01, 1.4582e-02],
        [6.3298e-01, 6.4655e-01, 6.0602e-01, 3.4855e-01, 7.3572e-01, 8.6268e-01,
         8.5429e-01, 1.6555e-01, 6.4872e-01, 2.9136e-01],
        [4.9059e-01, 7.4218e-01, 8.3842e-01, 1.5727e-01, 8.4987e-01, 6.9178e-01,
         2.3442e-01, 7.8525e-01, 9.5179e-01, 1.1631e-01],
        [7.5024e-01, 3.6468e-01, 1.2647e-01, 8.5569e-01, 2.3933e-01, 9.0536e-01,
         5.1002e-01, 8.6443e-01, 3.7047e-01, 6.8500e-01],
        [9.8393e-01, 1.1033e-01, 5.9461e-01, 5.

## Positive Definite

The covariance matrices need to be [positive definite](https://en.wikipedia.org/wiki/Definite_matrix)
Those are utilities functions to check is a matrix is positive definite and to make any matrix positive definite

In [None]:
#| export
import pandas as pd
from torch import Tensor

In [None]:
A = torch.rand(2,3,3) # batched random matrix used for testing

#### Symmetry

In [None]:
#| export
def is_symmetric(value, atol=1e-5):
    return torch.isclose(value, value.mT, atol=atol).all(-1).all(-1)

In [None]:
is_symmetric(A)

tensor([False, False])

In [None]:
#| export
def symmetric_upto(value, start=-8):
    for exp in torch.arange(start, 3):
        if is_symmetric(value, atol=10**exp):
            return exp.item()
    return exp.item()

def symmetric_upto_batched(value, start=-8):
    return torch.tensor([symmetric_upto(v) for v in value])

In [None]:
symmetric_upto_batched(A)

tensor([0, 0])

#### is posdef

Default pytorch check (uses symmetry + cholesky decomposition)

In [None]:
#| export
def is_posdef(cov):
    return torch.distributions.constraints.positive_definite.check(cov)

In [None]:
is_posdef(A)

tensor([False, False])

check if it is pos definite using eigenvalues. Positive definite matrix have all positive eigenvalues

In [None]:
torch.linalg.eigvalsh(A)

tensor([[-0.1517,  0.0101,  2.3280],
        [-0.5852,  0.0236,  1.5179]])

In [None]:
#| export
def is_posdef_eigv(cov):
    eigv = torch.linalg.eigvalsh(cov)
    return eigv.ge(0).all(-1), eigv

In [None]:
is_posdef_eigv(A)

(tensor([False, False]),
 tensor([[-0.1517,  0.0101,  2.3280],
         [-0.5852,  0.0236,  1.5179]]))

### Pytorch constraint

Note that `is_posdef` and `is_posdef_eigv` can return different values, in general `is_posdef_eigv` is more tollerant

transform any matrix $A$ into a positive definite matrix ($PD$) using the following formula

$PD = AA^T + aI$ 

where $AA^T$ is a positive semi-definite matrix and $a$ is a small positive number that is added on the diagonal to ensure that the resulting matrix is positive definite (not semi-definite)

the inverse transformation uses cholesky decomposition


Another approach would be to multiple to lower triangular matrix, but they'd require a positive diagonal, which is harderd to obtain see [https://en.wikipedia.org/wiki/Definite_matrix#Cholesky_decomposition](https://en.wikipedia.org/wiki/Definite_matrix#Cholesky_decomposition)

The API inspired by gpytorch constraints

In [None]:
#| export
from fastcore.foundation import docs
from fastcore.basics import store_attr

In [None]:
from meteo_imp.utils import *

In [None]:
#| export
@docs
class PosDef(): 
    def __init__(self, a=1e-5): store_attr()
    def transform(self,raw): return raw @ raw.mT + self.a * torch.eye(raw.shape[-1], device=raw.device, dtype=raw.dtype)
    def inverse_transform(self, value): return torch.linalg.cholesky(value)
    
    _docs = {'cls_doc': "Positive Definite Constraint for PyTorch parameters",
             'transform':"transform any square matrix into a positive definite one",
             'inverse_transform': "tranform positive definite matrix into a matrix that can be back_transformed using `transform`"}

to_posdef = PosDef().transform

In [None]:
constraint = PosDef()

posdef = constraint.transform(A)

In [None]:
A

tensor([[[0.6904, 0.6166, 0.5863],
         [0.7641, 0.8669, 0.0038],
         [0.7567, 0.8677, 0.6291]],

        [[0.3774, 0.0545, 0.4272],
         [0.5981, 0.4032, 0.8922],
         [0.3197, 0.8341, 0.1758]]])

In [None]:
posdef

tensor([[[1.2005, 1.0642, 1.4262],
         [1.0642, 1.3354, 1.3328],
         [1.4262, 1.3328, 1.7212]],

        [[0.3279, 0.6289, 0.2412],
         [0.6289, 1.3163, 0.6843],
         [0.2412, 0.6843, 0.8288]]])

In [None]:
show_as_row(is_posdef(torch.stack([posdef,A])), is_posdef_eigv(torch.stack([posdef,A])), is_symmetric(torch.stack([posdef,A])))

In [None]:
test_eq(is_posdef(posdef).all(), True)

In [None]:
constraint.inverse_transform(posdef)

tensor([[[1.0957, 0.0000, 0.0000],
         [0.9713, 0.6261, 0.0000],
         [1.3017, 0.1094, 0.1222]],

        [[0.5726, 0.0000, 0.0000],
         [1.0982, 0.3321, 0.0000],
         [0.4212, 0.6675, 0.4536]]])

In [None]:
test_close(posdef, constraint.transform(constraint.inverse_transform(posdef)), eps=2e-5)

In [None]:
symmetric_upto(posdef[0])

-8

In [None]:
reset_seed()

In [None]:
# fuzzer
def fuzz_posdef(n=1000, noise=1e-5, bs=10, range=(0,1), **kwargs):
    A = torch.rand(bs,n,n, **kwargs)  * (range[1]-range[0]) + range[0]
    posdef = PosDef(noise).transform(A)
    return pd.DataFrame(
        {'n': [n], 'noise': f"{noise:.0e}", 'range': str(range), 'n_samples': bs,
         'posdef': is_posdef(posdef).sum().item() / bs,
         'sym': is_symmetric(posdef).sum().item() / bs, 
         'posdef_eigv': is_posdef_eigv(posdef)[0].sum().item() / bs
    })

In [None]:
fuzz_posdef()

Unnamed: 0,n,noise,range,n_samples,posdef,sym,posdef_eigv
0,1000,1e-05,"(0, 1)",10,0.8,1.0,0.9


In [None]:
n_min, n_max = -1, 1
A = torch.rand(2,100,100)  * (n_max-n_min) + n_min

In [None]:
is_posdef(to_posdef(A))

tensor([True, True])

In [None]:
ma = torch.tensor([[1, 7],
                   [-3, 4]])

In [None]:
is_posdef(to_posdef(ma))

tensor(True)

In [None]:
fuzz_posdef(device='cuda')

Unnamed: 0,n,noise,range,n_samples,posdef,sym,posdef_eigv
0,1000,1e-05,"(0, 1)",10,0.5,1.0,0.6


In [None]:
# %time fuzz_posdef(bs=100, device='cuda')

In [None]:
rate_posdef = pd.concat([fuzz_posdef(n, a, bs=100, range=range, device='cuda') 
               for n in [10, 100]
               for a in [1e-2, 1e-5, 1e-7]
               for range in [(-1,1),(0,1)]])

In [None]:
import altair as alt
from altair import datum

In [None]:
rate_posdef.head()

Unnamed: 0,n,noise,range,n_samples,posdef,sym,posdef_eigv
0,10,0.01,"(-1, 1)",100,1.0,1.0,1.0
0,10,0.01,"(0, 1)",100,1.0,1.0,1.0
0,10,1e-05,"(-1, 1)",100,1.0,1.0,1.0
0,10,1e-05,"(0, 1)",100,1.0,1.0,1.0
0,10,1e-07,"(-1, 1)",100,1.0,1.0,1.0


In [None]:
def _plot_var(df, var, x='n:N', row='range', column='noise:N', y_domain=(0,1), height=70, width=50):
    bar = alt.Chart(df).mark_bar().encode(
        x = alt.X('n:N'),
        y = alt.Y(var, scale=alt.Scale(domain=y_domain)),
        color = 'n:N',
    ).properties(height=height, width=width, ) 
    
    text = alt.Chart(df).mark_text(dy=10, color='white').encode(
        x = alt.X('n:N'),
        y = alt.Y(var),
        text = alt.Text(var, format=".2f")
    )
    
    return (bar + text).facet(
        column=column,
        row=row).properties(title=var, )
    

In [None]:
from IPython import display
import vl_convert as vlc

In [None]:
def plot_posdef_simulation(n_s, noise_s, range_s, bs=100, **kwargs):
    rate_posdef = pd.concat([fuzz_posdef(n, noise, bs=bs, range=range, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s for range in range_s])
    
    vl_spec = alt.hconcat(*[_plot_var(rate_posdef, var) for var in ['posdef', 'posdef_eigv', 'sym']]).to_json()
    # workaround for bug in vegalite see https://github.com/altair-viz/altair/issues/2742
    svg = vlc.vegalite_to_svg(vl_spec, vl_version='v5.3')
    display.display(display.HTML(svg))

In [None]:
plot_posdef_simulation(n_s = [10, 100], noise_s = [1e-2, 1e-5, 1e-7], range_s = [(-1, 1)], bs=1)

In [None]:
plot_posdef_simulation(n_s = [10, 100], noise_s = [1e-2, 1e-5, 1e-7], range_s = [(-1, 1)], bs=1000)

Let's go big by using a matrix `1000x1000`

In [None]:
plot_posdef_simulation(n_s = [1000], noise_s = [1e-3, 1e-4, 1e-5], range_s = [(-.1, 1)], bs=100)

In [None]:
plot_posdef_simulation(n_s = [700], noise_s = [1e-3, 1e-4, 1e-5], range_s = [(-.1, 1)], bs=100)

In [None]:
plot_posdef_simulation(n_s = [700], noise_s = [1e-4, 1e-5], range_s = [(-.1, 1)], bs=100, dtype=torch.float64)

check is multiplication of matrices is not breaking the positive definite constraint

If $A$ and $B$ are both positive definite matrices $ABA$ is also positive definite
[https://en.wikipedia.org/wiki/Definite_matrix#Multiplication](https://en.wikipedia.org/wiki/Definite_matrix#Multiplication)

In [None]:
def fuzz_multiply(n=100, # size of matrix
                  max_mult=1000, # number of multiplications
                  noise=1e-5, # noise to add on diagonal
                  bs=10, # batch size
                  n_range=(0,1), # range of random numbers
                  **kwargs):
    A = torch.rand(bs,n,n, **kwargs)  * (n_range[1]-n_range[0]) + n_range[0]
    posdef = PosDef(noise).transform(A)
    
    B = torch.rand(bs,n,n, **kwargs)  * (n_range[1]-n_range[0]) + n_range[0]
    pd2 = PosDef(noise).transform(B)
    
    stop_times = torch.zeros(bs, **kwargs)
    
    for n_mult in torch.arange(max_mult):
        posdef = posdef @ pd2 @ posdef
        check = is_posdef(posdef)
        stop_times[torch.logical_and(stop_times == 0, ~check)] = n_mult
        if not check.any(): break
         
    return pd.DataFrame(
        {'n': [n], 'noise': f"{noise:.0e}", 'range': str(n_range), 'n_samples': bs,
         'avg_stop': stop_times.mean().item(), 'stop_times': [stop_times.cpu().numpy()]}) 

In [None]:
fuzz_multiply()

Unnamed: 0,n,noise,range,n_samples,avg_stop,stop_times
0,100,1e-05,"(0, 1)",10,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [None]:
def plot_multiply_simulation(n_s, noise_s, max_mult=1000, bs=100, **kwargs):
    mult = pd.concat([fuzz_multiply(n=n, noise=noise, bs=bs, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s])
    
    return _plot_var(mult, var='avg_stop', x='n:N', y_domain=alt.Undefined, height=200, width=250)

In [None]:
plot_multiply_simulation(n_s=[1,2,3,10, 100], noise_s=[1e-3, 1e-4, 1e-5], bs=100, dtype=torch.float64)

torch jit is just useless ...

In [None]:
@torch.jit.script
def loop_mult_jit(A, B):
    n = 0
    while n<1000:
        A = A @ B @ A
        n += 1

In [None]:
def loop_mult(A, B):
    n = 0
    while n<1000:
        A = A @ B @ A
        n += 1

In [None]:
%timeit loop_mult(A,A)

53.1 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
%timeit loop_mult_jit(A,A)

51.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### Check pos def

This is to help finding matrices that aren't positive definite and debug the issues.
Returns a detailed dataframe row with info about the matrix and optionally logs everything to a global object

In [None]:
#| export
from warnings import warn
from fastcore.basics import store_attr

In [None]:
#| export
class CheckPosDef():
    def __init__(self,
                do_check:bool = False, # set to True to actually check matrix
                use_log:bool = True, # keep internal log
                warning:bool = True, # show a warning if a matrix is not pos def 
                ):
        store_attr()
        self.log = pd.DataFrame()
        self.extra_args = {}
    def add_args(self, **kwargs):
        """Add an extra argument to the next call of check_posdef """
        self.extra_args = {**kwargs, **self.extra_args}
        return self
    
    def check(self,
              x: Tensor, # (batch of) square matrix
              **extra_args
             ) -> pd.DataFrame:
        
        if not self.do_check: return
        
        self.add_args(**extra_args)
        
        x = x if x.dim() > 2 else [x]
        infos = pd.concat([*map(self._check_matrix, x)])
        
        if self.use_log: self.log = pd.concat([self.log, infos])
        if self.warning and (~infos['is_pd_eigv'].all() or ~infos['is_pd_chol'].all()):
             warn("Matrix is not positive definite")
        
        self.extra_args = {} 
        return infos
    
    def _check_matrix(self,
                     x: Tensor # square matrix
                    ) -> pd.DataFrame:
        
        x = x.detach().cpu().clone() # free GPU memory and ensure that there is a copy
        sym_upto = symmetric_upto(x)

        is_pd_eigv, eigv = is_posdef_eigv(x)
        is_pd_chol = torch.linalg.cholesky_ex(x).info.eq(0).all().item() # skip pytorch too strict symmetry check
        is_sym = is_symmetric(x)

        info = pd.DataFrame({
            'is_pd_eigv': is_pd_eigv,
            'is_pd_chol': is_pd_chol,
            'is_sym': is_sym,
            'sym_upto': sym_upto,
            'eigv': [eigv.detach().numpy()],
            'matrix': [x.detach().numpy()],
            **self.extra_args
        })

        return info

In [None]:
CheckPosDef(True).check(A)

In [None]:
checker = CheckPosDef(True)

checker.check(A, my_arg="my arg") # this will be another col in the log

In [None]:
checker.log

In [None]:
checker.add_args(show="only once")
checker.check(posdef)
checker.check(A)
checker.log

In [None]:
B = torch.rand(2,3,3) # a batch of matrices

In [None]:
checker.check(B)

In [None]:
test_close(B[0] @ A, (B @ A)[0]) # example batched matrix multiplication

## Conditional Predictions

Therefore we need to compute the conditional distribution of a normal ^[https://cs.nyu.edu/~roweis/notes/gaussid.pdf eq, 5a, 5d]

$$ X = \left[\begin{array}{c} x \\ o \end{array} \right] $$

$$ p(X) = N\left(\left[ \begin{array}{c} \mu_x \\ \mu_o \end{array} \right], \left[\begin{array}{cc} \Sigma_{xx} & \Sigma_{xo} \\ \Sigma_{ox} & \Sigma_{oo} \end{array} \right]\right)$$

where $x$ is a vector of variable that need to predicted and $o$ is a vector of the variables that have been observed


then the conditional distribution is:

$$p(x|o) = N(\mu_x + \Sigma_{xo}\Sigma_{oo}^{-1}(o - \mu_o), \Sigma_{xx} - \Sigma_{xo}\Sigma_{oo}^{-1}\Sigma_{ox})$$

In [None]:
#| export
import torch
from torch.distributions import MultivariateNormal
from torch.linalg import cholesky
from torch import cholesky_inverse
from torch import Tensor

from fastcore.test import *
from meteo_imp.utils import *
from typing import List

In [None]:
#| export
def conditional_guassian(
                         μ: Tensor, # mean with shape `[n_vars]`
                         Σ: Tensor, # cov with shape `[n_vars, n_vars] `
                         obs: Tensor, # Observations with shape `[n_obs]`, where `n_obs = sum(idx)`
                         mask: Tensor # Boolean tensor specifying for each variable is observed (True) or not (False). Shape `[n_vars]`
                        ) -> ListMNormal: # Distribution conditioned on observations. shape `[n_vars - n_obs]`
    assert μ.shape[0] == mask.shape[0]
    assert obs.shape[0] == sum(mask)
    
    μ_x = μ[~mask]
    μ_o = μ[mask]
    # the double square brackets `:][:` are needed to keep the dimensionality even for empty tensors 
    Σ_xx = Σ[~mask,:][:, ~mask]
    Σ_xo = Σ[~mask,:][:,  mask]
    Σ_ox = Σ[ mask,:][:, ~mask]
    Σ_oo = Σ[ mask,:][:,  mask]
    
    Σ_oo_inv = cholesky_inverse(cholesky(Σ_oo))
    
    
    mean = μ_x + Σ_xo@Σ_oo_inv@(obs - μ_o)
    cov = Σ_xx - Σ_xo@Σ_oo_inv@Σ_ox
    
    return ListMNormal(mean, cov)
    

In [None]:
# example distribution with only 2 variables
μ = torch.tensor([.5, 1.])
Σ = torch.tensor([[1., .5], [.5 ,1.]])


mask = torch.tensor([True, False]) # second variable is the observed one

obs = torch.tensor([5.]) # value of second variable

gauss_cond = conditional_guassian(μ, Σ, obs, mask)

# hardcoded values to test that the code is working, see also for alternative implementation https://python.quantecon.org/multivariate_normal.html
test_close(3.25, gauss_cond.mean.item())
test_close(.75, gauss_cond.cov.item())

### Batches

cannot have proper batch support, or at least not in a straigthforward way as the shape of the output would be different for the different batches.

so using a for-loop to temporarly fix the situation

In [None]:
#| export
def cond_gaussian_batched(dist: ListMNormal,
                         obs, # this needs to have the same shape of the mask !!! 
                         mask
                         ) -> List[ListMNormal]: # lists of distributions for element in the batch
    return [conditional_guassian(dist.mean[i], dist.cov[i], obs[i][mask[i]], mask[i]) for i in range(obs.shape[0])]
        

In [None]:
reset_seed(10)
mean = torch.rand(2,3) # batch
cov = to_posdef(torch.rand(2,3,3))
mask = torch.rand(2,3) > .3
obs = torch.rand(2,3)

In [None]:
conditional_gaussian_batched(mean, cov, obs, mask)

In [None]:
mask.shape, obs.shape

In [None]:
assert mean.shape == mask.shape
assert mean.dim() == 2

In [None]:
obs.shape

In [None]:
mean_x = mean[~mask]
mean_o = mean[mask]

In [None]:
mask

In [None]:
mean_x

In [None]:
cov.shape

In [None]:
cov[~mask]

In [None]:
cov

In [None]:
cov[0][~mask[0], ~mask[0]]

In [None]:
cov[0][mask[0],:][:, mask[0]].shape

### Performance

analysis of the performance of inverting a positive definite matrix

Use `cholesky` decomposition and `cholesky_solve` to improve performance of matrix inversion

see the [Probabilist machine learning course from uni Tübigen](https://uni-tuebingen.de/en/180804), specifically the code from the [Gaussian Regression Notebook](https://uni-tuebingen.de/fileadmin/Uni_Tuebingen/Fakultaeten/MatNat/Fachbereiche/Informatik/Lehrstuehle/MethMaschLern/Probabilistic_ML/Notebook_Vorlesung_7___9/Gaussian_Linear_Regression.ipynb) for details

This is the direct implementation of the equations

In [None]:
def _conditional_guassian_base(
                         μ: Tensor, # mean with shape `[n_vars]`
                         Σ: Tensor, # cov with shape `[n_vars, n_vars] `
                         obs: Tensor, # Observations with shape `[n_vars]`
                         idx: Tensor # Boolean tensor specifying for each variable is observed (True) or not (False). Shape `[n_vars]`
                        ) -> ListNormal: # Distribution conditioned on observations
    μ_x = μ[~idx]
    μ_o = μ[idx]
    
    Σ_xx = Σ[~idx,:][:, ~idx]
    Σ_xo = Σ[~idx,:][:, idx]
    Σ_ox = Σ[idx,:][:, ~idx]
    Σ_oo = Σ[idx,:][:, idx]
    
    Σ_oo_inv = torch.linalg.inv(Σ_oo)
    
    mean = μ_x + Σ_xo@Σ_oo_inv@(obs - μ_o)
    cov = Σ_xx - Σ_xo@Σ_oo_inv@Σ_ox
    
    return ListNormal(mean, cov)
    

 faster version

In [None]:
n_var = 5
mean = torch.rand(n_var, dtype=torch.float64)
cov = to_posdef(torch.rand(n_var, n_var, dtype=torch.float64))
dist = MultivariateNormal(mean, cov)
idx = torch.rand(n_var, dtype=torch.float64) > .5
obs = torch.rand(n_var, dtype=torch.float64)[idx]

In [None]:
torch.linalg.inv(cov) 

In [None]:
(torch.linalg.inv(cov) - cholesky_inverse(torch.linalg.cholesky(cov))).max()

In [None]:
test_close(torch.linalg.inv(cov), cholesky_inverse(torch.linalg.cholesky(cov)), eps=1e-2)

In [None]:
reset_seed()
A = to_posdef(torch.rand(1000, 1000, dtype=torch.float64)) + torch.eye(1000) * 1e-3 # noise to ensure is positive definite

In [None]:
is_symmetric(A)

In [None]:
is_posdef(A)

In [None]:
%timeit torch.linalg.inv(A)

In [None]:
%timeit cholesky_inverse(torch.linalg.cholesky(A))

The second version is way faster

In [None]:
test_close(conditional_guassian(mean, cov, obs, idx).mean, _conditional_guassian_base(mean, cov, obs, idx).mean)

In [None]:
B = to_posdef(torch.rand(n_var, n_var, dtype=torch.float64))

In [None]:
B @ torch.inverse(cov)

In [None]:
torch.cholesky_solve(cholesky(cov), B)

## Helper

### cov2std

In [None]:
x = torch.stack([torch.eye(3)*i for i in  range(1,4)])

In [None]:
x

In [None]:
torch.diagonal(x, dim1=1, dim2=2)

In [None]:
#| export
def cov2std(x):
    "convert cov of array of covariances to array of stddev"
    return torch.sqrt(torch.diagonal(x, dim1=-2, dim2=-1))

## Export

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()