# Sampling from the XY Model with Spline Flows: Direct Sampling of Spins

In [None]:
from __future__ import annotations

import math
import torch
import pytorch_lightning as pl

import actions
import transforms
import utils

Tensor: TypeAlias = torch.Tensor
BoolTensor: TypeAlias = torch.BoolTensor
Module: TypeAlias = torch.nn.Module
IterableDataset: TypeAlias = torch.utils.data.IterableDataset

PI = math.pi

%load_ext lab_black
%load_ext tensorboard

# Definitions

In [None]:
class CouplingBlock(torch.nn.Module):
    """Pair of coupling layers."""

    def __init__(
        self,
        transform,
        lattice_shape: list[int],
        net_hidden_shape: list[int],
        net_activation: Module,
        net_final_activation: Module,
    ):
        super().__init__()
        self.transform = transform
        self.register_buffer("mask", utils.make_checkerboard(lattice_shape))

        half_lattice = utils.prod(lattice_shape) // 2
        nodes = [
            half_lattice,
            *net_hidden_shape,
            half_lattice * self.transform.params_dof,
        ]
        activations = [net_activation for _ in net_hidden_shape] + [
            net_final_activation
        ]
        net_a, net_b = [], []
        for d_in, d_out, f_act in zip(nodes[:-1], nodes[1:], activations):
            net_a.append(torch.nn.Linear(d_in, d_out))
            net_b.append(torch.nn.Linear(d_in, d_out))
            net_a.append(f_act)
            net_b.append(f_act)
        self.net_a = torch.nn.Sequential(*net_a)
        self.net_b = torch.nn.Sequential(*net_b)

    def forward(self, inputs: Tensor, log_det_jacob: Tensor) -> tuple[Tensor]:
        in_a = inputs[:, self.mask]
        in_b = inputs[:, ~self.mask]
        out_a, log_det_jacob_a = self.transform(
            in_a, self.net_b(in_b).view(*in_a.shape, -1).squeeze(dim=-1)
        )
        out_b, log_det_jacob_b = self.transform(
            in_b, self.net_a(out_a).view(*in_b.shape, -1).squeeze(dim=-1)
        )
        log_det_jacob.add_(log_det_jacob_a + log_det_jacob_b)
        outputs = torch.empty_like(inputs)
        outputs[:, self.mask] = out_a
        outputs[:, ~self.mask] = out_b
        return outputs, log_det_jacob

    def inverse(self, inputs: Tensor, log_det_jacob: Tensor) -> tuple[Tensor]:
        in_a = inputs[:, self.mask]
        in_b = inputs[:, ~self.mask]
        out_b, log_det_jacob_b = self.transform.inverse(
            in_b, self.net_a(in_a).view(*in_b.shape, -1).squeeze(dim=-1)
        )
        out_a, log_det_jacob_a = self.transform.inverse(
            in_a, self.net_b(out_b).view(*in_a.shape, -1).squeeze(dim=-1)
        )
        log_det_jacob.add_(log_det_jacob_a + log_det_jacob_b)
        outputs = torch.empty_like(inputs)
        outputs[:, self.mask] = out_a
        outputs[:, ~self.mask] = out_b
        return outputs, log_det_jacob

In [None]:
class Model(pl.LightningModule):
    """Module which learns to sample from XY model."""

    def __init__(
        self,
        *,
        xy_coupling: float,
        lattice_shape: list[int],
        n_blocks: int,
        n_spline_segments: int,
        net_hidden_shape: list[int],
        net_activation: torch.nn.Module,
        use_shift_coupling_layers: bool = False,
        use_random_rotations: bool = False,
    ):
        super().__init__()
        self.save_hyperparameters()

        layers = []
        for _ in range(n_blocks):
            if use_shift_coupling_layers:
                layers.append(
                    CouplingBlock(
                        transforms.PointwisePhaseShift(),
                        lattice_shape,
                        net_hidden_shape,
                        net_activation=torch.nn.Tanh(),
                        net_final_activation=torch.nn.Hardtanh(-PI, PI),
                    )
                )
            if use_random_rotations:
                layers.append(utils.RandomRotationLayer())
            layers.append(
                CouplingBlock(
                    transforms.PointwiseRationalQuadraticSplineTransform(
                        n_spline_segments
                    ),
                    lattice_shape,
                    net_hidden_shape,
                    net_activation,
                    net_final_activation=torch.nn.Identity(),
                )
            )

        self.flow = utils.Flow(*layers)
        self.action = actions.XYSpinAction(xy_coupling)
        self.curr_iter = 0

    def _save_checkpoint(self):
        """Hack: should be done automatically by Lightning but I never progress
        past epoch=0 and this is easier than fiddling with other things.
        """
        self.trainer.save_checkpoint(
            self.logger.log_dir + f"/checkpoints/iter_{self.curr_iter}.ckpt"
        )

    def forward(self, batch):
        """z ~ Unif -> x ~ XY"""
        z, log_prob_z = batch
        x, log_det_jacob = self.flow(z)
        weights = log_prob_z - log_det_jacob + self.action(x)
        return x, weights

    def on_train_start(self):
        self._save_checkpoint()
        # Dirty but more convenient than saving / loading config files atm
        self.logger.log_hyperparams(self.hparams)

    def on_train_end(self):
        self._save_checkpoint()

    def training_step(self, batch, batch_idx):
        _, weights = self.forward(batch)
        loss = weights.mean()
        self.log("loss", loss, logger=True)
        self.lr_schedulers().step()  # must be called manually since lightning 1.3!!
        self.curr_iter += 1
        return loss

    def validation_step(self, batch, batch_idx):
        x, weights = self.forward(batch)
        loss = weights.mean()
        metrics = dict(
            loss=loss,
            acceptance=utils.metropolis_acceptance(weights),
            mag_sq=utils.magnetisation_sq(x).mean(),
        )
        self.log_dict(
            metrics,
            prog_bar=False,
            logger=True,
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.flow.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.trainer.max_steps
        )
        return [optimizer], [scheduler]

    @torch.no_grad()
    def sample(self, prior: IterableDataset, n_iter: int = 1):
        x, weights = self.forward(next(prior))
        for _ in range(n_iter - 1):
            _x, _weights = self.forward(next(prior))
            x = torch.cat((x, _x), dim=0)
            weights = torch.cat((weights, _weights), dim=0)
        return x, weights

