# Bayesian Inference Modelling
## Goal: estimation of both evidence and posterior in one go.

# 1. WSABI model
- We use the WSABI-M model in the following paper.
- Gunter T, Osborne MA, Garnett R, Hennig P, Roberts SJ. Sampling for inference in probabilistic models with fast Bayesian quadrature. Advances in neural information processing systems 27 (2014).
- You can also easily try WSABI-L by changing label to "wsabil"

In [1]:
import torch
import time
from BASQ._wsabi import WsabiGP
from gpytorch.kernels import ScaleKernel, RBFKernel

def set_and_opt_gp(X, Y):
    kernel = ScaleKernel(RBFKernel())
    model = WsabiGP(X, Y, kernel, tm.device, alpha_factor=1, label="wsabim", optimiser="BoTorch")
    return model

In [2]:
from SOBER._utils import TensorManager
from SOBER.BASQ._basq import BASQ
from BASQ.experiment.gmm import GMM
tm = TensorManager()

num_dim = 10  # Number of dimensions of the true likelihood to be estimated
mu_pi = torch.zeros(num_dim).to(tm.device, tm.dtype)  # the mean vactor of Gaussian prior
cov_pi = 2 * torch.eye(num_dim).to(tm.device, tm.dtype)  # the covariance matrix of Gaussian prior

from SOBER._prior import Gaussian
prior = Gaussian(mu_pi, cov_pi)
true_likelihood = GMM(num_dim, mu_pi, cov_pi, tm.device)  # true likelihood to be estimated

In [3]:
n_init = 2           # number of initial guess
n_iterations = 10    # number of iterations (batches)
n_cand = 20000       # number of candidates
n_nys = 500          # number of Nyström samples
n_batch = 100        # batch size

# Run!

In [4]:
Z_true = 1                             # true integral
x_test = prior.sample(10000)           # test data for evaluating posterior using KL divergence

torch.manual_seed(0)                    # fix random seed for reproducibility
X = prior.sample(n_init)               # inital dataset X
Y = true_likelihood(X).to(tm.dtype)    # initial guess Y

# CAUTION! You need to specify the model is warped GP by making it True.
basq = BASQ(n_cand, n_nys, prior, warped_gp=True)      # set up BASQ instance
model = set_and_opt_gp(X, Y)           # set up the GP surroage model

for ith_round in range(n_iterations):
    tik = time.monotonic()
    X_batch, _ = basq.batch_uncertainty_sampling(model, n_batch)  # run BASQ algorithm to select 100 batch points
    tok = time.monotonic()
    overhead = tok - tik               # overhead of batch query
    
    Y_batch = true_likelihood(X_batch) # parallel query to true likelihood function
    X = torch.cat([X, X_batch])        # concatenate the observations for X
    Y = torch.cat([Y, Y_batch])        # concatenate the observations for X
    
    # Evaluation for integral
    model = set_and_opt_gp(X, Y)       # retrain GP model
    integral_estimated = basq.quadrature(model, 500)  # integral estimation
    logMAE = (Z_true - integral_estimated).abs()      # evaluate the estimated integral value to true one
    # EZ, VZ = basq.full_quadrature(model, 500)       # You can estimate integral variance (but takes more time)
    
    # Evaluation for the posterior
    KL = basq.KLdivergence(Z_true, x_test, true_likelihood, model)  # compute the KL divergence
    print('Iter %d - overhead: %.3f [s]  logMAE of Integral: %.3f   logKL of posterior: %.3f' % (
        ith_round, overhead, logMAE.log().item(), KL.log().item()
    ))

Iter 0 - overhead: 0.997 [s]  logMAE of Integral: -1.912   logKL of posterior: -7.398
Iter 1 - overhead: 1.132 [s]  logMAE of Integral: -0.640   logKL of posterior: -8.998
Iter 2 - overhead: 1.725 [s]  logMAE of Integral: -1.355   logKL of posterior: -8.720
Iter 3 - overhead: 1.823 [s]  logMAE of Integral: -1.692   logKL of posterior: -9.039
Iter 4 - overhead: 1.769 [s]  logMAE of Integral: -2.508   logKL of posterior: -8.729
Iter 5 - overhead: 2.059 [s]  logMAE of Integral: -1.663   logKL of posterior: -10.367
Iter 6 - overhead: 2.272 [s]  logMAE of Integral: -1.075   logKL of posterior: -9.662
Iter 7 - overhead: 2.597 [s]  logMAE of Integral: -1.485   logKL of posterior: -11.935
Iter 8 - overhead: 2.555 [s]  logMAE of Integral: -1.185   logKL of posterior: -9.623
Iter 9 - overhead: 2.886 [s]  logMAE of Integral: -1.359   logKL of posterior: -10.310


