In [1]:
from __future__ import annotations

import math
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt

import actions
import mcmc
import transforms
import utils

Tensor: TypeAlias = torch.Tensor
BoolTensor: TypeAlias = torch.BoolTensor
Module: TypeAlias = torch.nn.Module
IterableDataset: TypeAlias = torch.utils.data.IterableDataset

PI = math.pi

%load_ext lab_black
%load_ext tensorboard

INFO:blib2to3.pgen2.driver:Generating grammar tables from /home/joe/.miniconda3/envs/xy/lib/python3.9/site-packages/blib2to3/Grammar.txt
INFO:blib2to3.pgen2.driver:Writing grammar tables to /home/joe/.cache/black/22.1.0/Grammar3.9.7.final.0.pickle
INFO:blib2to3.pgen2.driver:Writing failed: [Errno 2] No such file or directory: '/home/joe/.cache/black/22.1.0/tmpsuwiwxuf'
INFO:blib2to3.pgen2.driver:Generating grammar tables from /home/joe/.miniconda3/envs/xy/lib/python3.9/site-packages/blib2to3/PatternGrammar.txt
INFO:blib2to3.pgen2.driver:Writing grammar tables to /home/joe/.cache/black/22.1.0/PatternGrammar3.9.7.final.0.pickle
INFO:blib2to3.pgen2.driver:Writing failed: [Errno 2] No such file or directory: '/home/joe/.cache/black/22.1.0/tmpf1irt2sf'


## 1D XY chain with free boundary conditions

In [2]:
def action(phi: Tensor, beta: float) -> Tensor:
    links = phi[:, 1:] - phi[:, :-1]
    action = links.cos().sum(dim=1).mul(beta).neg()
    return action


def magnetisation_sq(phi: Tensor) -> Tensor:
    return phi.cos().mean(dim=1).pow(2) + phi.sin().mean(dim=1).pow(2)


def local_field_strength(phi: Tensor) -> tuple[Tensor]:
    cos_phi, sin_phi = F.pad(phi.cos(), (1, 1)), F.pad(phi.sin(), (1, 1))
    m1 = cos_phi[:, 2:] + cos_phi[:, :-2]
    m2 = sin_phi[:, 2:] + sin_phi[:, :-2]
    return m1, m2


def log_norm(n_spins: int, beta: float):
    return torch.full((1, n_spins), 2 * PI).log().sum(dim=1) + torch.full(
        (1, n_spins - 1), beta
    ).i0().log().sum(dim=1)

In [20]:
def gibbs_update(phi: Tensor, update_indices: Tensor, calc_prob: bool = False):
    prev_spins = phi.index_select(dim=1, index=update_indices)

    # Compute forward transition probabilities
    m1, m2 = local_field_strength(phi)
    m1 = m1.index_select(dim=1, index=update_indices)
    m2 = m2.index_select(dim=1, index=update_indices)

    kappa = (beta * (m1.pow(2) + m2.pow(2)).sqrt()).clamp_(min=0.01)
    theta = torch.atan2(m2, m1)

    # Sample from forward conditional distribution
    forward_dist = torch.distributions.VonMises(loc=theta, concentration=kappa)
    new_spins = forward_dist.sample()

    phi_new = phi.clone()
    phi_new[:, update_indices] = new_spins

    if not calc_prob:
        return phi_new

    log_prob_forward = forward_dist.log_prob(new_spins).sum(dim=1)

    # Compute reverse transition probabilities
    m1, m2 = local_field_strength(phi_new)
    m1 = m1.index_select(dim=1, index=update_indices)
    m2 = m2.index_select(dim=1, index=update_indices)

    kappa = (beta * (m1.pow(2) + m2.pow(2)).sqrt()).clamp_(min=0.01)
    theta = torch.atan2(m2, m1)

    reverse_dist = torch.distributions.VonMises(loc=theta, concentration=kappa)
    log_prob_reverse = reverse_dist.log_prob(prev_spins).sum(dim=1)

    return phi_new, log_prob_forward - log_prob_reverse

## Check exact sampling of links

In [21]:
n_sample = 1000
n_spins = 10
beta = 0.8

links_sampler = torch.distributions.VonMises(loc=0, concentration=beta)
links = links_sampler.sample((n_sample, n_spins - 1))
log_prob_model = (
    links_sampler.log_prob(links).sum(dim=1) - torch.full([n_sample], 2 * PI).log()
)

spins = torch.cumsum(
    torch.cat([torch.empty((n_sample, 1)).uniform_(-PI, PI), links], dim=1),
    dim=1,
)
assert torch.allclose(spins[:, 1:] - spins[:, :-1], links, atol=1e-5)

log_prob_target = -action(spins, beta) - log_norm(n_spins, beta)
weights = log_prob_target - log_prob_model
assert torch.allclose(log_prob_target, log_prob_model)

## Check Gibbs sampling works

In [22]:
n_sample = 1000
n_spins = 10
beta = 0.8

