In [1]:
from __future__ import annotations

import math
import torch
import torch.nn.functional as F
import pytorch_lightning as pl

import actions
import mcmc
import transforms
import utils

Tensor: TypeAlias = torch.Tensor
BoolTensor: TypeAlias = torch.BoolTensor
Module: TypeAlias = torch.nn.Module
IterableDataset: TypeAlias = torch.utils.data.IterableDataset

PI = math.pi

%load_ext lab_black
%load_ext tensorboard

INFO:blib2to3.pgen2.driver:Generating grammar tables from /home/joe/.miniconda3/envs/xy/lib/python3.9/site-packages/blib2to3/Grammar.txt
INFO:blib2to3.pgen2.driver:Writing grammar tables to /home/joe/.cache/black/21.12b0/Grammar3.9.7.final.0.pickle
INFO:blib2to3.pgen2.driver:Writing failed: [Errno 2] No such file or directory: '/home/joe/.cache/black/21.12b0/tmp0pv4oyfk'
INFO:blib2to3.pgen2.driver:Generating grammar tables from /home/joe/.miniconda3/envs/xy/lib/python3.9/site-packages/blib2to3/PatternGrammar.txt
INFO:blib2to3.pgen2.driver:Writing grammar tables to /home/joe/.cache/black/21.12b0/PatternGrammar3.9.7.final.0.pickle
INFO:blib2to3.pgen2.driver:Writing failed: [Errno 2] No such file or directory: '/home/joe/.cache/black/21.12b0/tmpp1c_tsoy'


In [2]:
class CouplingBlock(torch.nn.Module):
    """Pair of coupling layers."""

    def __init__(
        self,
        transform,
        lattice_shape: list[int],
        net_hidden_shape: list[int],
        net_activation: Module,
        net_final_activation: Module,
    ):
        super().__init__()
        self.transform = transform
        self.register_buffer("mask", utils.make_checkerboard(lattice_shape))

        half_lattice = utils.prod(lattice_shape) // 2
        nodes = [
            half_lattice,
            *net_hidden_shape,
            half_lattice * self.transform.params_dof,
        ]
        activations = [net_activation for _ in net_hidden_shape] + [
            net_final_activation
        ]
        net_a, net_b = [], []
        for d_in, d_out, f_act in zip(nodes[:-1], nodes[1:], activations):
            net_a.append(torch.nn.Linear(d_in, d_out))
            net_b.append(torch.nn.Linear(d_in, d_out))
            net_a.append(f_act)
            net_b.append(f_act)
        self.net_a = torch.nn.Sequential(*net_a)
        self.net_b = torch.nn.Sequential(*net_b)

    def forward(self, inputs: Tensor, log_det_jacob: Tensor) -> tuple[Tensor]:
        in_a = inputs[:, self.mask]
        in_b = inputs[:, ~self.mask]
        out_a, log_det_jacob_a = self.transform(
            in_a, self.net_b(in_b).view(*in_a.shape, -1).squeeze(dim=-1)
        )
        out_b, log_det_jacob_b = self.transform(
            in_b, self.net_a(out_a).view(*in_b.shape, -1).squeeze(dim=-1)
        )
        log_det_jacob.add_(log_det_jacob_a + log_det_jacob_b)
        outputs = torch.empty_like(inputs)
        outputs[:, self.mask] = out_a
        outputs[:, ~self.mask] = out_b
        return outputs, log_det_jacob

    def inverse(self, inputs: Tensor, log_det_jacob: Tensor) -> tuple[Tensor]:
        in_a = inputs[:, self.mask]
        in_b = inputs[:, ~self.mask]
        out_b, log_det_jacob_b = self.transform.inverse(
            in_b, self.net_a(in_a).view(*in_b.shape, -1).squeeze(dim=-1)
        )
        out_a, log_det_jacob_a = self.transform.inverse(
            in_a, self.net_b(out_b).view(*in_a.shape, -1).squeeze(dim=-1)
        )
        log_det_jacob.add_(log_det_jacob_a)  # + log_det_jacob_b)
        outputs = torch.empty_like(inputs)
        outputs[:, self.mask] = out_a
        outputs[:, ~self.mask] = out_b
        return outputs, log_det_jacob

