https://medium.com/@tinonucera/building-quadratic-approximation-in-bayesian-inference-from-scratch-step-by-step-example-f17167cbe9eb

https://bjlkeng.io/posts/normal-approximations-to-the-posterior-distribution/

https://rdrr.io/github/rmcelreath/rethinking/man/quap.html

https://www2.stat.duke.edu/~st118/sta250/laplace.pdf

https://bookdown.org/rdpeng/advstatcomp/laplace-approximation.html  se eksempel 5.1.1.1

https://www.uio.no/studier/emner/matnat/math/STK4021/h23/stk4021_chapters_1_to_5_new.pdf

https://medium.com/@pallavisinha12/create-python-package-automate-publishing-with-github-actions-a-quick-guide-35b82aa4684c

In [1]:
import pymc as pm
import numpy as np
import scipy.stats as st
import arviz as az
import xarray as xr
from pymc.step_methods.arraystep import ArrayStep
from pymc.util import get_value_vars_from_user_vars

In [2]:
RANDOM_SEED = 3137
rng = np.random.default_rng(RANDOM_SEED)

In [3]:
def quap(vars, start=None, draws=1_000, chains=1):

    map = pm.find_MAP(vars=vars, start=start)

    m = pm.modelcontext(None)

    for var in vars:
        if m.rvs_to_transforms[var] is not None:
            m.rvs_to_transforms[var] = None
            var_value = m.rvs_to_values[var]
            var_value.name = var.name

    H = pm.find_hessian(map, vars=vars)
    cov = np.linalg.inv(H)
    mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars])
    posterior = st.multivariate_normal(mean=mean, cov=cov)

    samples = rng.multivariate_normal(mean, cov, size=(chains, draws))

    data_vars = {}
    for i, var in enumerate(vars):
        data_vars[str(var)] = xr.DataArray(samples[:, :, i], dims=("chain", "draw"))

    coords = {"chain": np.arange(chains), "draw": np.arange(draws)}
    ds = xr.Dataset(data_vars, coords=coords)

    idata = az.convert_to_inference_data(ds)

    return idata, posterior

In [4]:
y = np.array([2642, 3503, 4358]*10)

In [5]:
with pm.Model() as m: 
  logsigma = pm.Uniform("logsigma", 1, 100)
  mu = pm.Uniform("mu", -10000, 10000) 
  yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
  idata, posterior = quap([mu, logsigma])

Output()

  return f(*args, **kwargs)


In [6]:
class QuadraticApproximation(ArrayStep):
    def __init__(self, vars, model, **kwargs):
        self.model = model
        self.vars = vars
        self.varnames = [var.name for var in vars]
        
        # Compute mode and covariance
        self.mode, self.covariance = self._compute_mode_and_covariance()
        
        vars = get_value_vars_from_user_vars(vars, model)
        
        # Create necessary function sets for pymc
        super().__init__(vars, [self._logp_fn], **kwargs)
      
    def _point_to_array(self, point):
        return np.array([point[varname] for varname in self.varnames])
    
    def _array_to_point(self, array):
        return {varname: val for varname, val in zip(self.varnames, array)}

    def _logp_fn(self, x):
        point = self._array_to_point(x)
        return self.model.logp(point)
    
    def _compute_mode_and_covariance(self):
        # Find the MAP estimate (mode of the posterior)
        map = pm.find_MAP(vars=self.vars)

        m = pm.modelcontext(None)

        for var in self.vars:
            if m.rvs_to_transforms[var] is not None:
                m.rvs_to_transforms[var] = None
                # change name so that we can use `map[var]` value
                var_value = m.rvs_to_values[var]
                var_value.name = var.name

        H = pm.find_hessian(map, vars=self.vars)
        cov = np.linalg.inv(H)
        mean = np.concatenate([np.atleast_1d(map[v.name]) for v in self.vars])
        
        return mean, cov

    def astep(self, q0, logp):
        # Generate a sample from the multivariate Gaussian approximation
        sample = np.random.multivariate_normal(self.mode, self.covariance)
        return sample, []

# Example usage:


In [7]:
with pm.Model() as model:
    # Define a simple Gaussian model
    logsigma = pm.Uniform("logsigma", 1, 100)
    mu = pm.Uniform("mu", -10000, 10000) 
    yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
        
    # Instantiate and use the custom sampler
    custom_step = QuadraticApproximation(vars=[mu, logsigma], model=model)
    trace = pm.sample(1000, step=custom_step)
    

Output()

  return f(*args, **kwargs)
Multiprocess sampling (4 chains in 4 jobs)
QuadraticApproximation: [mu, logsigma]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 11 seconds.


In [10]:
trace

In [8]:
az.summary(trace, kind="stats", hdi_prob=0.89).round(2)

Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
logsigma,6.55,0.13,6.34,6.74
mu,3503.01,126.0,3295.06,3702.32


In [9]:
az.summary(idata, kind="stats", hdi_prob=0.89).round(2)

Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
mu,3503.71,128.76,3320.47,3724.41
logsigma,6.55,0.13,6.37,6.76
