# Gaussian Distributions Utils
> tools to work with gaussian distributions

In [None]:
#| hide
#| default_exp gaussian

In [None]:
from fastcore.test import *

In [None]:
import altair as alt

## Normal Parameters

In [None]:
#| export
from collections import namedtuple
from fastcore.basics import patch

### Normal

In [None]:
import torch

In [None]:
#| export
ListNormal = namedtuple('ListNormal', ['mean', 'std'])

In [None]:
#| export
Normal = namedtuple('Normal', ['mean', 'std'])

In [None]:
#| export
@patch
def __getitem__(self: ListNormal, n:int
           )->Normal:
    """Get the mean and cov for the nth Normal distribution in the list """
    return Normal(self.mean[n], self.std[n])

In [None]:
#| export
@patch
def detach(self: ListNormal)->ListNormal:
    """Detach both mean and cov at once """
    return ListNormal(self.mean.detach(), self.std.detach())

In [None]:
ListNormal(torch.rand(10), torch.rand(10))[1]

Normal(mean=tensor(0.4322), std=tensor(0.0450))

### Multivariate Normal

In [None]:
#| export
ListMNormal = namedtuple('ListMultiNormal', ['mean', 'cov'])

In [None]:
#| export
MNormal = namedtuple('MultiNormal', ['mean', 'cov'])

In [None]:
#| export
@patch
def __getitem__(self: ListMNormal, n:int
           )->Normal:
    """Get the mean and cov for the nth Normal distribution in the list """
    return MNormal(self.mean[n], self.cov[n])
@patch
def __setitem__(self: ListMNormal, idx, value)->Normal:
    """set the mean and cov for the nth Normal distribution in the list """
    self.mean[idx], self.cov[idx] = value

In [None]:
#| export
@patch
def detach(self: ListMNormal)->ListMNormal:
    """Detach both mean and cov at once """
    return ListMNormal(self.mean.detach(), self.cov.detach())

In [None]:
ListMNormal(torch.rand(2,10), torch.rand(2,10,10))[1]