In [3]:
class CouplingBlockHalfConditional(torch.nn.Module):
    """Pair of coupling layers."""

    def __init__(
        self,
        transform,
        lattice_shape: list[int],
        net_hidden_shape: list[int],
        net_activation: Module,
        net_final_activation: Module,
        coupling_strength: float,
        final_layer: bool = False,
    ):
        super().__init__()
        self.transform = transform
        self.register_buffer("mask", utils.make_checkerboard(lattice_shape))

        half_lattice = utils.prod(lattice_shape) // 2
        nodes = [
            half_lattice,
            *net_hidden_shape,
            half_lattice * self.transform.params_dof,
        ]
        activations = [net_activation for _ in net_hidden_shape] + [
            net_final_activation
        ]
        net = []
        for d_in, d_out, f_act in zip(nodes[:-1], nodes[1:], activations):
            net.append(torch.nn.Linear(d_in, d_out))
            net.append(f_act)
        self.net = torch.nn.Sequential(*net)

        lattice_dim = len(lattice_shape)
        self.register_buffer("kernel", utils.nearest_neighbour_kernel(lattice_dim))
        if lattice_dim == 1:
            self.conv = F.conv1d
        elif lattice_dim == 2:
            self.conv = F.conv2d
        elif lattice_dim == 3:
            self.conv = F.conv3d

        padding = tuple(1 for edge in range(2 * lattice_dim))
        self.pad = lambda config: F.pad(config, padding, "circular")

        self.coupling_strength = coupling_strength
        self.final_layer = final_layer

    def sample_conditional(self, config):
        config.unsqueeze_(dim=1)
        cos_config, sin_config = config.cos(), config.sin()
        m1 = self.conv(self.pad(cos_config), self.kernel).squeeze(dim=1)
        m2 = self.conv(self.pad(sin_config), self.kernel).squeeze(dim=1)

        m1 = m1[..., self.mask]
        m2 = m2[..., self.mask]
        kappa = self.coupling_strength * (m1.pow(2) + m2.pow(2)).sqrt()
        kappa.clamp_(min=0.01)  # otherwise sampling takes AGES
        theta = torch.atan2(m2, m1)

        dist = torch.distributions.VonMises(loc=theta, concentration=kappa)
        new_spins = dist.sample()
        log_density = dist.log_prob(new_spins).sum(dim=1)
        return new_spins.squeeze(dim=1), log_density

    def forward(self, inputs: Tensor, log_det_jacob: Tensor) -> tuple[Tensor]:
        in_a = inputs[:, self.mask]
        in_b = inputs[:, ~self.mask]
        out_a, log_det_jacob_a = self.transform(
            in_a, self.net(in_b).view(*in_a.shape, -1).squeeze(dim=-1)
        )
        log_det_jacob.add_(log_det_jacob_a)

        intermediates = torch.clone(inputs)
        intermediates[:, self.mask] = out_a

        outputs = torch.clone(intermediates)
        out_b, log_density_b = self.sample_conditional(intermediates)
        outputs[:, ~self.mask] = out_b

        if self.final_layer:
            log_det_jacob.add_(log_density_b)

        return outputs, log_det_jacob

