In [16]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
import sys
import os

parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(parent_dir)

import numpy as np

import matplotlib.pyplot as plt

from src.sampler_nuts import nuts_sampler, nuts_sampler_side_move


# Regular NUTs

In [15]:
# Correlated Gaussian target
Sigma = np.array([[1.0, 0.8],
                  [0.8, 1.5]])
Sigma_inv = np.linalg.inv(Sigma)

def logp(q):
    return -0.5 * np.einsum("bi,ij,bj->b", q, Sigma_inv, q)

def grad_logp(q):
    return -q @ Sigma_inv.T


# Run NUTs Sampler
batch_size = 4

samples, info = nuts_sampler(
    log_prob         = logp,
    grad_log_prob    = grad_logp,
    num_samples      = 5000,
    initial_positions= np.zeros((batch_size, 2)),
    step_size        = 0.3
)

print("Mean acceptance rate:", info["accept_rate_mean"])
print('')
print("Empirical mean:", [np.mean(sample, axis=0) for sample in samples])
print('')
print("Empirical covariance:\n")
for output in [np.cov(sample.T) for sample in samples]:
    print(output)

Mean acceptance rate: [0.99221019 0.99221785 0.99286596 0.99198096]

Empirical mean: [array([0.01840941, 0.00419406]), array([ 0.00032169, -0.02170667]), array([0.02035895, 0.0349395 ]), array([-0.05924829, -0.06577955])]

Empirical covariance:

[[0.87038764 0.69483781]
 [0.69483781 1.31483268]]
[[0.87532261 0.70486718]
 [0.70486718 1.32462156]]
[[0.84902211 0.71232135]
 [0.71232135 1.31397485]]
[[0.97118616 0.80115099]
 [0.80115099 1.44052789]]


# Side-Step NUTs

In [25]:
def logp(q):    
    return -0.5 * np.einsum("bi,ij,bj->b", q, Sigma_inv, q)

def grad_logp(q):
    return -q @ Sigma_inv.T

n_chains_per_group = 3
n_chains = 2 * n_chains_per_group

samples, info = nuts_sampler_side_move(
    log_prob         = logp,
    grad_log_prob    = grad_logp,
    num_samples      = 5000,
    initial_positions= np.zeros((n_chains, 2)),
    n_chains_per_group = n_chains_per_group
)

100%|██████████| 5000/5000 [05:29<00:00, 15.17it/s]


In [26]:
print("Mean acceptance rate:", info["accept_rate_mean"])
print('')
print("Empirical mean:", [np.mean(sample, axis=0) for sample in samples])
print('')
print("Empirical covariance:\n")
for output in [np.cov(sample.T) for sample in samples]:
    print(output)

Mean acceptance rate: [nan nan nan nan nan nan]

Empirical mean: [array([-0.07241323, -0.15616043]), array([0.09077687, 0.02863648]), array([0.14708811, 0.27988946]), array([0.01192773, 0.13302205]), array([0.14302369, 0.30040758]), array([-0.01051927, -0.01828502])]

Empirical covariance:

[[0.96562599 0.88293167]
 [0.88293167 1.75146528]]
[[1.27607709 0.96842618]
 [0.96842618 1.6262488 ]]
[[1.17210086 0.88696659]
 [0.88696659 1.69532555]]
[[0.98345297 0.71976464]
 [0.71976464 1.46594256]]
[[1.07591927 0.84495099]
 [0.84495099 1.52508248]]
[[1.02638834 0.79466393]
 [0.79466393 1.46689468]]