MultiNormal(mean=tensor([0.0229, 0.4637, 0.1999, 0.9229, 0.1036, 0.5479, 0.4829, 0.2972, 0.9683,
        0.9519]), cov=tensor([[0.0374, 0.5933, 0.5046, 0.2290, 0.7677, 0.5853, 0.0074, 0.3667, 0.6025,
         0.8323],
        [0.1473, 0.9688, 0.2277, 0.1073, 0.8594, 0.6005, 0.5353, 0.9306, 0.1585,
         0.5595],
        [0.4591, 0.4949, 0.6432, 0.0630, 0.6998, 0.2254, 0.9016, 0.7314, 0.9867,
         0.1340],
        [0.9946, 0.6823, 0.8888, 0.3387, 0.9628, 0.6730, 0.4966, 0.8522, 0.3186,
         0.3757],
        [0.4234, 0.1553, 0.3145, 0.0338, 0.3227, 0.2926, 0.5984, 0.6980, 0.2820,
         0.8597],
        [0.3834, 0.1146, 0.0706, 0.0280, 0.0724, 0.6593, 0.0788, 0.0611, 0.8379,
         0.8059],
        [0.5841, 0.7931, 0.8659, 0.2980, 0.7672, 0.9661, 0.6043, 0.1415, 0.9610,
         0.4112],
        [0.6061, 0.2894, 0.0213, 0.1567, 0.8274, 0.6500, 0.9066, 0.3175, 0.2687,
         0.3781],
        [0.1840, 0.7383, 0.7131, 0.8098, 0.1568, 0.8432, 0.5871, 0.9134, 0.2612,
        

## Positive Definite

The covariance matrices need to be [positive definite](https://en.wikipedia.org/wiki/Definite_matrix)
Those are utilities functions to check is a matrix is positive definite and to make any matrix positive definite

#### Other libraries

Most libraries that implement Kalman Filters use manually specified parameters, which often don't have the issue of the positive definite constraint (eg. `pykalman`)

From `statsmodels` statespace models:
>Cholesky decomposition [...] requires that the matrix be positive definite. While this
          should generally be true, it may not be in every case. [source](https://www.statsmodels.org/stable/generated/statsmodels.tsa.statespace.kalman_filter.KalmanFilter.set_inversion_method.html#statsmodels.tsa.statespace.kalman_filter.KalmanFilter.set_inversion_method)

which seems to mean that they take into account the fact that during the filter calculations may not be positive definite



In [None]:
#| export
import pandas as pd
from torch import Tensor

In [None]:
A = torch.rand(2,3,3) # batched random matrix used for testing

#### Symmetry

In [None]:
#| export
def is_symmetric(value, atol=1e-5):
    return torch.isclose(value, value.mT, atol=atol).all(-1).all(-1)

In [None]:
is_symmetric(A)

tensor([False, False])

In [None]:
#| export
def symmetric_upto(value, start=-8):
    for exp in torch.arange(start, 3):
        if is_symmetric(value, atol=10**exp):
            return exp.item()
    return exp.item()

def symmetric_upto_batched(value, start=-8):
    return torch.tensor([symmetric_upto(v) for v in value])

In [None]:
symmetric_upto_batched(A)

tensor([0, 0])

#### is posdef

Default pytorch check (uses symmetry + cholesky decomposition)

In [None]:
#| export
def is_posdef(cov):
    return torch.distributions.constraints.positive_definite.check(cov)

In [None]:
is_posdef(A)

tensor([False, False])

check if it is pos definite using eigenvalues. Positive definite matrix have all positive eigenvalues

In [None]:
torch.linalg.eigvalsh(A)

tensor([[ 0.1256,  0.4735,  1.0205],
        [-0.4617,  0.2060,  1.6219]])

In [None]:
#| export
def is_posdef_eigv(cov):
    eigv = torch.linalg.eigvalsh(cov)
    return eigv.ge(0).all(-1), eigv

In [None]:
is_posdef_eigv(A)

(tensor([ True, False]),
 tensor([[ 0.1256,  0.4735,  1.0205],
         [-0.4617,  0.2060,  1.6219]]))

### Pytorch constraint

Note that `is_posdef` and `is_posdef_eigv` can return different values, in general `is_posdef_eigv` is more tollerant

transform any matrix $A$ into a positive definite matrix ($PD$) using the following formula

$PD = AA^T + aI$ 

where $AA^T$ is a positive semi-definite matrix and $a$ is a small positive number that is added on the diagonal to ensure that the resulting matrix is positive definite (not semi-definite)

the inverse transformation uses cholesky decomposition


Another approach would be to multiple to lower triangular matrix, but they'd require a positive diagonal, which is harderd to obtain see [https://en.wikipedia.org/wiki/Definite_matrix#Cholesky_decomposition](https://en.wikipedia.org/wiki/Definite_matrix#Cholesky_decomposition)

The API inspired by gpytorch constraints

In [None]:
#| export
from fastcore.foundation import docs
from fastcore.basics import store_attr

In [None]:
from meteo_imp.utils import *

In [None]:
#| export
@docs
class PosDef(): 
    def __init__(self, noise=1e-5): store_attr()
    def transform(self,raw): return raw @ raw.mT + self.noise * torch.eye(raw.shape[-1], device=raw.device, dtype=raw.dtype)
    def inverse_transform(self, value): return torch.linalg.cholesky(value)
    
    _docs = {'cls_doc': "Positive Definite Constraint for PyTorch parameters",
             'transform':"transform any square matrix into a positive definite one",
             'inverse_transform': "tranform positive definite matrix into a matrix that can be back_transformed using `transform`"}

to_posdef = PosDef().transform

In [None]:
constraint = PosDef()

posdef = constraint.transform(A)

In [None]:
constraint.noise

1e-05

In [None]:
A

tensor([[[0.9268, 0.5991, 0.2592],
         [0.0158, 0.4280, 0.4779],
         [0.2563, 0.1366, 0.2649]],

        [[0.7287, 0.4174, 0.5989],
         [0.8998, 0.5918, 0.8489],
         [0.0034, 0.4378, 0.0456]]])

In [None]:
posdef

tensor([[[1.2850, 0.3949, 0.3880],
         [0.3949, 0.4118, 0.1891],
         [0.3880, 0.1891, 0.1546]],

        [[1.0639, 1.4111, 0.2125],
         [1.4111, 1.8805, 0.3009],
         [0.2125, 0.3009, 0.1938]]])

In [None]:
show_as_row(is_posdef(torch.stack([posdef,A])), is_posdef_eigv(torch.stack([posdef,A])), is_symmetric(torch.stack([posdef,A])))

In [None]:
test_eq(is_posdef(posdef).all(), True)

In [None]:
constraint.inverse_transform(posdef)

tensor([[[1.1336, 0.0000, 0.0000],
         [0.3483, 0.5390, 0.0000],
         [0.3423, 0.1296, 0.1434]],

        [[1.0315, 0.0000, 0.0000],
         [1.3681, 0.0944, 0.0000],
         [0.2060, 0.2012, 0.3330]]])

In [None]:
test_close(posdef, constraint.transform(constraint.inverse_transform(posdef)), eps=2e-5)

In [None]:
symmetric_upto(posdef[0])

-8

In [None]:
reset_seed()

### Fuzzer

In [None]:
run_fuzzer = True # temporly disable for performance reasons

In [None]:
def random_posdef(bs=10,n=100,noise=1e-5,n_range=(0,1), **kwargs):
    A = torch.rand(bs,n,n, **kwargs)  * (n_range[1]-n_range[0]) + n_range[0]
    return PosDef(noise).transform(A)

In [None]:
# fuzzer
def fuzz_posdef(bs=10,n=100,noise=1e-5,n_range=(0,1), **kwargs):
    posdef = random_posdef(bs, n, noise, **kwargs)
    return pd.DataFrame(
        {'n': [n], 'noise': f"{noise:.0e}", 'range': str(n_range), 'n_samples': bs,
         'posdef': is_posdef(posdef).sum().item() / bs,
         'sym': is_symmetric(posdef).sum().item() / bs, 
         'posdef_eigv': is_posdef_eigv(posdef)[0].sum().item() / bs
    })

In [None]:
fuzz_posdef()

Unnamed: 0,n,noise,range,n_samples,posdef,sym,posdef_eigv
0,100,1e-05,"(0, 1)",10,1.0,1.0,1.0


In [None]:
n_min, n_max = -1, 1
A = torch.rand(2,100,100)  * (n_max-n_min) + n_min

In [None]:
is_posdef(to_posdef(A))

tensor([True, True])

In [None]:
ma = torch.tensor([[1, 7],
                   [-3, 4]])

In [None]:
is_posdef(to_posdef(ma))

tensor(True)

In [None]:
fuzz_posdef(device='cuda')

Unnamed: 0,n,noise,range,n_samples,posdef,sym,posdef_eigv
0,100,1e-05,"(0, 1)",10,1.0,1.0,1.0


In [None]:
# %time fuzz_posdef(bs=100, device='cuda')

In [None]:
rate_posdef = pd.concat([fuzz_posdef(n=n, noise=noise, bs=100, n_range=n_range, device='cuda') 
               for n in [10, 100]
               for noise in [1e-2, 1e-5, 1e-7]
               for n_range in [(-1,1),(0,1)]])

In [None]:
import altair as alt
from altair import datum

In [None]:
rate_posdef.head()

Unnamed: 0,n,noise,range,n_samples,posdef,sym,posdef_eigv
0,10,0.01,"(-1, 1)",100,1.0,1.0,1.0
0,10,0.01,"(0, 1)",100,1.0,1.0,1.0
0,10,1e-05,"(-1, 1)",100,1.0,1.0,1.0
0,10,1e-05,"(0, 1)",100,1.0,1.0,1.0
0,10,1e-07,"(-1, 1)",100,1.0,1.0,1.0


In [None]:
def _plot_var(df, var, x='n:N', row='range', column='noise:N', y_domain=(0,1), height=70, width=50):
    bar = alt.Chart(df).mark_bar().encode(
        x = alt.X('n:N'),
        y = alt.Y(var, scale=alt.Scale(domain=y_domain)),
        color = 'n:N',
    ).properties(height=height, width=width, ) 
    
    text = alt.Chart(df).mark_text(dy=10, color='white').encode(
        x = alt.X('n:N'),
        y = alt.Y(var),
        text = alt.Text(var, format=".2f")
    )
    
    return (bar + text).facet(
        column=column,
        row=row).properties(title=var, )
    

In [None]:
def _plot_var_box(df, var, x='n:N', row='range', column='noise:N', height=70, width=50, title=''):
    box = alt.Chart(df).mark_boxplot().encode(
        x = alt.X(x),
        y = alt.Y(var),
        color = x,
    ).properties(height=height, width=width) 

    # text = alt.Chart(df).mark_text(dy=10, color='white').encode(
    #     x = alt.X('n:N'),
    #     y = alt.Y(var),
    #     text = alt.Text(var, format=".2f")
    # )
    
    return (box).facet(
        column=column,
        row=row).properties(title=title)


In [None]:
from IPython import display
import vl_convert as vlc
from functools import partial

#### Generation of Random positive definite matrices 

In [None]:
def plot_posdef_simulation(n_s, noise_s, range_s, bs=100, **kwargs):
    if not run_fuzzer: return
    rate_posdef = pd.concat([fuzz_posdef(n=n, noise=noise, bs=bs, n_range=range, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s for range in range_s])
    
    vl_spec = alt.hconcat(*[_plot_var(rate_posdef, var) for var in ['posdef', 'posdef_eigv']]).to_json()
    # workaround for bug in vegalite see https://github.com/altair-viz/altair/issues/2742
    svg = vlc.vegalite_to_svg(vl_spec, vl_version='v5.3')
    display.display(display.HTML(svg))

In [None]:
plot_posdef_simulation(n_s = [10, 100], noise_s = [1e-2, 1e-5, 1e-7], range_s = [(-1, 1)], bs=1000)

Let's go big by using a matrix `1000x1000`

In [None]:
plot_posdef_simulation(n_s = [1000], noise_s = [1e-3, 1e-4, 1e-5], range_s = [(-.1, 1)], bs=100)

for a standard noise on the diagonal less than half of the random matrices that are 1000 in size are positive definite.

Let's have a look at one of such matrices

In [None]:
posdef = random_posdef(100, 1000)
not_pd = posdef[torch.argwhere(~is_posdef_eigv(posdef)[0])[0]]

This should be positive definite but actually it's not ...

In [None]:
not_pd

tensor([[[335.6587, 255.2734, 259.6248,  ..., 246.4368, 257.4700, 255.6423],
         [255.2734, 339.6811, 265.9074,  ..., 248.7385, 259.2237, 258.6045],
         [259.6248, 265.9074, 348.2930,  ..., 256.4056, 264.2218, 258.5888],
         ...,
         [246.4368, 248.7385, 256.4056,  ..., 322.1650, 248.4815, 249.9758],
         [257.4700, 259.2237, 264.2218,  ..., 248.4815, 336.6406, 257.6953],
         [255.6423, 258.6045, 258.5888,  ..., 249.9758, 257.6953, 341.0876]]])

trying with `float64` (for memory constraint on the GPU only using a `700x700` matrix)

In [None]:
plot_posdef_simulation(n_s = [700], noise_s = [1e-3, 1e-4, 1e-5], range_s = [(-.1, 1)], bs=100)

In [None]:
plot_posdef_simulation(n_s = [700], noise_s = [1e-4, 1e-5], range_s = [(-.1, 1)], bs=100, dtype=torch.float64)

All matrices now are positive definite

#### Multiplication

check is multiplication of matrices is not breaking the positive definite constraint

If $A$ and $B$ are both positive definite matrices $ABA$ is also positive definite
[https://en.wikipedia.org/wiki/Definite_matrix#Multiplication](https://en.wikipedia.org/wiki/Definite_matrix#Multiplication)

In [None]:
def fuzz_op(op, # operation that takes 2 pos def matrices and return one pos def matrix
            fn_check = is_posdef,
                  n=100, # size of matrix
                  max_t=1000, # number of multiplications
                  noise=1e-5, # noise to add on diagonal
                  bs=10, # batch size
                  n_range=(0,1), # range of random numbers
                  **kwargs):
    pd1 = random_posdef(bs, n, noise, n_range, **kwargs)
    pd2 = random_posdef(bs, n, noise, n_range,**kwargs)
    stop_times = torch.zeros(bs, **kwargs)
    
    for t in torch.arange(max_t):
        pd1 = op(pd1, pd2)
        check = fn_check(pd1)
        stop_times[torch.logical_and(stop_times == 0, ~check)] = t
        if not check.any(): break
         
    stop_times[stop_times == 0] = t
    return pd.DataFrame(
        {'n': [n], 'noise': f"{noise:.0e}", 'range': str(n_range), 'n_samples': bs, 'last_t': t.item(),
         'mean_stop': stop_times.mean().item(),
         'std_stop': stop_times.std().item(),
         'stop_times': [stop_times.cpu().numpy()]})

In [None]:
fuzz_multiply = partial(fuzz_op, lambda pd1, pd2: pd2 @ pd1 @ pd2)
fuzz_multiply_eigv = partial(fuzz_multiply, fn_check = lambda pd1: is_posdef_eigv(pd1)[0])

In [None]:
def plot_multiply_simulation(n_s, noise_s, max_mult=1000, bs=100, **kwargs):
    mult = pd.concat([fuzz_multiply(n=n, noise=noise, bs=bs, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s]).explode('stop_times')
    
    mult_eigv = pd.concat([fuzz_multiply(n=n, noise=noise, bs=bs, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s]).explode('stop_times')
    
    vl_spec = alt.hconcat(*[_plot_var_box(df, 'stop_times') for df in [mult, mult_eigv]]).to_json()
    # workaround for bug in vegalite see https://github.com/altair-viz/altair/issues/2742
    svg = vlc.vegalite_to_svg(vl_spec, vl_version='v5.3')
    display.display(display.HTML(svg))
    return (mult, mult_eigv)

In [None]:
plot_multiply_simulation(n_s=[2,3,10, 100], noise_s=[1e-3, 1e-4, 1e-5], bs=100);

#### Addition

check is multiplication of matrices is not breaking the positive definite constraint

If $A$ and $B$ are both positive definite matrices $A+B$ is also positive definite
[https://en.wikipedia.org/wiki/Definite_matrix#Addition](https://en.wikipedia.org/wiki/Definite_matrix#Addition)

In [None]:
pd1 = random_posdef(10, 100)
pd2 = random_posdef(10, 100)

In [None]:
is_posdef(pd1 + pd2).all()

tensor(True)

In [None]:
fuzz_add = partial(fuzz_op, lambda pd1, pd2: pd1 + pd2)

In [None]:
%time fuzz_add(max_t=1e5, device='cuda')

CPU times: user 5.06 s, sys: 43.7 ms, total: 5.1 s
Wall time: 4.11 s


Unnamed: 0,n,noise,range,n_samples,last_t,mean_stop,std_stop,stop_times
0,100,1e-05,"(0, 1)",10,12471.0,8596.5,2447.240234,"[12417.0, 9269.0, 12471.0, 6977.0, 8576.0, 748..."


In [None]:
def plot_add_simulation(n_s, noise_s, max_ts=[1000], bs=100, **kwargs):
    add = pd.concat([fuzz_add(n=n, noise=noise, bs=bs, max_t=max_t, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s for max_t in max_ts]).explode('stop_times')
    
    vl_spec = _plot_var_box(add, var='stop_times', height=150, width=150).to_json()

    svg = vlc.vegalite_to_svg(vl_spec, vl_version='v5.3')
    display.display(display.HTML(svg))

In [None]:
cache_disk("add_plot")(lambda: plot_add_simulation(n_s=[50, 100, 150], noise_s=[1e-3, 1e-4, 1e-5], bs=100, max_ts=[1e5]))()

#### Numpy posdef

In [None]:
import numpy as np

In [None]:
arr = np.random.rand(2,3,3)

In [None]:
arr.shape

In [None]:
arr.transpose(0,2,1) == np.moveaxis(arr, -1, -2)

In [None]:
def to_posdef_np(x, noise=1e-5):
    return x @ np.moveaxis(x, -1, -2) + (noise * np.eye(x.shape[-1], dtype=arr.dtype))

In [None]:
to_posdef_np(arr)

In [None]:
# fuzzer
def fuzz_posdef_np(n=100, noise=1e-5, bs=10, range=(0,1), dtype=np.float32):
    A = np.random.rand(bs,n,n).astype(dtype)  * (range[1]-range[0]) + range[0]
    posdef = torch.from_numpy(to_posdef_np(A, noise))
    return pd.DataFrame(
        {'n': [n], 'noise': f"{noise:.0e}", 'range': str(range), 'n_samples': bs,
         'posdef': is_posdef(posdef).sum().item() / bs,
         'sym': is_symmetric(posdef).sum().item() / bs, 
         'posdef_eigv': is_posdef_eigv(posdef)[0].sum().item() / bs
    })

In [None]:
fuzz_posdef_np(n=1000, dtype=np.float32)

### Checker Positive Definite

This is to help finding matrices that aren't positive definite and debug the issues.
Returns a detailed dataframe row with info about the matrix and optionally logs everything to a global object

In [None]:
#| export
from warnings import warn
from fastcore.basics import store_attr

In [None]:
#| export
class CheckPosDef():
    def __init__(self,
                do_check:bool = False, # set to True to actually check matrix
                use_log:bool = True, # keep internal log
                warning:bool = True, # show a warning if a matrix is not pos def 
                ):
        store_attr()
        self.log = pd.DataFrame()
        self.extra_args = {}
    def add_args(self, **kwargs):
        """Add an extra argument to the next call of check_posdef """
        self.extra_args = {**kwargs, **self.extra_args}
        return self
    
    def check(self,
              x: Tensor, # (batch of) square matrix
              **extra_args
             ) -> pd.DataFrame:
        
        if not self.do_check: return
        
        self.add_args(**extra_args)
        
        x = x if x.dim() > 2 else [x]
        infos = pd.concat([*map(self._check_matrix, x)])
        
        if self.use_log: self.log = pd.concat([self.log, infos])
        if self.warning and (~infos['is_pd_eigv'].all() or ~infos['is_pd_chol'].all()):
             warn("Matrix is not positive definite")
        
        self.extra_args = {} 
        return infos
    
    def _check_matrix(self,
                     x: Tensor # square matrix
                    ) -> pd.DataFrame:
        
        x = x.detach().cpu().clone() # free GPU memory and ensure that there is a copy
        sym_upto = symmetric_upto(x)

        is_pd_eigv, eigv = is_posdef_eigv(x)
        is_pd_chol = torch.linalg.cholesky_ex(x).info.eq(0).all().item() # skip pytorch too strict symmetry check
        is_sym = is_symmetric(x).item()

        info = pd.DataFrame({
            'is_pd_eigv': is_pd_eigv.item(),
            'is_pd_chol': is_pd_chol,
            'is_sym': is_sym,
            'sym_upto': sym_upto,
            'eigv': [eigv.detach().numpy()],
            'matrix': [x.detach().numpy()],
            **self.extra_args
        })

        return info

In [None]:
CheckPosDef(True).check(A)

In [None]:
CheckPosDef(True).check(A[0])

In [None]:
checker = CheckPosDef(True)

checker.check(A, my_arg="my arg") # this will be another col in the log

In [None]:
checker.log

In [None]:
checker.add_args(show="only once")
checker.check(posdef)
checker.check(A)
checker.log

In [None]:
B = torch.rand(2,3,3) # a batch of matrices

In [None]:
checker.check(B)

In [None]:
test_close(B[0] @ A, (B @ A)[0]) # example batched matrix multiplication

### Diagonal Positive Definite Contraint

this is a simpler contraint that make the matrix diagonal and positive definite, by forcing it to have positive numbers on the diagonal.

given a random vector $a$ it is transformed into a diagonal positive definite matrix using:

$A_{diag\ pos\ def} = e^a I$

the inverse transformation is the log of the diagonal

In [None]:
#| export
@docs
class DiagPosDef(): 
    def transform(self,raw): return torch.diag(torch.exp(raw))
    def inverse_transform(self, value): return torch.log(torch.diag(value))
    
    _docs = {'cls_doc': "Diagonal Positive Definite Constraint for PyTorch parameters",
             'transform':"transform any vector into a diagonal positive definite matrix",
             'inverse_transform': "tranform diagonal positive definite matrix into a vector that can be back transformed using `transform`"}

to_diagposdef = DiagPosDef().transform

In [None]:
dpd_const = DiagPosDef()
a = torch.rand(3)

In [None]:
dpd_const.transform(a)

tensor([[1.1899, 0.0000, 0.0000],
        [0.0000, 2.0571, 0.0000],
        [0.0000, 0.0000, 1.4777]])

In [None]:
test_close(a, dpd_const.inverse_transform(dpd_const.transform(a)))

## Conditional Predictions

Therefore we need to compute the conditional distribution of a normal ^[https://cs.nyu.edu/~roweis/notes/gaussid.pdf eq, 5a, 5d]

$$ X = \left[\begin{array}{c} x \\ o \end{array} \right] $$

$$ p(X) = N\left(\left[ \begin{array}{c} \mu_x \\ \mu_o \end{array} \right], \left[\begin{array}{cc} \Sigma_{xx} & \Sigma_{xo} \\ \Sigma_{ox} & \Sigma_{oo} \end{array} \right]\right)$$

where $x$ is a vector of variable that need to predicted and $o$ is a vector of the variables that have been observed


then the conditional distribution is:

$$p(x|o) = N(\mu_x + \Sigma_{xo}\Sigma_{oo}^{-1}(o - \mu_o), \Sigma_{xx} - \Sigma_{xo}\Sigma_{oo}^{-1}\Sigma_{ox})$$

In [None]:
#| export
import torch
from torch.distributions import MultivariateNormal
from torch.linalg import cholesky
from torch import cholesky_inverse
from torch import Tensor

from fastcore.test import *
from meteo_imp.utils import *
from typing import List

In [None]:
#| export
def conditional_guassian(
                         μ: Tensor, # mean with shape `[n_vars]`
                         Σ: Tensor, # cov with shape `[n_vars, n_vars] `
                         obs: Tensor, # Observations with shape `[n_obs]`, where `n_obs = sum(idx)`
                         mask: Tensor # Boolean tensor specifying for each variable is observed (True) or not (False). Shape `[n_vars]`
                        ) -> ListMNormal: # Distribution conditioned on observations. shape `[n_vars - n_obs]`
    assert μ.shape[0] == mask.shape[0]
    assert obs.shape[0] == sum(mask)
    
    μ_x = μ[~mask]
    μ_o = μ[mask]
    # the double square brackets `:][:` are needed to keep the dimensionality even for empty tensors 
    Σ_xx = Σ[~mask,:][:, ~mask]
    Σ_xo = Σ[~mask,:][:,  mask]
    Σ_ox = Σ[ mask,:][:, ~mask]
    Σ_oo = Σ[ mask,:][:,  mask]
    
    Σ_oo_inv = torch.inverse(Σ_oo) # cholesky_inverse(cholesky(Σ_oo))
    
    
    mean = μ_x + Σ_xo@Σ_oo_inv@(obs - μ_o)
    cov = Σ_xx - Σ_xo@Σ_oo_inv@Σ_ox
    
    return ListMNormal(mean, cov)
    

In [None]:
# example distribution with only 2 variables
μ = torch.tensor([.5, 1.])
Σ = torch.tensor([[1., .5], [.5 ,1.]])


mask = torch.tensor([True, False]) # second variable is the observed one

obs = torch.tensor([5.]) # value of second variable

gauss_cond = conditional_guassian(μ, Σ, obs, mask)

# hardcoded values to test that the code is working, see also for alternative implementation https://python.quantecon.org/multivariate_normal.html
test_close(3.25, gauss_cond.mean.item())
test_close(.75, gauss_cond.cov.item())

### Batches

cannot have proper batch support, or at least not in a straigthforward way as the shape of the output would be different for the different batches.

so using a for-loop to temporarly fix the situation

In [None]:
#| export
def cond_gaussian_batched(dist: ListMNormal,
                         obs, # this needs to have the same shape of the mask !!! 
                         mask
                         ) -> List[ListMNormal]: # lists of distributions for element in the batch
    return [conditional_guassian(dist.mean[i], dist.cov[i], obs[i][mask[i]], mask[i]) for i in range(obs.shape[0])]
        

In [None]:
reset_seed(10)
mean = torch.rand(2,3) # batch
cov = to_posdef(torch.rand(2,3,3))
mask = torch.rand(2,3) > .3
obs = torch.rand(2,3)

In [None]:
conditional_gaussian_batched(mean, cov, obs, mask)

In [None]:
mask.shape, obs.shape

In [None]:
assert mean.shape == mask.shape
assert mean.dim() == 2

In [None]:
obs.shape

In [None]:
mean_x = mean[~mask]
mean_o = mean[mask]

In [None]:
mask

In [None]:
mean_x

In [None]:
cov.shape

In [None]:
cov[~mask]

In [None]:
cov

In [None]:
cov[0][~mask[0], ~mask[0]]

In [None]:
cov[0][mask[0],:][:, mask[0]].shape

### Performance

analysis of the performance of inverting a positive definite matrix

Use `cholesky` decomposition and `cholesky_solve` to improve performance of matrix inversion

see the [Probabilist machine learning course from uni Tübigen](https://uni-tuebingen.de/en/180804), specifically the code from the [Gaussian Regression Notebook](https://uni-tuebingen.de/fileadmin/Uni_Tuebingen/Fakultaeten/MatNat/Fachbereiche/Informatik/Lehrstuehle/MethMaschLern/Probabilistic_ML/Notebook_Vorlesung_7___9/Gaussian_Linear_Regression.ipynb) for details

This is the direct implementation of the equations

In [None]:
def _conditional_guassian_base(
                         μ: Tensor, # mean with shape `[n_vars]`
                         Σ: Tensor, # cov with shape `[n_vars, n_vars] `
                         obs: Tensor, # Observations with shape `[n_vars]`
                         idx: Tensor # Boolean tensor specifying for each variable is observed (True) or not (False). Shape `[n_vars]`
                        ) -> ListNormal: # Distribution conditioned on observations
    μ_x = μ[~idx]
    μ_o = μ[idx]
    
    Σ_xx = Σ[~idx,:][:, ~idx]
    Σ_xo = Σ[~idx,:][:, idx]
    Σ_ox = Σ[idx,:][:, ~idx]
    Σ_oo = Σ[idx,:][:, idx]
    
    Σ_oo_inv = torch.linalg.inv(Σ_oo)
    
    mean = μ_x + Σ_xo@Σ_oo_inv@(obs - μ_o)
    cov = Σ_xx - Σ_xo@Σ_oo_inv@Σ_ox
    
    return ListNormal(mean, cov)
    

 faster version

In [None]:
n_var = 5
mean = torch.rand(n_var, dtype=torch.float64)
cov = to_posdef(torch.rand(n_var, n_var, dtype=torch.float64))
dist = MultivariateNormal(mean, cov)
idx = torch.rand(n_var, dtype=torch.float64) > .5
obs = torch.rand(n_var, dtype=torch.float64)[idx]

In [None]:
torch.linalg.inv(cov) 

In [None]:
(torch.linalg.inv(cov) - cholesky_inverse(torch.linalg.cholesky(cov))).max()

In [None]:
test_close(torch.linalg.inv(cov), cholesky_inverse(torch.linalg.cholesky(cov)), eps=1e-2)

In [None]:
reset_seed()
A = to_posdef(torch.rand(1000, 1000, dtype=torch.float64)) + torch.eye(1000) * 1e-3 # noise to ensure is positive definite

In [None]:
is_symmetric(A)

In [None]:
is_posdef(A)

In [None]:
%timeit torch.linalg.inv(A)

In [None]:
%timeit cholesky_inverse(torch.linalg.cholesky(A))

The second version is way faster

In [None]:
test_close(conditional_guassian(mean, cov, obs, idx).mean, _conditional_guassian_base(mean, cov, obs, idx).mean)

In [None]:
B = to_posdef(torch.rand(n_var, n_var, dtype=torch.float64))

In [None]:
B @ torch.inverse(cov)

In [None]:
torch.cholesky_solve(cholesky(cov), B)

## Helper

### cov2std

In [None]:
x = torch.stack([torch.eye(3)*i for i in  range(1,4)])

In [None]:
x

In [None]:
torch.diagonal(x, dim1=1, dim2=2)

In [None]:
#| export
def cov2std(x):
    "convert cov of array of covariances to array of stddev"
    return torch.sqrt(torch.diagonal(x, dim1=-2, dim2=-1))

## Export

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()