# Normalising flow for the von Mises distribution

In [None]:
from __future__ import annotations

import math

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch

import transforms
import utils

PI = math.pi

plt.rcParams.update({"font.size": 18})

%load_ext lab_black
%load_ext tensorboard

# Fixed von Mises concentration

In [None]:
class UnconditionalLayer(torch.nn.Module):
    """Wraps around a transformation, adding a set of learnable parameters."""

    def __init__(self, transform):
        super().__init__()
        self.transform = transform
        self.params = torch.nn.Parameter(transform.identity_params.view(1, 1, -1))

    def forward(self, inputs, log_det_jacob):
        outputs, log_det_jacob_this = self.transform(
            inputs, self.params.expand(inputs.shape[0], 1, -1)
        )
        log_det_jacob.add_(log_det_jacob_this)
        return outputs, log_det_jacob

    def inverse(self, inputs, log_det_jacob):
        outputs, log_det_jacob_this = self.transform.inverse(
            inputs, self.params.expand(inputs.shape[0], 1, -1)
        )
        log_det_jacob.add_(log_det_jacob_this)
        return outputs, log_det_jacob


class FixedConcModel(pl.LightningModule):
    """Module which learns to transform uniform variates into von Mises variates."""

    def __init__(self, vonmises_conc, n_spline_segments):
        super().__init__()
        self.flow = utils.Flow(
            UnconditionalLayer(
                transforms.PointwiseRationalQuadraticSplineTransform(n_spline_segments)
            )
        )
        self.target = torch.distributions.VonMises(loc=0, concentration=vonmises_conc)

    def forward(self, batch):
        z, log_prob_z = batch  # shape (n_batch, 1)
        x, log_det_jacob = self.flow(z)
        log_prob_x = self.target.log_prob(x).view_as(log_det_jacob)  # flattened
        weights = log_prob_z - log_det_jacob - log_prob_x
        return x, weights

    def training_step(self, batch, batch_idx):
        _, weights = self.forward(batch)
        loss = weights.mean()
        self.log("loss", loss, logger=True)
        self.lr_schedulers().step()  # must be called manually since lightning 1.3!!
        return loss

    def validation_step(self, batch, batch_idx):
        _, weights = self.forward(batch)
        loss = weights.mean()
        self.log(
            "acceptance",
            utils.metropolis_acceptance(weights.flatten()),
            prog_bar=False,
            logger=True,
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.flow.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.trainer.max_steps
        )
        return [optimizer], [scheduler]

    @torch.no_grad()
    def sample(self, prior, n_iter=1):
        x, weights = self.forward(next(prior))
        for _ in range(n_iter - 1):
            _x, _weights = self.forward(next(prior))
            x = torch.cat((x, _x), dim=0)
            weights = torch.cat((weights, _weights), dim=0)
        return x, weights

In [None]:
VONMISES_CONC = 0.5
N_SPLINE_SEGMENTS = 8

N_TRAIN = 10000
N_BATCH = 1000
N_BATCH_VAL = 10000

model = FixedConcModel(VONMISES_CONC, N_SPLINE_SEGMENTS)

unif = torch.distributions.Uniform(-PI, PI)
train_dataloader = utils.Prior(distribution=unif, sample_shape=[N_BATCH, 1])
val_dataloader = utils.Prior(distribution=unif, sample_shape=[N_BATCH_VAL, 1])

pbar = utils.JlabProgBar()
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")

trainer = pl.Trainer(
    gpus=1,
    max_steps=N_TRAIN,
    val_check_interval=100,  # how often to run sampling
    limit_val_batches=1,
    callbacks=[pbar, lr_monitor],  # gpu_monitor],
    enable_checkpointing=False,
)

# Variable von Mises concentration

In [None]:
class ConditionalLayer(torch.nn.Module):
    """Wraps around a transformation, adding a neural network taking conc as input."""

    def __init__(self, transform, net_hidden_shape, net_activation):
        super().__init__()
        self.transform = transform

        nodes = [1, *net_hidden_shape, transform.params_dof]
        activations = [net_activation for _ in net_hidden_shape] + [torch.nn.Identity()]
        layers = []
        for d_in, d_out, f_act in zip(nodes[:-1], nodes[1:], activations):
            layers.append(torch.nn.Linear(d_in, d_out))
            layers.append(f_act)
        self.network = torch.nn.Sequential(*layers)

    def forward(self, inputs, log_det_jacob, vm_conc):
        params = self.network(vm_conc)
        outputs, log_det_jacob_this = self.transform(inputs, params.unsqueeze(dim=1))
        log_det_jacob.add_(log_det_jacob_this)
        return outputs, log_det_jacob

    def inverse(self, inputs, log_det_jacob, vm_conc):
        params = self.network(vm_conc)
        outputs, log_det_jacob_this = self.transform.inverse(
            inputs, params.unsqueeze(dim=1)
        )
        log_det_jacob.add_(log_det_jacob_this)
        return outputs, log_det_jacob


