In [2]:
import pyro
import torch
import pyro.distributions as pdist
import torch.distributions as tdist
import arviz as az
import numpy as np
from pyro.infer.mcmc import NUTS, MCMC

In [3]:
# We need constraints to specify the integration interval
from torch.distributions import constraints


# Define the unnormalized density function
class MyDensity(pdist.TorchDistribution):
    # The integration interval
    support = constraints.interval(-3.0, 3.0)
    # Constraint for the starting value used in the sampling
    arg_constraints = {"start": support}

    def __init__(self, start=torch.tensor(0.0)):
        # start = starting value for HMC sampling, default 0
        self.start = start
        super().__init__(event_shape=torch.Size())

    def sample(self, sample_shape=torch.Size()):
        # This is only used to start the HMC sampling
        # It simply returns the starting value for the sampling
        return self.start

    def log_prob(self, x):
        # Return log of the (unnormalized) density
        term1 = -x**2 / 2
        term2 = (torch.sin(x)**2) + 3 * (torch.cos(x)**2) * (torch.sin(7 * x)**2) + 1
        return term1 + torch.log(term2)


In [4]:
# Specify the model, which in our case is just our MyDensity distribution

def model():
    return pyro.sample("x", MyDensity())

In [6]:
# Run HMC / NUTS
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200, num_chains=2, mp_context="spawn")
mcmc.run()

Warmup [1]:   0%|          | 0/1200 [00:00, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
Traceback (most recent call last):
  File "/opt/homebrew/Cellar/python@3.12/3.12.6/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
  File "<string>", line 1, in <module>
  File "/opt/homebrew/Cellar/python@3.12/3.12.6/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)    exitcode = _main(fd, parent_sentinel)

                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/opt/homebrew/Cellar/python@3.12/3.12.6/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
  File "/opt/homebrew/Cellar/python@3.12/3.12.6/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction

KeyboardInterrupt: 

In [None]:
samples = mcmc.get_samples()["x"]
expected_val = torch.mean(samples**2).item()

In [None]:
# TODO use arviz to summarize and investigate: plot
idata = az.from_pyro(mcmc)

print("Expected value of x^2:", expected_val)
az.plot_trace(idata)

In [None]:
# HMC / NUTS
def HMC(n_samples):
    # Run HMC / NUTS
    nuts_kernel=None
    mcmc= None
    # Do something with the mcmc object here ...

    # TODO Get the samples
    samples = None

    # TODO: Calculate E(x^2)
    expected_val = None
    return expected_val

In [None]:
# TODO make figures w.r.t. different sample sizes, samplers, etc.