In [1]:
from collections import defaultdict
from scipy import stats
import numpy as np

In [2]:
class MCMCSampler():

    def __init__(self, log_pstar, covariance, data, **kwargs):
        self.log_pstar  = log_pstar
        self.covariance = covariance
        self.data       = data
        self.kwargs     = kwargs

    def transition(self, theta):
        return stats.multivariate_normal(theta, self.covariance).rvs()

    def get_samples(self, burn_period = 0.2):
        if isinstance(burn_period, float):
            burn_period = int(burn_period * len(self.samples) + 1)
        return self.samples[burn_period:]

def check_verbose(N, verbose):
    if isinstance(verbose, str):
        if verbose == 'auto':
            temp = 10 ** np.floor(np.log10(N) - 1)
            k    = np.array([1, 2, 5])
            arg  = np.fabs(N // temp / k - 10).argmin()
            verbose = int(temp * k[arg])
        else:
            raise Exception()
    elif isinstance(verbose, str):
        verbose = int(N * verbose + 0.5)

    assert isinstance(verbose, int)
    assert 0 < verbose < N

    return verbose

class Verbose():

    def __init__(self, N, verbose):
        self.N = N
        self.verbose = verbose
        self.num = max(len(f'{N:,d}'), len('iteration'))

        space = max(len('iteration') - self.num, 0)

        print(' ' * space + 'iteration | log pstar')
        print('-' * space + '----------+----------')

    def print(self, i, val, end = '\n'):
        print(f'\r{i:>{self.num},d} | {val:+.2e}', end = end)

class MetropolisHastingsSampler(MCMCSampler):

    def __init__(self, log_pstar, covariance, data, **kwargs):
        super().__init__(log_pstar, covariance, data, **kwargs)

    def fit(self, n_samples, theta0, verbose = 'auto', random_state = None):

        if random_state is not None:
            np.random.seed(random_state)

        N                  = n_samples + 1 # add 1 to include theta0
        m                  = len(theta0)   # no. of parameters
        
        assert n_samples > 0
        assert isinstance(theta0, np.ndarray) and (theta0.ndim == 1)

        verbose            = check_verbose(N, verbose)

        # samples
        self.samples       = np.empty((N, m))
        self.samples[0]    = theta0

        # record of all log_pstar evaluations
        self.log_pstars    = np.empty(N)
        self.log_pstars[0] = self.log_pstar(theta0, self.data, **self.kwargs)

        # acceptance indicator for all new samples
        self.acceptance    = np.zeros(N - 1, dtype = bool)

        if verbose:
            message = Verbose(N, verbose)
            message.print(0, self.log_pstars[0])

        for i in range(1, N):

            # sample a new theta
            theta = self.transition(self.samples[i - 1])

            # compute log pstar of new theta
            logp  = self.log_pstar(theta, self.data, **self.kwargs)

            # accept with p_new / p_old probability
            if np.log(np.random.uniform()) < (logp - self.log_pstars[i - 1]):
                self.samples[i]        = theta
                self.log_pstars[i]     = logp
                self.acceptance[i - 1] = True

            # reject and add the previous sample
            else:
                self.samples[i]        = self.samples[i - 1]
                self.log_pstars[i]     = self.log_pstars[i - 1]
            
            if verbose:
                message.print(i, self.log_pstars[i], '' if i % verbose else '\n')

        return self

class GibbsSampler(MCMCSampler):

    def __init__(self, log_pstar, deviation, data, **kwargs):
        super().__init__(log_pstar, deviation, data, **kwargs)

        self.__dict__['deviation'] = self.__dict__.pop('covariance')

    def transition(self, value, j):
        return stats.norm(value, self.deviation[j]).rvs()

    def fit(self, n_samples, theta0, verbose = 'auto', random_state = None):

        if random_state is not None:
            np.random.seed(random_state)

        N                  = n_samples + 1 # add 1 to include theta0
        m                  = len(theta0)   # no. of parameters

        assert n_samples > 0
        assert isinstance(theta0, np.ndarray) and (theta0.ndim == 1)

        verbose = check_verbose(N, verbose)

        # samples
        self.samples       = np.empty((N, m))
        self.samples[0]    = theta0

        # record of all log_pstar evaluations
        self.log_pstars    = np.empty(N)
        self.log_pstars[0] = self.log_pstar(theta0, self.data, **self.kwargs)

        # acceptance rate indicator for all new samples (per parameter)
        self.acceptance    = np.zeros((N - 1, m), dtype = bool)

        if verbose:
            message = Verbose(N, verbose)
            message.print(0, self.log_pstars[0])

        for i in range(1, N):

            # theta and logp baseline
            theta_baseline = self.samples[i - 1]
            logp_baseline  = self.log_pstars[i - 1]

            # loop through each parameter in theta
            for j in range(m):
                
                # copy most recent theta baseline
                theta    = theta_baseline.copy()

                # sample the j-th element
                theta[j] = self.transition(theta[j], j)
                
                # compute log_pstar for theta
                logp     = self.log_pstar(theta, self.data, **self.kwargs)

                # accept with p_new / p_old probability
                if np.log(np.random.uniform()) < (logp - logp_baseline):
                    theta_baseline = theta
                    logp_baseline  = logp
                
                # reject the new sample
                else:
                    self.acceptance[i - 1,j] = False
                
            # append new sample and log_pstar values
            self.samples[i]    = theta_baseline
            self.log_pstars[i] = logp_baseline
            
            if verbose:
                message.print(i, self.log_pstars[i], '' if i % verbose else '\n')

        return self


class AdaptiveGibbsSampler(MCMCSampler):

    def __init__(self, log_pstar, deviation, data, rate = 0.9, **kwargs):

        assert isinstance(rate, float) and 0 < rate < 1

        super().__init__(log_pstar, deviation, data, **kwargs)

        self.__dict__['deviation'] = self.__dict__.pop('covariance')
        self.rate = rate

    def transition(self, value, j):
        return stats.norm(value, self.deviation[j]).rvs()

    def fit(self, n_samples, theta0, verbose = 'auto', random_state = None):

        if random_state is not None:
            np.random.seed(random_state)

        N                  = n_samples + 1 # add 1 to include theta0
        m                  = len(theta0)   # no. of parameters

        assert n_samples > 0
        assert isinstance(theta0, np.ndarray) and (theta0.ndim == 1)

        verbose = check_verbose(N, verbose)

        # samples
        self.samples       = np.empty((N, m))
        self.samples[0]    = theta0

        # record of all log_pstar evaluations
        self.log_pstars    = np.empty(N)
        self.log_pstars[0] = self.log_pstar(theta0, self.data, **self.kwargs)

        # acceptance rate indicator for all new samples
        self.acceptance    = {}

        # log_alpha for the dirichlet prior
        log_alpha          = np.zeros(m)

        if verbose:
            message = Verbose(N, verbose)
            message.print(0, self.log_pstars[0])

        for i in range(1, N):
            
            # theta and logp baseline
            theta_baseline = self.samples[i - 1]
            logp_baseline  = self.log_pstars[i - 1]

            # initialise the acceptance rate for the i-th sample to be a dictionary with list values
            self.acceptance[i] = defaultdict(list)

            for _ in range(m):
                
                # compute alpha
                alpha    = np.exp(log_alpha) + 1e-8              # numerical stability as alpha needs to be greater than 0

                # sample the j-th parameter to update
                j        = stats.dirichlet(alpha).rvs().argmax() # draws m numbers but select the one with the highest value

                # copy most recent theta baseline
                theta    = theta_baseline.copy()

                # sample the j-th element
                theta[j] = self.transition(theta[j], j)
                
                # compute log_pstar for theta
                logp     = self.log_pstar(theta, self.data, **self.kwargs)

                # accept with probability p_new / p_old
                if np.log(np.random.uniform()) < (logp - logp_baseline):
                    
                    # increment the j-th element of log_alpha by the improvement in log_pstar value
                    log_alpha[j]  += logp - logp_baseline

                    theta_baseline = theta
                    logp_baseline  = logp
                    self.acceptance[i][j].append(True)
                    
                # reject the new sample
                else:
                    self.acceptance[i][j].append(False)
                
                # for numerical stability
                log_alpha -= log_alpha.max()

                # pull all values towards 0 (this prevents exploding values)
                log_alpha *= self.rate
                
            self.samples[i]    = theta_baseline
            self.log_pstars[i] = logp_baseline
            
            if verbose:
                message.print(i, self.log_pstars[i], '' if i % verbose else '\n')
                
        return self


In [3]:
np.random.seed(0)

X = np.random.normal(size = (100, 5))
w = np.random.normal(scale = 3, size = 5)
b = np.random.normal(scale = 10)

y = X @ w + b + np.random.normal(scale = 0.2, size = 100)

In [4]:
def log_pstar(theta, data, **kwargs):
    X, y = data
    b, w = theta[0], theta[1:]
    return stats.norm(X @ w + b, 0.2).logpdf(y).sum()

In [5]:
theta0    = np.zeros(6)
deviation = np.ones(6) * 0.05
data      = (X, y)

mcmc = MetropolisHastingsSampler(log_pstar, deviation, data)

mcmc.fit(1000, theta0)

# compare the average of the samples to the true values
print(mcmc.get_samples(500).mean(axis = 0), np.append(b, w))

# norm distance
np.linalg.norm(mcmc.get_samples(500).mean(axis = 0) - np.append(b, w))

iteration | log pstar
----------+----------
        0 | -6.52e+04
      100 | -3.78e+03
      200 | -7.67e+01
      300 | -7.38e+01
      400 | -4.50e+01
      500 | -2.49e+01
      600 | -2.49e+01
      700 | -2.49e+01
      800 | -2.49e+01
      900 | -2.49e+01
    1,000 | -2.49e+01
[-5.89406933  1.16883667 -0.20996651  3.22307695 -0.78225874 -0.91443858] [-5.81268477  1.14819729 -0.10272684  3.28904054 -0.7026474  -1.04235196]


0.21354336994483278

In [6]:
mcmc = GibbsSampler(log_pstar, deviation, data)

mcmc.fit(1000, theta0)

# compare the average of the samples to the true values
print(mcmc.get_samples(500).mean(axis = 0), np.append(b, w))

# norm distance
np.linalg.norm(mcmc.get_samples(500).mean(axis = 0) - np.append(b, w))

iteration | log pstar
----------+----------
        0 | -6.52e+04
      100 | -2.31e+04
      200 | -4.41e+03
      300 | -7.84e+01
      400 | +1.65e+01
      500 | +1.57e+01
      600 | +1.42e+01
      700 | +1.65e+01
      800 | +1.70e+01
      900 | +1.48e+01
    1,000 | +1.61e+01
[-5.85782613  1.13758115 -0.11418702  3.27399197 -0.71067281 -1.0822089 ] [-5.81268477  1.14819729 -0.10272684  3.28904054 -0.7026474  -1.04235196]


0.06450754387547969

In [7]:
mcmc = AdaptiveGibbsSampler(log_pstar, deviation, data)

mcmc.fit(1000, theta0)

# compare the average of the samples to the true values
print(mcmc.get_samples(500).mean(axis = 0), np.append(b, w))

# norm distance
np.linalg.norm(mcmc.get_samples(500).mean(axis = 0) - np.append(b, w))

iteration | log pstar
----------+----------
        0 | -6.52e+04
      100 | -3.85e+03
      200 | +1.62e+01
      300 | +1.33e+01
      400 | +1.55e+01
      500 | +1.50e+01
      600 | +1.43e+01
      700 | +1.43e+01
      800 | +1.59e+01
      900 | +1.25e+01
    1,000 | +1.61e+01
[-5.86363205  1.13474773 -0.11487291  3.26878705 -0.71104131 -1.08056184] [-5.81268477  1.14819729 -0.10272684  3.28904054 -0.7026474  -1.04235196]


0.06974740087091971