In [1]:
import sys
import os

module_path = os.path.abspath(os.path.join("../.."))
if module_path not in sys.path:
    sys.path.append(module_path)

from pvi.distributions.base import ExponentialFamilyDistribution
from pvi.distributions.exponential_family_distributions import *
from pvi.distributions.exponential_family_factors import *

import torch

  from .autonotebook import tqdm as notebook_tqdm


# Test t-factors

## Mean-field Gaussian Factor

In [2]:
N = 4
D = 2

mean = torch.zeros(size=(D,))
prec = torch.ones(size=(D,))
natural_parameters = {"np1" : mean / prec,
                      "np2" : -0.5 * prec}

thetas = torch.ones(size=(N, D))

mean_field_gaussian = MeanFieldGaussianFactor(natural_parameters)

In [3]:
mean_field_gaussian(thetas)

tensor([-1., -1., -1., -1.])

In [4]:
dist1 = torch.distributions.Normal(loc=mean, scale=prec ** -0.5)
dist2 = torch.distributions.Normal(loc=mean, scale=(2. * prec) ** -0.5)
print(0.5 * (dist1.scale ** -2 - dist2.scale ** -2 + prec))
mean_field_gaussian.compute_refined_factor(dist1, dist2).natural_parameters

tensor([-1.1921e-07, -1.1921e-07])


AttributeError: 'Normal' object has no attribute 'nat_params'

## Multivariate Gaussian factor

In [5]:
N = 4
D = 2

log_coefficient = 0.
mean = torch.zeros(size=(D,))
prec = torch.eye(D)
natural_parameters = {"np1" : torch.solve(mean[:, None], prec).solution[:, 0],
                      "np2" : -0.5 * prec}

thetas = torch.ones(size=(N, D))

multivariate_gaussian = MultivariateGaussianFactor(natural_parameters=natural_parameters)

torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:859.)
  import sys


TypeError: __init__() got an unexpected keyword argument 'natural_parameters'

In [6]:
multivariate_gaussian(thetas)

NameError: name 'multivariate_gaussian' is not defined

In [None]:
N = 4
D = 2

mean = torch.zeros(size=(D,))
L = torch.tensor([[1., 0.],
                  [0., 1.]])
prec = torch.mm(L, L.T)
natural_parameters = {"np1" : torch.solve(mean[:, None], prec).solution[:, 0],
                      "np2" : -0.5 * prec}

thetas = torch.ones(size=(N, D))

multivariate_gaussian = MultivariateGaussianFactor(natural_parameters=natural_parameters)

multivariate_gaussian.distribution_from_np(multivariate_gaussian.natural_parameters).scale_tril

In [None]:
dist1 = torch.distributions.MultivariateNormal(loc=mean, covariance_matrix=torch.inverse(prec))
dist2 = torch.distributions.MultivariateNormal(loc=mean, covariance_matrix=torch.inverse(2 * prec))
print(-0.5 * (torch.inverse(dist1.covariance_matrix) - torch.inverse(dist2.covariance_matrix) + prec))
multivariate_gaussian.compute_refined_factor(dist1, dist2).natural_parameters

# Test distribution base classes

## Mean Field Gaussian

In [None]:
N = 4
D = 2

mean = torch.zeros(size=(D,))
scale = torch.ones(size=(D,))

sp = {
    "sp1" : mean,
    "sp2" : scale
    
}

mfgd = MeanFieldGaussianDistribution(std_params=sp,
                                     nat_params=None,
                                     is_trainable=True)

mfgd.rsample().sum().backward()
mfgd._unc_params["up2"].grad

# Multivariate Gaussian

In [None]:
N = 4
D = 2

mean = torch.zeros(size=(D,))
cov = torch.eye(D)

sp = {
    "sp1" : mean,
    "sp2" : cov
    
}

mvgd = MultivariateGaussianDistribution(std_params=sp,
                                        nat_params=None,
                                        is_trainable=True)

mvgd.rsample().sum().backward()

In [None]:
mean = torch.zeros(size=(D,))
cov = torch.eye(D)

d1 = torch.distributions.MultivariateNormal(loc=mean, covariance_matrix=cov)
d2 = torch.distributions.MultivariateNormal(loc=mean, covariance_matrix=cov)

torch.distributions.kl_divergence(d1, d2)

In [None]:
torch.distributions.Dirichlet(torch.tensor([0.5, 1e-6]))._natural_params

## Dirichlet distribution

In [None]:
D = 4

conc = torch.zeros(size=(D,)).uniform_()

sp = {"sp1" : conc}

dird = DirichletDistribution(std_params=sp,
                             nat_params=None,
                             is_trainable=False)

dird.kl_divergence(dird)

# Multinomial distribution

In [None]:
D = 10

p = torch.zeros(size=(D,)).uniform_()
p = p / p.sum()

sp = {
    "sp1" : N,
    "sp2" : p
}

muld = MultinomialDistribution(std_params=sp,
                               nat_params=None,
                               is_trainable=False)

muld.sample()