# Gaussian Distributions Utils
> tools to work with gaussian distributions

In [None]:
#| hide
#| default_exp gaussian

In [None]:
from fastcore.test import *

In [None]:
import altair as alt

## Normal Parameters

In [None]:
#| export
from collections import namedtuple
from dataclasses import dataclass
from fastcore.basics import patch
from torch import Tensor

### Normal

In [None]:
import torch

In [None]:
#| export
ListNormal = namedtuple('ListNormal', ['mean', 'std'])

# class ListNormal:
#     mean: Tensor
#     std: Tensor

In [None]:
#| export
Normal = namedtuple('Normal', ['mean', 'std'])
# @dataclass()
# class Normal:
#     mean: Tensor
#     std: Tensor

In [None]:
#| export
@patch
def __getitem__(self: ListNormal, n:int
           )->Normal:
    """Get the mean and cov for the nth Normal distribution in the list """
    return Normal(self.mean[n], self.std[n])

In [None]:
#| export
@patch
def detach(self: ListNormal)->ListNormal:
    """Detach both mean and cov at once """
    return ListNormal(self.mean.detach(), self.std.detach())

In [None]:
ln = ListNormal(torch.rand(10), torch.rand(10))

In [None]:
ln[5]

Normal(mean=tensor(0.6659), std=tensor(0.7099))

### Multivariate Normal

In [None]:
#| export
ListMNormal = namedtuple('ListMultiNormal', ['mean', 'cov'])
# @dataclass()
# class ListMNormal:
#     mean: Tensor
#     cov: Tensor

In [None]:
#| export
MNormal = namedtuple('MultiNormal', ['mean', 'cov'])
# @dataclass()
# class MNormal:
#     mean: Tensor
#     cov: Tensor

In [None]:
#| export
@patch
def __getitem__(self: ListMNormal, n:int
           )->Normal:
    """Get the mean and cov for the nth Normal distribution in the list """
    return MNormal(self.mean[n], self.cov[n])
@patch
def __setitem__(self: ListMNormal, idx, value)->Normal:
    """set the mean and cov for the nth Normal distribution in the list """
    self.mean[idx], self.cov[idx] = value

In [None]:
#| export
@patch
def detach(self: ListMNormal)->ListMNormal:
    """Detach both mean and cov at once """
    return ListMNormal(self.mean.detach(), self.cov.detach())

In [None]:
ListMNormal(torch.rand(2,10), torch.rand(2,10,10))[1]