In [4]:
class Model(pl.LightningModule):
    """Module which learns to sample from XY model."""

    def __init__(
        self,
        *,
        xy_coupling: float,
        lattice_shape: list[int],
        n_blocks: int,
        conditional_blocks: list[int],
        n_spline_segments: int,
        net_hidden_shape: list[int],
        net_activation: torch.nn.Module,
        use_shift_coupling_layers: bool = False,
        use_random_rotations: bool = False,
    ):
        super().__init__()
        self.save_hyperparameters()

        layers = []
        for i in range(n_blocks):
            if use_shift_coupling_layers:
                layers.append(
                    CouplingBlock(
                        transforms.PointwisePhaseShift(),
                        lattice_shape,
                        net_hidden_shape,
                        net_activation=torch.nn.Tanh(),
                        net_final_activation=torch.nn.Hardtanh(-PI, PI),
                    )
                )
            if use_random_rotations:
                layers.append(utils.RandomRotationLayer())

            if i in conditional_blocks:
                layers.append(
                    CouplingBlockHalfConditional(
                        transforms.PointwiseRationalQuadraticSplineTransform(
                            n_spline_segments
                        ),
                        lattice_shape,
                        net_hidden_shape,
                        net_activation,
                        net_final_activation=torch.nn.Identity(),
                        coupling_strength=xy_coupling,
                        final_layer=(True if i == n_blocks - 1 else False),
                    )
                )
            else:
                layers.append(
                    CouplingBlock(
                        transforms.PointwiseRationalQuadraticSplineTransform(
                            n_spline_segments
                        ),
                        lattice_shape,
                        net_hidden_shape,
                        net_activation,
                        net_final_activation=torch.nn.Identity(),
                    )
                )

        self.flow = utils.Flow(*layers)
        self.action = actions.XYSpinAction(xy_coupling)
        self.curr_iter = 0

    def _save_checkpoint(self):
        """Hack: should be done automatically by Lightning but I never progress
        past epoch=0 and this is easier than fiddling with other things.
        """
        self.trainer.save_checkpoint(
            self.logger.log_dir + f"/checkpoints/iter_{self.curr_iter}.ckpt"
        )

    def forward(self, batch):
        """z ~ Unif -> x ~ XY"""
        z, log_prob_z = batch
        x, log_det_jacob = self.flow(z)
        weights = log_prob_z - log_det_jacob + self.action(x)
        return x, weights

    def on_train_start(self):
        self._save_checkpoint()
        # Dirty but more convenient than saving / loading config files atm
        self.logger.log_hyperparams(self.hparams)

    def on_train_end(self):
        self._save_checkpoint()

    def training_step(self, batch, batch_idx):
        _, weights = self.forward(batch)
        loss = weights.mean()
        self.log("loss", loss, logger=True)
        self.lr_schedulers().step()  # must be called manually since lightning 1.3!!
        self.curr_iter += 1
        return loss

    def validation_step(self, batch, batch_idx):
        x, weights = self.forward(batch)
        loss = weights.mean()
        metrics = dict(
            loss=loss,
            acceptance=utils.metropolis_acceptance(weights),
            mag_sq=utils.magnetisation_sq(x).mean(),
        )
        self.log_dict(
            metrics,
            prog_bar=False,
            logger=True,
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.flow.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.trainer.max_steps
        )
        return [optimizer], [scheduler]

    @torch.no_grad()
    def sample(self, prior: IterableDataset, n_iter: int = 1):
        x, weights = self.forward(next(prior))
        for _ in range(n_iter - 1):
            _x, _weights = self.forward(next(prior))
            x = torch.cat((x, _x), dim=0)
            weights = torch.cat((weights, _weights), dim=0)
        return x, weights

## Training

In [5]:
# XY parameters
LATTICE_SHAPE = [6]
XY_COUPLING = 0.8

# Model parameters
N_BLOCKS = 4
CONDITIONAL_BLOCKS = [0, 1, 2, 3]
N_SPLINE_SEGMENTS = 8
NET_HIDDEN_SHAPE = [32]
NET_ACTIVATION = torch.nn.Tanh()
USE_SHIFT_COUPLING_LAYERS = False
USE_RANDOM_ROTATIONS = False

# Training hyperparameters
N_TRAIN = 5000
N_BATCH = 1000
N_BATCH_VAL = 10000

In [6]:
model = Model(
    xy_coupling=XY_COUPLING,
    lattice_shape=LATTICE_SHAPE,
    n_blocks=N_BLOCKS,
    conditional_blocks=CONDITIONAL_BLOCKS,
    n_spline_segments=N_SPLINE_SEGMENTS,
    net_hidden_shape=NET_HIDDEN_SHAPE,
    net_activation=NET_ACTIVATION,
    use_shift_coupling_layers=USE_SHIFT_COUPLING_LAYERS,
    use_random_rotations=USE_RANDOM_ROTATIONS,
)

# Could wrap in DataSet(IterableDataset, batch_size=None) but can't currently see benefit
unif = torch.distributions.Uniform(-PI, PI)
train_dataloader = utils.Prior(
    distribution=unif, sample_shape=[N_BATCH, *LATTICE_SHAPE]
)
val_dataloader = utils.Prior(
    distribution=unif, sample_shape=[N_BATCH_VAL, *LATTICE_SHAPE]
)

logger = pl.loggers.TensorBoardLogger(
    save_dir="test",
    name="spins",
)
pbar = utils.JlabProgBar()
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")

trainer = pl.Trainer(
    gpus=1,
    max_steps=N_TRAIN,  # total number of training steps
    val_check_interval=100,  # how often to run sampling
    limit_val_batches=1,  # one batch for each val step
    # logger=logger,
    callbacks=[pbar, lr_monitor],
    enable_checkpointing=False,  # manually saving checkpoints
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [7]:
trainer.fit(model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params
------------------------------
0 | flow | Flow | 10.0 K
------------------------------
10.0 K    Trainable params
0         Non-trainable params
10.0 K    Total params
0.040     Total estimated model params size (MB)


Epoch 0: : 5101it [02:17, 37.21it/s, loss=-6.68, v_num=4]             


In [8]:
%tensorboard --logdir lightning_logs