exact_sample = torch.cumsum(
    torch.cat(
        [
            torch.empty([n_sample, 1]).uniform_(-PI, PI),
            links_sampler.sample([n_sample, n_spins - 1]),
        ],
        dim=1,
    ),
    dim=1,
).remainder(2 * PI)

reference_action = action(exact_sample, beta)
reference_mag_sq = magnetisation_sq(exact_sample)

print("Exact:")
print(
    f"Energy: {reference_action.mean():.4g} +/- {reference_action.std()/math.sqrt(reference_action.shape[0]):.2g}"
)
print(
    f"Mag sq: {reference_mag_sq.mean():.4g} +/- {reference_mag_sq.std()/math.sqrt(reference_mag_sq.shape[0]):.2g}"
)

Exact:
Energy: -2.672 +/- 0.049
Mag sq: 0.2034 +/- 0.0054


In [23]:
# Heat bath
n_sweeps = 10
heatbath_sample = exact_sample.clone()

update_indices = torch.randint(0, n_spins, size=(n_spins * n_sweeps, 1))
for i in range(n_sweeps * n_spins):
    heatbath_sample = gibbs_update(heatbath_sample, update_indices[i])

heatbath_action = action(heatbath_sample, beta)
heatbath_mag_sq = magnetisation_sq(heatbath_sample)

print("Heat bath:")
print(
    f"Energy: {heatbath_action.mean():.4g} +/- {heatbath_action.std()/math.sqrt(heatbath_action.shape[0]):.2g}"
)
print(
    f"Mag sq: {heatbath_mag_sq.mean():.4g} +/- {heatbath_mag_sq.std()/math.sqrt(heatbath_mag_sq.shape[0]):.2g}"
)

Heat bath:
Energy: -2.7 +/- 0.047
Mag sq: 0.1966 +/- 0.0052


In [24]:
# Checkerboard
gibbs_sample = exact_sample.clone()
even_sites = torch.arange(0, n_spins, 2)
odd_sites = torch.arange(1, n_spins, 2)

for _ in range(n_sweeps):
    gibbs_sample = gibbs_update(gibbs_sample, even_sites)
    gibbs_sample = gibbs_update(gibbs_sample, odd_sites)

gibbs_action = action(gibbs_sample, beta)
gibbs_mag_sq = magnetisation_sq(gibbs_sample)

print("Gibbs:")
print(
    f"Energy: {gibbs_action.mean():.4g} +/- {gibbs_action.std()/math.sqrt(gibbs_action.shape[0]):.2g}"
)
print(
    f"Mag sq: {gibbs_mag_sq.mean():.4g} +/- {gibbs_mag_sq.std()/math.sqrt(gibbs_mag_sq.shape[0]):.2g}"
)

Gibbs:
Energy: -2.632 +/- 0.046
Mag sq: 0.1949 +/- 0.0052


## Gibbs sampling step after exact sampling

In [35]:
n_sample = 4
n_spins = 10
beta = 0.8

links_sampler = torch.distributions.VonMises(loc=0, concentration=beta)
links = links_sampler.sample((n_sample, n_spins - 1))
log_prob_model = (
    links_sampler.log_prob(links).sum(dim=1) - torch.full([n_sample], 2 * PI).log()
)

exact_sample = torch.cumsum(
    torch.cat([torch.empty((n_sample, 1)).uniform_(-PI, PI), links], dim=1),
    dim=1,
)

log_prob_target = -action(exact_sample, beta) - log_norm(n_spins, beta)
weights = log_prob_target - log_prob_model

print("Exact sampling from von Mises...")
print("log prob from action: \n", log_prob_model)
print("log prob from von mises: \n", log_prob_target)

phi = exact_sample.clone()
phi, log_transition_prob = gibbs_update(phi, torch.Tensor([1]).long(), calc_prob=True)
log_prob_model.add_(log_transition_prob)
log_prob_target = -action(phi, beta) - log_norm(n_spins, beta)

print("After one Heatbath update...")
print("log prob from action: \n", log_prob_model)
print("log prob from von mises: \n", log_prob_target)

phi, log_transition_prob = gibbs_update(
    phi, torch.arange(0, n_spins, 2), calc_prob=True
)
log_prob_model.add_(log_transition_prob)
log_prob_target = -action(phi, beta) - log_norm(n_spins, beta)

print("After one Checkerboard update...")
print("log prob from action: \n", log_prob_model)
print("log prob from von mises: \n", log_prob_target)

Exact sampling from von Mises...
log prob from action: 
 tensor([-14.8371, -15.9839, -19.2774, -18.7554])
log prob from von mises: 
 tensor([-14.8371, -15.9839, -19.2774, -18.7554])
After one Heatbath update...
log prob from action: 
 tensor([-14.9967, -16.1953, -20.6349, -18.8216])
log prob from von mises: 
 tensor([-14.9967, -16.1953, -20.6349, -18.8216])
After one Checkerboard update...
log prob from action: 
 tensor([-16.2260, -18.6741, -17.8651, -14.6803])
log prob from von mises: 
 tensor([-16.2260, -18.6741, -17.8651, -14.6803])