class VariableConcModel(pl.LightningModule):
    """Module which learns a mapping between uniform and von Mises distributions."""

    def __init__(
        self,
        max_concentration,
        n_spline_segments,
        net_hidden_shape,
        net_activation,
        use_fwd_kl=False,
        fwd_kl_interval=1,
        use_mobius=False,
        concentration_range=[0.02, 11],
    ):
        super().__init__()
        self.save_hyperparameters()

        # NOTE: sampling from torch.distributions.VonMises fails at conc=1e-4!!!
        self.concentration_distribution = torch.distributions.Uniform(
            *concentration_range
        )
        self.curr_iter = 1

        spline_layer = ConditionalLayer(
            transforms.PointwiseRationalQuadraticSplineTransform(n_spline_segments),
            net_hidden_shape,
            net_activation,
        )

        if use_mobius:
            self.flow = utils.Flow(utils.MobiusLayer(), spline_layer)
        else:
            self.flow = utils.Flow(spline_layer)

    def _uniform_to_vonmises(self, z, log_prob_z, concentrations):
        """z ~ Unif -> x ~ vM"""
        x, log_det_jacob = self.flow(z, concentrations)
        log_prob_x = (
            torch.distributions.VonMises(loc=0, concentration=concentrations)
            .log_prob(x)
            .view_as(log_det_jacob)
        )
        weights = log_prob_z - log_det_jacob - log_prob_x
        return x, weights

    def _vonmises_to_uniform(self, x, log_prob_x, concentrations):
        """x ~ vM -> z ~ Unif"""
        z, log_det_jacob = self.flow.inverse(x, concentrations)
        weights = log_prob_x - log_det_jacob + math.log(2 * PI)
        return z, weights

    def on_train_start(self):
        self.logger.log_hyperparams(self.hparams)

    def training_step(self, batch, batch_idx):
        z, log_prob_z = batch
        concentrations = self.concentration_distribution.sample(z.shape).to(z.device)

        if (
            self.hparams.use_fwd_kl
            and self.curr_iter % self.hparams.fwd_kl_interval == 0
        ):
            curr_vonmises_prior = torch.distributions.VonMises(
                loc=0, concentration=concentrations
            )
            x = curr_vonmises_prior.sample()
            log_prob_x = curr_vonmises_prior.log_prob(x).flatten()
            _, weights = self._vonmises_to_uniform(x, log_prob_x, concentrations)
        else:
            _, weights = self._uniform_to_vonmises(z, log_prob_z, concentrations)

        loss = weights.mean()
        self.log("loss", loss, logger=True)
        self.lr_schedulers().step()  # must be called manually since lightning 1.3!!
        self.curr_iter += 1
        return loss

    def validation_step(self, batch, batch_idx):
        z, log_prob_z = batch
        concentrations = self.concentration_distribution.sample(z.shape).to(z.device)
        x, weights = self._uniform_to_vonmises(z, log_prob_z, concentrations)
        loss = weights.mean()
        self.log(
            "acceptance",
            utils.metropolis_acceptance(weights.flatten()),
            prog_bar=False,
            logger=True,
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.flow.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_steps,
        )
        return [optimizer], [scheduler]

    @torch.no_grad()
    def sample(self, concentration, prior, n_iter=1):
        concentrations = torch.full(prior.sample_shape, concentration)
        x, weights = self._uniform_to_vonmises(*next(prior), concentrations)
        for _ in range(n_iter - 1):
            _x, _weights = self._uniform_to_vonmises(*next(prior), concentrations)
            x = torch.cat((x, _x), dim=0)
            weights = torch.cat((weights, _weights), dim=0)
        return x, weights

In [None]:
VONMISES_MAX_CONC = 10
N_SPLINE_SEGMENTS = 16
NET_HIDDEN_SHAPE = [8, 16, 32]
NET_ACTIVATION = torch.nn.LeakyReLU()
USE_FWD_KL = False
FWD_KL_INTERVAL = 2
USE_MOBIUS = True

N_TRAIN = 32000
N_BATCH = 8000
N_BATCH_VAL = 16000

model = VariableConcModel(
    VONMISES_MAX_CONC,
    N_SPLINE_SEGMENTS,
    NET_HIDDEN_SHAPE,
    NET_ACTIVATION,
    USE_FWD_KL,
    FWD_KL_INTERVAL,
    USE_MOBIUS,
)

unif = torch.distributions.Uniform(-PI, PI)
train_dataloader = utils.Prior(distribution=unif, sample_shape=[N_BATCH, 1])
val_dataloader = utils.Prior(distribution=unif, sample_shape=[N_BATCH_VAL, 1])

pbar = utils.JlabProgBar()
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")

trainer = pl.Trainer(
    gpus=1,
    max_steps=N_TRAIN,
    val_check_interval=100,  # how often to run sampling
    limit_val_batches=1,
    callbacks=[pbar, lr_monitor],  # gpu_monitor],
    enable_checkpointing=False,
)

In [None]:
trainer.fit(model, train_dataloader, val_dataloader)

In [None]:
N_CONC = 12
N_COLS = 4

concentrations = torch.linspace(
    model.concentration_distribution.low,
    model.concentration_distribution.high,
    N_CONC,
).tolist()
dom = torch.linspace(-math.pi, math.pi, 100)

fig, axes = plt.subplots(
    math.ceil(N_CONC / N_COLS),
    N_COLS,
    sharex=True,
    sharey=True,
    figsize=(20, 16),
)

acceptances = []

for ax, conc in zip(axes.flatten(), concentrations):
    x, weights = model.sample(conc, val_dataloader, n_iter=10)

    acceptance = utils.metropolis_acceptance(weights.flatten())
    acceptances.append(acceptance)

    target = torch.distributions.VonMises(loc=0, concentration=conc).log_prob(dom).exp()

    ax.hist(x.flatten().tolist(), bins=50, density=True)
    ax.plot(dom.tolist(), target.tolist(), "r-", label=f"conc = {conc:.2g}")
    ax.legend()

fig.suptitle("Histograms of model outputs")

acceptance_fig, acceptance_ax = plt.subplots()
acceptance_ax.plot(concentrations, acceptances, "o-")
acceptance_ax.set_xlabel("concentration")
acceptance_ax.set_ylabel("acceptance")

In [None]:
%tensorboard --logdir lightning_logs

In [None]:
trainer.save_checkpoint("trained_vonmises_sampler/model.ckpt")