MultiNormal(mean=tensor([0.6145, 0.9545, 0.9506, 0.3708, 0.8807, 0.3148, 0.8657, 0.8969, 0.5861,
        0.1759]), cov=tensor([[0.4695, 0.6334, 0.8019, 0.1890, 0.1260, 0.7271, 0.3380, 0.4822, 0.3967,
         0.6806],
        [0.4955, 0.5878, 0.1718, 0.1387, 0.7825, 0.9125, 0.8619, 0.6222, 0.2173,
         0.5443],
        [0.7097, 0.3823, 0.4766, 0.0937, 0.4554, 0.9178, 0.2110, 0.0275, 0.5144,
         0.2217],
        [0.3205, 0.6029, 0.0214, 0.4978, 0.6771, 0.8186, 0.5360, 0.1226, 0.3451,
         0.0198],
        [0.9501, 0.9777, 0.0515, 0.2082, 0.2785, 0.2175, 0.4271, 0.3827, 0.2398,
         0.9729],
        [0.1488, 0.8439, 0.2854, 0.7357, 0.9889, 0.2667, 0.0230, 0.1091, 0.1631,
         0.8153],
        [0.7805, 0.6039, 0.0336, 0.8994, 0.1771, 0.6934, 0.4101, 0.1552, 0.7183,
         0.2132],
        [0.8785, 0.0896, 0.9429, 0.4760, 0.5316, 0.7215, 0.3914, 0.2091, 0.2201,
         0.9366],
        [0.0676, 0.6766, 0.8490, 0.7209, 0.3308, 0.3180, 0.7028, 0.9616, 0.9841,
        

## Positive Definite

The covariance matrices need to be [positive definite](https://en.wikipedia.org/wiki/Definite_matrix)
Those are utilities functions to check is a matrix is positive definite and to make any matrix positive definite

#### Other libraries

Most libraries that implement Kalman Filters use manually specified parameters, which often don't have the issue of the positive definite constraint (eg. `pykalman`)

From `statsmodels` statespace models:
>Cholesky decomposition [...] requires that the matrix be positive definite. While this
          should generally be true, it may not be in every case. [source](https://www.statsmodels.org/stable/generated/statsmodels.tsa.statespace.kalman_filter.KalmanFilter.set_inversion_method.html#statsmodels.tsa.statespace.kalman_filter.KalmanFilter.set_inversion_method)

which seems to mean that they take into account the fact that during the filter calculations may not be positive definite



In [None]:
#| export
import pandas as pd
from torch import Tensor
import torch

In [None]:
A = torch.rand(2,3,3) # batched random matrix used for testing

#### Symmetry

In [None]:
#| export
def is_symmetric(value, atol=1e-5):
    return torch.isclose(value, value.mT, atol=atol).all(-1).all(-1)

In [None]:
is_symmetric(A)

tensor([False, False])

In [None]:
#| export
def symmetric_upto(value, start=-8):
    for exp in torch.arange(start, 3):
        if is_symmetric(value, atol=10**exp):
            return exp.item()
    return exp.item()

def symmetric_upto_batched(value, start=-8):
    return torch.tensor([symmetric_upto(v) for v in value])

In [None]:
symmetric_upto_batched(A)

tensor([0, 0])

#### is posdef

Default pytorch check (uses symmetry + cholesky decomposition)

In [None]:
#| export
def is_posdef(cov):
    return torch.distributions.constraints.positive_definite.check(cov)

In [None]:
is_posdef(A)

tensor([False, False])

check if it is pos definite using eigenvalues. Positive definite matrix have all positive eigenvalues

In [None]:
torch.linalg.eigvalsh(A)

tensor([[-0.7223,  0.3265,  1.4477],
        [-0.8472,  0.4187,  1.3116]])

In [None]:
#| export
def is_posdef_eigv(cov):
    try:
        eigv = torch.linalg.eigvalsh(cov)
        return eigv.ge(0).all(-1), eigv
    except torch._C._LinAlgError:
        return torch.tensor(False), torch.tensor(torch.nan)

In [None]:
is_posdef_eigv(A)

(tensor([False, False]),
 tensor([[-0.7223,  0.3265,  1.4477],
         [-0.8472,  0.4187,  1.3116]]))

Note that `is_posdef` and `is_posdef_eigv` can return different values, in general `is_posdef_eigv` is more tollerant

### Pytorch constraint

transform any matrix $A$ into a positive definite matrix ($PD$) using the following formula

$PD = AA^T + aI$ 

where $AA^T$ is a positive semi-definite matrix and $a$ is a small positive number that is added on the diagonal to ensure that the resulting matrix is positive definite (not semi-definite)

the inverse transformation uses cholesky decomposition


Another approach would be to multiple to lower triangular matrix, but they'd require a positive diagonal, which is harderd to obtain see [https://en.wikipedia.org/wiki/Definite_matrix#Cholesky_decomposition](https://en.wikipedia.org/wiki/Definite_matrix#Cholesky_decomposition)

The API inspired by gpytorch constraints

In [None]:
#| export
from fastcore.foundation import docs
from fastcore.basics import store_attr

In [None]:
from meteo_imp.utils import *

In [None]:
#| export
def inv_softplus(x):
    return x + torch.log(-torch.expm1(-x))
softplus = torch.nn.Softplus()

In [None]:
#| export
def batch_diagonal(x): return torch.diagonal(x, dim1=-2, dim2=-1)
def batch_diag_scatter(input, src): return torch.diagonal_scatter(input, src, dim1=-2, dim2=-1)
def batch_diag_embed(x): return torch.diagonal_embed(x, dim1=-2, dim2=-1)

In [None]:
#| export
@docs
class PosDef(): 
    def __init__(self, min_diag: float=1e-5 # min value for diagonal to ensure num stability
                ): store_attr()
    def transform_triangular(self, raw): return torch.tril(raw)
    def transform_pos_diag(self, raw):
        diag = softplus(batch_diagonal(raw)) + self.min_diag
        return batch_diag_scatter(raw, diag)
    def transform_cho_factor(self, raw):
        return self.transform_pos_diag(self.transform_triangular(raw))
    def transform(self,raw):
        cho_factor = self.transform_cho_factor(raw)
        return cho_factor @ cho_factor.mT 
    def inverse_transform(self, value): 
        cho_factor = torch.linalg.cholesky(value)
        return batch_diag_scatter(cho_factor, inv_softplus(batch_diagonal(cho_factor) - self.min_diag))
    
    _docs = {'cls_doc': "Positive Definite Constraint for PyTorch parameters",
             'transform_triangular': "transform to lower triangular matrix",
             'transform_pos_diag': "transform to matrix with positive diagonal",
             'transform_cho_factor': "trasform to Choleksy factor (lower triangular matrix with positive diagonal)",
             'transform':"transform any square matrix into a positive definite one",
             'inverse_transform': "tranform positive definite matrix into a matrix that can be back transformed using `transform`"}

to_posdef = PosDef().transform

In [None]:
constraint = PosDef()

posdef = constraint.transform(A)

In [None]:
A = torch.randn(2, 3,3)
triang = constraint.transform_triangular(A)
p_diag = constraint.transform_pos_diag(triang)
cho_fact = constraint.transform_cho_factor(A)
posdef = constraint.transform(A)
show_as_row(A, triang, p_diag, cho_fact, posdef)

In [None]:
show_as_row(is_posdef(torch.stack([posdef,A])), is_posdef_eigv(torch.stack([posdef,A])), is_symmetric(torch.stack([posdef,A])))

In [None]:
test_eq(is_posdef(posdef).all(), True)

In [None]:
test_close(posdef, constraint.transform(constraint.inverse_transform(posdef)))

In [None]:
symmetric_upto(posdef[0])

-8

In [None]:
is_posdef_eigv(to_posdef(torch.rand(1000, 1000)))[0]

tensor(False)

### Fuzzer

In [None]:
run_fuzzer = True # temporly disable for performance reasons

In [None]:
def random_posdef(bs=10,n=100,n_range=(0,1), **kwargs):
    A = torch.rand(bs,n,n, **kwargs)  * (n_range[1]-n_range[0]) + n_range[0]
    return PosDef().transform(A)

In [None]:
# fuzzer
def fuzz_posdef(bs=10,n=100,n_range=(0,1), **kwargs):
    posdef = random_posdef(bs, n, **kwargs)
    return pd.DataFrame(
        {'n': [n], 'range': str(n_range), 'n_samples': bs,
         'posdef': is_posdef(posdef).sum().item() / bs,
         'sym': is_symmetric(posdef).sum().item() / bs, 
         'posdef_eigv': is_posdef_eigv(posdef)[0].sum().item() / bs
    })

In [None]:
fuzz_posdef()

Unnamed: 0,n,range,n_samples,posdef,sym,posdef_eigv
0,100,"(0, 1)",10,0.9,1.0,0.8


In [None]:
n_min, n_max = -1, 1
A = torch.rand(2,100,100)  * (n_max-n_min) + n_min

In [None]:
is_posdef(to_posdef(A))

tensor([False, False])

In [None]:
ma = torch.tensor([[1., 7],
                   [-3, 4]])

In [None]:
is_posdef(to_posdef(ma))

tensor(True)

In [None]:
fuzz_posdef(device='cuda')

Unnamed: 0,n,range,n_samples,posdef,sym,posdef_eigv
0,100,"(0, 1)",10,1.0,1.0,1.0


In [None]:
# %time fuzz_posdef(bs=100, device='cuda')

In [None]:
rate_posdef = pd.concat([fuzz_posdef(n=n, bs=100, n_range=n_range, device='cuda') 
               for n in [10, 100]
               for n_range in [(-1,1),(0,1)]])

In [None]:
import altair as alt
from altair import datum

In [None]:
rate_posdef.head()

Unnamed: 0,n,range,n_samples,posdef,sym,posdef_eigv
0,10,"(-1, 1)",100,1.0,1.0,1.0
0,10,"(0, 1)",100,1.0,1.0,1.0
0,100,"(-1, 1)",100,0.96,1.0,0.98
0,100,"(0, 1)",100,1.0,1.0,1.0


In [None]:
def _plot_var(df, var, x='n:N', row='range', y_domain=(0,1), height=70, width=50):
    bar = alt.Chart(df).mark_bar().encode(
        x = alt.X('n:N'),
        y = alt.Y(var, scale=alt.Scale(domain=y_domain)),
        color = 'n:N',
    ).properties(height=height, width=width, ) 
    
    text = alt.Chart(df).mark_text(dy=10, color='white').encode(
        x = alt.X('n:N'),
        y = alt.Y(var),
        text = alt.Text(var, format=".2f")
    )
    
    return (bar + text).facet(
        row=row).properties(title=var, )
    

In [None]:
def _plot_var_box(df, var, x='n:N', row='range', column='noise:N', height=70, width=50, title=''):
    box = alt.Chart(df).mark_boxplot().encode(
        x = alt.X(x),
        y = alt.Y(var),
        color = x,
    ).properties(height=height, width=width) 

    # text = alt.Chart(df).mark_text(dy=10, color='white').encode(
    #     x = alt.X('n:N'),
    #     y = alt.Y(var),
    #     text = alt.Text(var, format=".2f")
    # )
    
    return (box).facet(
        column=column,
        row=row).properties(title=title)


In [None]:
from IPython import display
import vl_convert as vlc
from functools import partial

#### Generation of Random positive definite matrices 

In [None]:
def plot_posdef_simulation(n_s, range_s, bs=100, **kwargs):
    if not run_fuzzer: return
    rate_posdef = pd.concat([fuzz_posdef(n=n, bs=bs, n_range=range, device='cuda', **kwargs) 
               for n in n_s for range in range_s])
    
    print(rate_posdef)
    vl_spec = alt.hconcat(*[_plot_var(rate_posdef, var) for var in ['posdef', 'posdef_eigv']]).to_json()
    # workaround for bug in vegalite see https://github.com/altair-viz/altair/issues/2742
    svg = vlc.vegalite_to_svg(vl_spec, vl_version='v5.3')
    display.display(display.HTML(svg))

In [None]:
plot_posdef_simulation(n_s = [10, 100], range_s = [(-1, 1)], bs=1000)

     n    range  n_samples  posdef  sym  posdef_eigv
0   10  (-1, 1)       1000   1.000  1.0        1.000
0  100  (-1, 1)       1000   0.989  1.0        0.999


Let's go big by using a matrix `1000x1000`

In [None]:
plot_posdef_simulation(n_s = [1000], range_s = [(10, 20)], bs=100)

      n     range  n_samples  posdef  sym  posdef_eigv
0  1000  (10, 20)        100     0.0  1.0          0.0


for a standard noise on the diagonal less than half of the random matrices that are 1000 in size are positive definite.

Let's have a look at one of such matrices

In [None]:
posdef = random_posdef(100, 1000)
not_pd = posdef[torch.argwhere(~is_posdef_eigv(posdef)[0])[0]]

This should be positive definite but actually it's not ...

In [None]:
not_pd

tensor([[[7.8163e-01, 2.2600e-01, 4.9864e-01,  ..., 4.9707e-02,
          2.2801e-02, 5.4999e-01],
         [2.2600e-01, 1.4331e+00, 2.3493e-01,  ..., 9.0204e-01,
          4.1434e-01, 1.0401e+00],
         [4.9864e-01, 2.3493e-01, 1.4841e+00,  ..., 9.4301e-01,
          3.6869e-01, 1.0390e+00],
         ...,
         [4.9707e-02, 9.0204e-01, 9.4301e-01,  ..., 3.4198e+02,
          2.5398e+02, 2.4810e+02],
         [2.2801e-02, 4.1434e-01, 3.6869e-01,  ..., 2.5398e+02,
          3.3169e+02, 2.4080e+02],
         [5.4999e-01, 1.0401e+00, 1.0390e+00,  ..., 2.4810e+02,
          2.4080e+02, 3.2264e+02]]])

trying with `float64` (for memory constraint on the GPU only using a `700x700` matrix)

In [None]:
plot_posdef_simulation(n_s = [700], range_s = [(-.1, 1)], bs=100)

     n      range  n_samples  posdef  sym  posdef_eigv
0  700  (-0.1, 1)        100     0.0  1.0          0.0


In [None]:
plot_posdef_simulation(n_s = [700], range_s = [(-.1, 1)], bs=100, dtype=torch.float64)

     n      range  n_samples  posdef  sym  posdef_eigv
0  700  (-0.1, 1)        100     0.0  1.0          0.0


All matrices now are positive definite

#### Multiplication

check is multiplication of matrices is not breaking the positive definite constraint

If $A$ and $B$ are both positive definite matrices $ABA$ is also positive definite
[https://en.wikipedia.org/wiki/Definite_matrix#Multiplication](https://en.wikipedia.org/wiki/Definite_matrix#Multiplication)

In [None]:
def fuzz_op(op, # operation that takes 2 pos def matrices and return one pos def matrix
            fn_check = is_posdef,
                  n=100, # size of matrix
                  max_t=1000, # number of multiplications
                  noise=1e-5, # noise to add on diagonal
                  bs=10, # batch size
                  n_range=(0,1), # range of random numbers
                  **kwargs):
    pd1 = random_posdef(bs, n, noise, n_range, **kwargs)
    pd2 = random_posdef(bs, n, noise, n_range,**kwargs)
    stop_times = torch.zeros(bs, **kwargs)
    
    for t in torch.arange(max_t):
        pd1 = op(pd1, pd2)
        check = fn_check(pd1)
        stop_times[torch.logical_and(stop_times == 0, ~check)] = t
        if not check.any(): break
         
    stop_times[stop_times == 0] = t
    return pd.DataFrame(
        {'n': [n], 'noise': f"{noise:.0e}", 'range': str(n_range), 'n_samples': bs, 'last_t': t.item(),
         'mean_stop': stop_times.mean().item(),
         'std_stop': stop_times.std().item(),
         'stop_times': [stop_times.cpu().numpy()]})

In [None]:
fuzz_multiply = partial(fuzz_op, lambda pd1, pd2: pd2 @ pd1 @ pd2)
fuzz_multiply_eigv = partial(fuzz_multiply, fn_check = lambda pd1: is_posdef_eigv(pd1)[0])

In [None]:
def plot_multiply_simulation(n_s, noise_s, max_mult=1000, bs=100, **kwargs):
    mult = pd.concat([fuzz_multiply(n=n, noise=noise, bs=bs, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s]).explode('stop_times')
    
    mult_eigv = pd.concat([fuzz_multiply(n=n, noise=noise, bs=bs, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s]).explode('stop_times')
    
    vl_spec = alt.hconcat(*[_plot_var_box(df, 'stop_times') for df in [mult, mult_eigv]]).to_json()
    # workaround for bug in vegalite see https://github.com/altair-viz/altair/issues/2742
    svg = vlc.vegalite_to_svg(vl_spec, vl_version='v5.3')
    display.display(display.HTML(svg))
    return (mult, mult_eigv)

In [None]:
plot_multiply_simulation(n_s=[2,3,10, 100], noise_s=[1e-3, 1e-4, 1e-5], bs=100);

#### Addition

check is multiplication of matrices is not breaking the positive definite constraint

If $A$ and $B$ are both positive definite matrices $A+B$ is also positive definite
[https://en.wikipedia.org/wiki/Definite_matrix#Addition](https://en.wikipedia.org/wiki/Definite_matrix#Addition)

In [None]:
pd1 = random_posdef(10, 100)
pd2 = random_posdef(10, 100)

In [None]:
is_posdef(pd1 + pd2).all()

In [None]:
fuzz_add = partial(fuzz_op, lambda pd1, pd2: pd1 + pd2)

In [None]:
%time fuzz_add(max_t=1e5, device='cuda')

In [None]:
def plot_add_simulation(n_s, noise_s, max_ts=[1000], bs=100, **kwargs):
    add = pd.concat([fuzz_add(n=n, noise=noise, bs=bs, max_t=max_t, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s for max_t in max_ts]).explode('stop_times')
    
    vl_spec = _plot_var_box(add, var='stop_times', height=150, width=150).to_json()

    svg = vlc.vegalite_to_svg(vl_spec, vl_version='v5.3')
    display.display(display.HTML(svg))

In [None]:
cache_disk("add_plot")(lambda: plot_add_simulation(n_s=[50, 100, 150], noise_s=[1e-3, 1e-4, 1e-5], bs=100, max_ts=[1e5]))()

#### Numpy posdef

In [None]:
import numpy as np

In [None]:
arr = np.random.rand(2,3,3)

In [None]:
arr.shape

In [None]:
arr.transpose(0,2,1) == np.moveaxis(arr, -1, -2)

In [None]:
def to_posdef_np(x, noise=1e-5):
    return x @ np.moveaxis(x, -1, -2) + (noise * np.eye(x.shape[-1], dtype=arr.dtype))

In [None]:
to_posdef_np(arr)

In [None]:
# fuzzer
def fuzz_posdef_np(n=100, noise=1e-5, bs=10, range=(0,1), dtype=np.float32):
    A = np.random.rand(bs,n,n).astype(dtype)  * (range[1]-range[0]) + range[0]
    posdef = torch.from_numpy(to_posdef_np(A, noise))
    return pd.DataFrame(
        {'n': [n], 'noise': f"{noise:.0e}", 'range': str(range), 'n_samples': bs,
         'posdef': is_posdef(posdef).sum().item() / bs,
         'sym': is_symmetric(posdef).sum().item() / bs, 
         'posdef_eigv': is_posdef_eigv(posdef)[0].sum().item() / bs
    })

In [None]:
fuzz_posdef_np(n=1000, dtype=np.float32)

### Checker Positive Definite

This is to help finding matrices that aren't positive definite and debug the issues.
Returns a detailed dataframe row with info about the matrix and optionally logs everything to a global object

In [None]:
#| export
from warnings import warn
from fastcore.basics import store_attr

In [None]:
#| export
class CheckPosDef():
    def __init__(self,
                do_check:bool = False, # set to True to actually check matrix
                use_log:bool = True, # keep internal log
                warning:bool = True, # show a warning if a matrix is not pos def 
                ):
        store_attr()
        self.log = pd.DataFrame()
        self.extra_args = {}
    def add_args(self, **kwargs):
        """Add an extra argument to the next call of check_posdef """
        self.extra_args = {**kwargs, **self.extra_args}
        return self
    
    def check(self,
              x: Tensor, # (batch of) square matrix
              **extra_args
             ) -> pd.DataFrame:
        
        if not self.do_check: return
        
        self.add_args(**extra_args)
        
        x = x if x.dim() > 2 else [x]
        infos = pd.concat([*map(self._check_matrix, x, range(len(x)))])
        
        if self.use_log: self.log = pd.concat([self.log, infos])
        if self.warning and (~infos['is_pd_eigv'].all() or ~infos['is_pd_chol'].all()):
             warn("Matrix is not positive definite")
        
        self.extra_args = {} 
        return infos
    
    def _check_matrix(self,
                     x: Tensor, # square matrix
                     batch_n = 0,
                    ) -> pd.DataFrame:
        
        x = x.detach().clone() # ensure that there is a copy
        sym_upto = symmetric_upto(x)

        is_pd_eigv, eigv = is_posdef_eigv(x)
        is_pd_chol = torch.linalg.cholesky_ex(x).info.eq(0).all().item() # skip pytorch too strict symmetry check
        is_sym = is_symmetric(x).item()

        info = pd.DataFrame({
            'is_pd_eigv': is_pd_eigv.item(),
            'is_pd_chol': is_pd_chol,
            'is_sym': is_sym,
            'sym_upto': sym_upto,
            'eigv': [eigv.cpu().numpy()],
            'matrix': [x.cpu().numpy()],
            'batch_n': batch_n,
            **self.extra_args
        })

        return info

In [None]:
CheckPosDef(True).check(A)

In [None]:
CheckPosDef(True).check(A[0])

In [None]:
checker = CheckPosDef(True)

checker.check(A, my_arg="my arg") # this will be another col in the log

In [None]:
checker.log

In [None]:
checker.add_args(show="only once")
checker.check(posdef)
checker.check(A)
checker.log

In [None]:
B = torch.rand(2,3,3) # a batch of matrices

In [None]:
is_symmetric(B).shape

In [None]:
checker.check(B)

In [None]:
test_close(B[0] @ A, (B @ A)[0]) # example batched matrix multiplication

### Diagonal Positive Definite Contraint

this is a simpler contraint that make the matrix diagonal and positive definite, by forcing it to have positive numbers on the diagonal.

given a vector matrix $a$ it is transformed into a diagonal positive definite matrix using:

$A_{diag\ pos\ def} = a^2 I$

the inverse transformation is the square root of the diagonal

In [None]:
from meteo_imp.utils import *

In [None]:
#| export
@docs
class DiagPosDef(): 
    def transform(self,raw): return torch.diag_embed(raw.pow(2), dim1=-2, dim2=-1)
    def inverse_transform(self, value):
        if not is_diagonal(value): warn("Only diagonal of parameter considered")
        return torch.sqrt(torch.diagonal(value, dim1=-2, dim2=-1))
    
    _docs = {'cls_doc': "Diagonal Positive Definite Constraint for PyTorch parameters",
             'transform':"transform any vector into a diagonal positive definite matrix",
             'inverse_transform': "tranform diagonal positive definite matrix into a vector that can be back transformed using `transform`"}

def to_diagposdef(x): return DiagPosDef().transform(torch.diagonal(x, dim1=-2, dim2=-1))

In [None]:
DiagPosDef().transform(torch.rand(3))

tensor([[0.0008, 0.0000, 0.0000],
        [0.0000, 0.5319, 0.0000],
        [0.0000, 0.0000, 0.4327]])

In [None]:
DiagPosDef().transform(torch.rand(2, 3))

tensor([[[0.2185, 0.0000, 0.0000],
         [0.0000, 0.0034, 0.0000],
         [0.0000, 0.0000, 0.3096]],

        [[0.8915, 0.0000, 0.0000],
         [0.0000, 0.0632, 0.0000],
         [0.0000, 0.0000, 0.7199]]])

In [None]:
v = -1.2 * torch.ones(2,3)

In [None]:
DiagPosDef().inverse_transform(DiagPosDef().transform(v))

tensor([[-1.2000, -1.2000, -1.2000],
        [-1.2000, -1.2000, -1.2000]])

In [None]:
to_diagposdef(torch.rand(3,3))

tensor([[2.3585, 0.0000, 0.0000],
        [0.0000, 1.0629, 0.0000],
        [0.0000, 0.0000, 2.3738]])

In [None]:
dpd_const = DiagPosDef()
a = torch.rand(3)

In [None]:
dpd_const.transform(a)

tensor([[1.1899, 0.0000, 0.0000],
        [0.0000, 2.0571, 0.0000],
        [0.0000, 0.0000, 1.4777]])

In [None]:
test_close(a, dpd_const.inverse_transform(dpd_const.transform(a)))

## Conditional Predictions

Therefore we need to compute the conditional distribution of a normal ^[https://cs.nyu.edu/~roweis/notes/gaussid.pdf eq, 5a, 5d]

$$ X = \left[\begin{array}{c} x \\ o \end{array} \right] $$

$$ p(X) = N\left(\left[ \begin{array}{c} \mu_x \\ \mu_o \end{array} \right], \left[\begin{array}{cc} \Sigma_{xx} & \Sigma_{xo} \\ \Sigma_{ox} & \Sigma_{oo} \end{array} \right]\right)$$

where $x$ is a vector of variable that need to predicted and $o$ is a vector of the variables that have been observed


then the conditional distribution is:

$$p(x|o) = N(\mu_x + \Sigma_{xo}\Sigma_{oo}^{-1}(o - \mu_o), \Sigma_{xx} - \Sigma_{xo}\Sigma_{oo}^{-1}\Sigma_{ox})$$

In [None]:
#| export
import torch
from torch.distributions import MultivariateNormal
from torch.linalg import cholesky
from torch import cholesky_inverse
from torch import Tensor

from fastcore.test import *
from meteo_imp.utils import *
from typing import List

In [None]:
#| export
def conditional_guassian(
                         μ: Tensor, # mean with shape `[n_vars]`
                         Σ: Tensor, # cov with shape `[n_vars, n_vars] `
                         obs: Tensor, # Observations with shape `[n_obs]`, where `n_obs = sum(idx)`
                         mask: Tensor # Boolean tensor specifying for each variable is observed (True) or not (False). Shape `[n_vars]`
                        ) -> ListMNormal: # Distribution conditioned on observations. shape `[n_vars - n_obs]`
    assert μ.shape[0] == mask.shape[0]
    assert obs.shape[0] == sum(mask)
    
    μ_x = μ[~mask]
    μ_o = μ[mask]
    # the double square brackets `:][:` are needed to keep the dimensionality even for empty tensors 
    Σ_xx = Σ[~mask,:][:, ~mask]
    Σ_xo = Σ[~mask,:][:,  mask]
    Σ_ox = Σ[ mask,:][:, ~mask]
    Σ_oo = Σ[ mask,:][:,  mask]
    
    Σ_oo_inv = torch.inverse(Σ_oo) # cholesky_inverse(cholesky(Σ_oo))
    
    
    mean = μ_x + Σ_xo@Σ_oo_inv@(obs - μ_o)
    cov = Σ_xx - Σ_xo@Σ_oo_inv@Σ_ox
    
    return ListMNormal(mean, cov)
    

In [None]:
# example distribution with only 2 variables
μ = torch.tensor([.5, 1.])
Σ = torch.tensor([[1., .5], [.5 ,1.]])


mask = torch.tensor([True, False]) # second variable is the observed one

obs = torch.tensor([5.]) # value of second variable

gauss_cond = conditional_guassian(μ, Σ, obs, mask)

# hardcoded values to test that the code is working, see also for alternative implementation https://python.quantecon.org/multivariate_normal.html
test_close(3.25, gauss_cond.mean.item())
test_close(.75, gauss_cond.cov.item())

### Batches

cannot have proper batch support, or at least not in a straigthforward way as the shape of the output would be different for the different batches.

so using a for-loop to temporarly fix the situation

In [None]:
#| export
def cond_gaussian_batched(dist: ListMNormal,
                         obs, # this needs to have the same shape of the mask !!! 
                         mask
                         ) -> List[ListMNormal]: # lists of distributions for element in the batch
    return [conditional_guassian(dist.mean[i], dist.cov[i], obs[i][mask[i]], mask[i]) for i in range(obs.shape[0])]
        

In [None]:
reset_seed(10)
mean = torch.rand(2,3) # batch
cov = to_posdef(torch.rand(2,3,3))
mask = torch.rand(2,3) > .3
obs = torch.rand(2,3)

In [None]:
conditional_gaussian_batched(mean, cov, obs, mask)

In [None]:
mask.shape, obs.shape

In [None]:
assert mean.shape == mask.shape
assert mean.dim() == 2

In [None]:
obs.shape

In [None]:
mean_x = mean[~mask]
mean_o = mean[mask]

In [None]:
mask

In [None]:
mean_x

In [None]:
cov.shape

In [None]:
cov[~mask]

In [None]:
cov

In [None]:
cov[0][~mask[0], ~mask[0]]

In [None]:
cov[0][mask[0],:][:, mask[0]].shape

### Performance

analysis of the performance of inverting a positive definite matrix

Use `cholesky` decomposition and `cholesky_solve` to improve performance of matrix inversion

see the [Probabilist machine learning course from uni Tübigen](https://uni-tuebingen.de/en/180804), specifically the code from the [Gaussian Regression Notebook](https://uni-tuebingen.de/fileadmin/Uni_Tuebingen/Fakultaeten/MatNat/Fachbereiche/Informatik/Lehrstuehle/MethMaschLern/Probabilistic_ML/Notebook_Vorlesung_7___9/Gaussian_Linear_Regression.ipynb) for details

This is the direct implementation of the equations

In [None]:
def _conditional_guassian_base(
                         μ: Tensor, # mean with shape `[n_vars]`
                         Σ: Tensor, # cov with shape `[n_vars, n_vars] `
                         obs: Tensor, # Observations with shape `[n_vars]`
                         idx: Tensor # Boolean tensor specifying for each variable is observed (True) or not (False). Shape `[n_vars]`
                        ) -> ListNormal: # Distribution conditioned on observations
    μ_x = μ[~idx]
    μ_o = μ[idx]
    
    Σ_xx = Σ[~idx,:][:, ~idx]
    Σ_xo = Σ[~idx,:][:, idx]
    Σ_ox = Σ[idx,:][:, ~idx]
    Σ_oo = Σ[idx,:][:, idx]
    
    Σ_oo_inv = torch.linalg.inv(Σ_oo)
    
    mean = μ_x + Σ_xo@Σ_oo_inv@(obs - μ_o)
    cov = Σ_xx - Σ_xo@Σ_oo_inv@Σ_ox
    
    return ListNormal(mean, cov)
    

 faster version

In [None]:
n_var = 5
mean = torch.rand(n_var, dtype=torch.float64)
cov = to_posdef(torch.rand(n_var, n_var, dtype=torch.float64))
dist = MultivariateNormal(mean, cov)
idx = torch.rand(n_var, dtype=torch.float64) > .5
obs = torch.rand(n_var, dtype=torch.float64)[idx]

In [None]:
torch.linalg.inv(cov) 

In [None]:
(torch.linalg.inv(cov) - cholesky_inverse(torch.linalg.cholesky(cov))).max()

In [None]:
test_close(torch.linalg.inv(cov), cholesky_inverse(torch.linalg.cholesky(cov)), eps=1e-2)

In [None]:
reset_seed()
A = to_posdef(torch.rand(1000, 1000, dtype=torch.float64)) + torch.eye(1000) * 1e-3 # noise to ensure is positive definite

In [None]:
is_symmetric(A)

In [None]:
is_posdef(A)

In [None]:
%timeit torch.linalg.inv(A)

In [None]:
%timeit cholesky_inverse(torch.linalg.cholesky(A))

The second version is way faster

In [None]:
test_close(conditional_guassian(mean, cov, obs, idx).mean, _conditional_guassian_base(mean, cov, obs, idx).mean)

In [None]:
B = to_posdef(torch.rand(n_var, n_var, dtype=torch.float64))

In [None]:
B @ torch.inverse(cov)

In [None]:
torch.cholesky_solve(cholesky(cov), B)

## Helper

### cov2std

In [None]:
x = torch.stack([torch.eye(3)*i for i in  range(1,4)])

In [None]:
x

In [None]:
torch.diagonal(x, dim1=1, dim2=2)

In [None]:
#| export
def cov2std(x):
    "convert cov of array of covariances to array of stddev"
    return torch.sqrt(torch.diagonal(x, dim1=-2, dim2=-1))

## Export

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()