# Training

In [None]:
# XY parameters
LATTICE_SHAPE = [6]
XY_COUPLING = 1

# Model parameters
N_BLOCKS = 2
N_SPLINE_SEGMENTS = 8
NET_HIDDEN_SHAPE = [32]
NET_ACTIVATION = torch.nn.Tanh()
USE_SHIFT_COUPLING_LAYERS = True
USE_RANDOM_ROTATIONS = True

# Training hyperparameters
N_TRAIN = 10000
N_BATCH = 1000
N_BATCH_VAL = 10000

In [None]:
model = Model(
    xy_coupling=XY_COUPLING,
    lattice_shape=LATTICE_SHAPE,
    n_blocks=N_BLOCKS,
    n_spline_segments=N_SPLINE_SEGMENTS,
    net_hidden_shape=NET_HIDDEN_SHAPE,
    net_activation=NET_ACTIVATION,
    use_shift_coupling_layers=USE_SHIFT_COUPLING_LAYERS,
    use_random_rotations=USE_RANDOM_ROTATIONS,
)

# Could wrap in DataSet(IterableDataset, batch_size=None) but can't currently see benefit
unif = torch.distributions.Uniform(-PI, PI)
train_dataloader = utils.Prior(
    distribution=unif, sample_shape=[N_BATCH, *LATTICE_SHAPE]
)
val_dataloader = utils.Prior(
    distribution=unif, sample_shape=[N_BATCH_VAL, *LATTICE_SHAPE]
)

logger = pl.loggers.TensorBoardLogger(
    save_dir="test",
    name="spins",
)
pbar = utils.JlabProgBar()
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")

trainer = pl.Trainer(
    gpus=1,
    max_steps=N_TRAIN,  # total number of training steps
    val_check_interval=100,  # how often to run sampling
    limit_val_batches=1,  # one batch for each val step
    # logger=logger,
    callbacks=[pbar, lr_monitor],
    enable_checkpointing=False,  # manually saving checkpoints
)

In [None]:
trainer.fit(model, train_dataloader, val_dataloader)

In [None]:
# Test that loading checkpoint works
_reloaded_model = Model.load_from_checkpoint(
    trainer.log_dir + f"/checkpoints/iter_{N_TRAIN}.ckpt",
    # action=XYAction(J=1.1),  # could change some of the parameters
)

# Analysis

In [None]:
x, weights = model.sample(val_dataloader, n_iter=10)
model.logger.experiment.add_histogram("spins", x.flatten(), trainer.current_epoch)
model.logger.experiment.add_histogram(
    "links", utils.spins_to_links(x).flatten(), trainer.current_epoch
)
model.logger.experiment.add_histogram("weights", weights, trainer.current_epoch)
model.logger.experiment.add_histogram(
    "magnetisation_sq", utils.magnetisation_sq(x), trainer.current_epoch
)

In [None]:
%tensorboard --logdir lightning_logs

## Checking the O(2) symmetry

In [None]:
import matplotlib.pyplot as plt

N_BATCH = 1000
N_ROTATIONS = 100

angles = torch.linspace(0, 2 * PI, N_ROTATIONS)
x, _ = model.sample(utils.Prior(unif, [N_BATCH, *LATTICE_SHAPE]))
action = model.action(x)
log_prob_model = torch.empty((N_BATCH, N_ROTATIONS))

for i, angle in enumerate(angles):
    x_rotated = x.add(PI).add(angle).fmod(2 * PI).sub(PI)
    assert torch.allclose(action, model.action(x_rotated), atol=1e-5)
    _, log_det_jacob_inv = model.flow.inverse(x_rotated)
    log_prob_model[:, i] = log_det_jacob_inv

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))

ax1.set_xlabel("log ptilde (of unrotated config)")
ax1.set_ylabel("std dev of log ptilde")
ax1.scatter(log_prob_model[:, 0].tolist(), log_prob_model.std(dim=1).tolist())

ax2.set_xlabel("std dev of log ptilde")
ax2.set_ylabel("frequency")
ax2.hist(log_prob_model.std(dim=1).tolist())

ax3.set_xlabel("rotation angle")
ax3.set_ylabel("log ptilde")
for i in range(6):
    ax3.plot(angles.tolist(), log_prob_model[i].tolist())

ax4.set_xlabel("rotation_angle")
ax4.set_ylabel("<log ptilde>")
ax4.plot(angles.tolist(), log_prob_model.mean(dim=0).tolist())