# 2. MMLT model
- We use the MMLT model in the following papers.
- This allows **true "log" likelihood**. This is more natural for most Bayesian inference.
- Thus, the observed values Y should be in log space
- MMLT is good if likelihood takes very large dynamic range. Otherwise, non-warped GP works better.

- Chai HR, Garnett R. Improving quadrature for constrained integrands. InThe 22nd International Conference on Artificial Intelligence and Statistics 2019 Apr 11 (pp. 2751-2759). PMLR.

In [5]:
import torch
import time
from SOBER.BASQ._scale_mmlt import ScaleMmltGP
from gpytorch.kernels import ScaleKernel, RBFKernel

def set_and_opt_gp(X, Y):
    kernel = ScaleKernel(RBFKernel())
    model = ScaleMmltGP(X, Y, kernel, tm.device, optimiser="BoTorch")
    return model

In [6]:
from SOBER._utils import TensorManager
from SOBER.BASQ._basq import BASQ
from BASQ.experiment.gmm import GMM
tm = TensorManager()

num_dim = 10  # Number of dimensions of the true likelihood to be estimated
mu_pi = torch.zeros(num_dim).to(tm.device, tm.dtype)  # the mean vactor of Gaussian prior
cov_pi = 2 * torch.eye(num_dim).to(tm.device, tm.dtype)  # the covariance matrix of Gaussian prior

from SOBER._prior import Gaussian
prior = Gaussian(mu_pi, cov_pi)
true_likelihood = GMM(num_dim, mu_pi, cov_pi, tm.device)  # true likelihood to be estimated

In [7]:
n_init = 2           # number of initial guess
n_iterations = 10    # number of iterations (batches)
n_cand = 20000       # number of candidates
n_nys = 500          # number of Nyström samples
n_batch = 100        # batch size

# Run!

In [8]:
Z_true = 1                             # true integral
x_test = prior.sample(10000)           # test data for evaluating posterior using KL divergence

torch.manual_seed(0)                    # fix random seed for reproducibility
X = prior.sample(n_init)               # inital dataset X
# CAUTION! You need to give "log" likelihood
Y = true_likelihood(X).log().to(tm.dtype)    # initial guess Y

# CAUTION! You need to specify the model is warped GP by making it True.
basq = BASQ(n_cand, n_nys, prior, warped_gp=True)      # set up BASQ instance
model = set_and_opt_gp(X, Y)           # set up the GP surroage model

for ith_round in range(n_iterations):
    tik = time.monotonic()
    X_batch, _ = basq.batch_uncertainty_sampling(model, n_batch)  # run BASQ algorithm to select 100 batch points
    tok = time.monotonic()
    overhead = tok - tik               # overhead of batch query
    
    Y_batch = true_likelihood(X_batch).log().to(tm.dtype) # parallel query to true likelihood function
    X = torch.cat([X, X_batch])        # concatenate the observations for X
    Y = torch.cat([Y, Y_batch])        # concatenate the observations for X
    
    # Evaluation for integral
    model = set_and_opt_gp(X, Y)       # retrain GP model
    integral_estimated = basq.quadrature(model, 500)  # integral estimation
    logMAE = (Z_true - integral_estimated).abs()      # evaluate the estimated integral value to true one
    # EZ, VZ = basq.full_quadrature(model, 500)       # You can estimate integral variance (but takes more time)
    
    # Evaluation for the posterior
    KL = basq.KLdivergence(Z_true, x_test, true_likelihood, model)  # compute the KL divergence
    print('Iter %d - overhead: %.3f [s]  logMAE of Integral: %.3f   logKL of posterior: %.3f' % (
        ith_round, overhead, logMAE.log().item(), KL.log().item()
    ))

Iter 0 - overhead: 1.181 [s]  logMAE of Integral: -0.202   logKL of posterior: -8.846
Iter 1 - overhead: 1.441 [s]  logMAE of Integral: -0.080   logKL of posterior: -8.546
Iter 2 - overhead: 1.276 [s]  logMAE of Integral: -0.046   logKL of posterior: -8.012
Iter 3 - overhead: 1.386 [s]  logMAE of Integral: -0.040   logKL of posterior: -7.971
Iter 4 - overhead: 1.792 [s]  logMAE of Integral: -0.046   logKL of posterior: -8.816
Iter 5 - overhead: 1.521 [s]  logMAE of Integral: -0.044   logKL of posterior: -8.584
Iter 6 - overhead: 1.697 [s]  logMAE of Integral: -0.047   logKL of posterior: -9.621
Iter 7 - overhead: 1.861 [s]  logMAE of Integral: -0.047   logKL of posterior: -9.490
Iter 8 - overhead: 2.386 [s]  logMAE of Integral: -0.022   logKL of posterior: -8.923
Iter 9 - overhead: 2.747 [s]  logMAE of Integral: -0.021   logKL of posterior: -8.458
