# Thermodynamic integration

This is based on the discrete Langevin sampler tested in[potts-locally-informed-mh.ipynb](https://github.com/fritzo/notebooks/blob/master/potts-locally-informed-mh.ipynb).

In [1]:
import itertools
import math
from collections import Counter

import torch
import torch.distributions as dist
import matplotlib
import matplotlib.pyplot as plt
from opt_einsum import contract as einsum
from tqdm.auto import tqdm

matplotlib.rcParams.update({'figure.facecolor': "white"})
matplotlib.rcParams.update({'figure.dpi': 200})

We'll use sample efficiency to test correctness of distribution. For good samplers, sample efficiency should be close to one; for bad samplers it should be close to zero.

In [2]:
def sample_efficiency(counts: torch.Tensor, probs: torch.Tensor) -> float:
    total = counts.sum()
    chisq = (counts - total * probs).square().sum().div(total)
    return float(1 / chisq)

Consider a simple coupled non-normalized probability distribution.

In [3]:
class Potts:
    def __init__(self, p, k, temperature=2):
        self.v = torch.randn(p, k) / temperature
        self.e = torch.randn(p, p, k, k) / temperature
    def __call__(self, x):
        return einsum("...vi,vi", x, self.v) + einsum("...ui,...vj,uvij", x, x, self.e)
    def enumerate_support(self):
        p, k = self.v.shape
        support = torch.tensor(list(itertools.product(*[range(k)] * p)))
        return torch.nn.functional.one_hot(support).float()

Let's use a parallel adaptive discrete Langevin MH sampler. 

In [4]:
def logp_and_nbhd(f, x):
    x.requires_grad_()
    logp = f(x)
    assert logp.shape == x.shape[:-2]
    nbhd = torch.autograd.grad(logp.sum(), [x])[0].detach()
    x.detach_()
    logp.detach_()
    nbhd -= (nbhd * x).sum(-1, True)
    nbhd /= 2
    return logp, nbhd

def make_proposal(x, nbhd, temperature):
    logq = nbhd / temperature
    logq += x * logq.exp().sum(-1, True).max(-2, True).values.sub(1).clamp(min=1).log()
    logq -= logq.logsumexp(-1, True)
    return logq
    
def anneal(
    f: callable, p: int, k: int, num_steps: int, num_samples: int
) -> torch.Tensor:
    log_Z = p * math.log(k)  # entropy of uniform distribution
    x0 = dist.OneHotCategorical(torch.ones(p, k)).sample([num_samples])
    logp0, nbhd0 = logp_and_nbhd(f, x0)
    beta_old = 0
    acceptance_rate = 0
    for beta_new in tqdm(torch.linspace(1 / num_steps, 1, num_steps)):
        logq0 = make_proposal(x0, nbhd0, 1 / beta_new)
        x1 = dist.OneHotCategorical(logits=logq0).sample()
        if (x0 == x1).all():
            acceptance_rate += 1
        else:
            logp1, nbhd1 = logp_and_nbhd(f, x1)
            logq1 = make_proposal(x1, nbhd1, 1 / beta_new)
            logq10 = einsum("bpk,bpk->b", logq1, x0)
            logq01 = einsum("bpk,bpk->b", logq0, x1)
            ratio = (logp1 - logp0 + logq10 - logq01).exp()
            accept = ratio > torch.rand(ratio.shape)
            acceptance_rate += accept.float().mean().item()
            if accept.all():
                x0, logp0, nbhd0 = x1, logp1, nbhd1
            elif accept.any():
                x0[accept] = x1[accept]
                logp0[accept] = logp1[accept]
                nbhd0[accept] = nbhd1[accept]
        log_Z += (beta_new - beta_old) * logp0
        beta_old = beta_new
    print(f"acceptance rate = {acceptance_rate / num_steps:0.3g}")
    return x0, log_Z

Now we'll test accuracy of the log partition function and of the sampled distribution. This only works for small problems where we can completely enumerate the support.

In [5]:
def test_sampler(p, k, num_steps=2000, num_samples=5000):
    print(f"Simulating model with {k**p} states")
    torch.manual_seed(20220714)
    f = Potts(p, k)
    xs = f.enumerate_support()
    logp = f(xs)
    true_log_Z = logp.logsumexp(-1)
    logp -= true_log_Z
    probs = logp.exp()
    entropy = -probs @ logp
    print(f"perplexity = {entropy.exp():0.1f}")

    sample, log_Z = anneal(f, p, k, num_steps=num_steps, num_samples=num_samples)
    counts_dict = Counter(map(tuple, sample.max(-1).indices.tolist()))
    counts = torch.zeros_like(probs)
    for i, x in enumerate(xs.max(-1).indices.tolist()):
        counts[i] = counts_dict[tuple(x)]
    print(f"sample efficiency = {sample_efficiency(counts, probs):0.3g}")
    print(f"True log Z = {true_log_Z:0.5g}")
    print(f"Estimated log Z = {log_Z.mean():0.5g} +- {log_Z.std():0.3g}")

In [6]:
test_sampler(3, 3)

Simulating model with 27 states
perplexity = 9.8


  0%|          | 0/2000 [00:00<?, ?it/s]

acceptance rate = 0.56
sample efficiency = 1.06
True log Z = 3.527
Estimated log Z = 4.5366 +- 0.0612


In [7]:
test_sampler(4, 4)

Simulating model with 256 states
perplexity = 51.3


  0%|          | 0/2000 [00:00<?, ?it/s]

acceptance rate = 0.492
sample efficiency = 1.15
True log Z = 8.2319
Estimated log Z = 9.829 +- 0.0981


In [8]:
test_sampler(5, 5)

Simulating model with 3125 states
perplexity = 168.1


  0%|          | 0/2000 [00:00<?, ?it/s]

acceptance rate = 0.386
sample efficiency = 0.865
True log Z = 11.63
Estimated log Z = 14.508 +- 0.215


In [9]:
test_sampler(10, 3)

Simulating model with 59049 states
perplexity = 132.2


  0%|          | 0/2000 [00:00<?, ?it/s]

acceptance rate = 0.238
sample efficiency = 1.37
True log Z = 23.281
Estimated log Z = 29.026 +- 0.33


In [10]:
test_sampler(4, 20)

Simulating model with 160000 states
perplexity = 15026.9


  0%|          | 0/2000 [00:00<?, ?it/s]

acceptance rate = 0.458
sample efficiency = 1.03
True log Z = 14.413
Estimated log Z = 16.727 +- 0.293
