# Setup

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tqdm.auto as tqdm
import torch
%matplotlib widget

In [None]:
def grab(x: torch.Tensor) -> np.ndarray:
    """Convert a torch Tensor to numpy array"""
    return x.detach().numpy()

In [None]:
def wrap(x):
    """Wrap angle into range [-pi, pi]"""
    return (x + np.pi) % (2*np.pi) - np.pi

# Brief ML primer

Let's quickly demonstrate how training looks in Pytorch. We will train a small neural network to model the function
$$
f(x) = \mathrm{sinc}(x) := \frac{\sin(\pi x)}{\pi x}
$$

In [None]:
def target_fn(x):
    return torch.sinc(x)
fig, ax = plt.subplots(1,1, figsize=(3.5, 2.5))
xs = torch.linspace(-5, 5, steps=51)
ys = target_fn(xs)
ax.plot(grab(xs), grab(ys))
plt.show()

In [None]:
class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(1, 8),
            torch.nn.SiLU(),
            torch.nn.Linear(8, 8),
            torch.nn.SiLU(),
            torch.nn.Linear(8, 1),
        )
    def forward(self, x):
        assert len(x.shape) == 1, 'x should just have a batch index'
        return self.net(x[:,None])[:,0]

In [None]:
def train_model():
    model = ToyModel()
    batch_size = 128
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_hist = []
    for i in tqdm.tqdm(range(25000)):
        opt.zero_grad()
        # random samples around 0 for the training points
        x = 3*torch.randn((batch_size,))
        model_y = model(x)
        true_y = target_fn(x)
        # mean squared error
        loss = ((true_y - model_y)**2).mean()
        loss.backward()
        opt.step()
        loss_hist.append(grab(loss))
        if (i+1) % 1000 == 0:
            print(f'Step {i+1}: Loss {grab(loss)}')
    return dict(model=model, loss=np.stack(loss_hist))
res = train_model()

In [None]:
fig, axes = plt.subplots(1,2, figsize=(8, 3), tight_layout=True)
xs = torch.linspace(-5, 5, steps=51)
true_ys = target_fn(xs)
model_ys = res['model'](xs)
ax = axes[0]
ax.plot(grab(xs), grab(true_ys), color='k', label='target')
ax.plot(grab(xs), grab(model_ys), color='xkcd:red', label='model')
ax.legend()
ax = axes[1]
ax.plot(res['loss'])
ax.set_ylabel('loss')
ax.set_yscale('log')
plt.show()

# Action
The general form of the action is
$S(\theta_1, \theta_2; \alpha, \beta) := -\beta \cos(\theta_1 - \theta_2) - \alpha \cos(\theta_1) + \alpha \cos(\theta_2)$.

In [None]:
def action(th, *, alpha, beta):
    """family of actions on two angles"""
    assert th.shape[-1] == 2
    th1, th2 = th[...,0] ,th[...,1]
    return (
        -beta * torch.cos(th1 - th2) - alpha * torch.cos(th1)
        + alpha * torch.cos(th2)
    )

def make_action(alpha, beta):
    return lambda th: action(th, alpha=alpha, beta=beta)

# some target parameters
beta_target = 3.0
alpha_target = 1.0
target_action = make_action(alpha_target, beta_target)

In [None]:
def sample_inds(weights):
    """resample indices according to weights"""
    p = np.copy(weights)
    p /= np.sum(p)
    return np.random.choice(len(weights), p=p, size=len(weights))

def sample(batch_size, action, *, beta0):
    """importance sampling to get ground truth data"""
    shape = (batch_size,)
    dist = torch.distributions.VonMises(0.0, beta0)
    delta = dist.sample(shape)
    S0 = dist.log_prob(delta)
    th1 = 2*np.pi*torch.rand(size=shape)
    th2 = (th1 - delta) % (2*np.pi)
    th = torch.stack([th1, th2], axis=-1)
    logw = -action(th) + S0
    logw -= torch.logsumexp(logw, dim=0)
    weight = np.exp(grab(logw))
    # resample
    inds = sample_inds(weight)
    return th[inds]

In [None]:
def make_th_grid(steps):
    th = torch.linspace(-np.pi, np.pi, steps=steps)
    th = (th[1:]+th[:-1])/2
    th = torch.stack(torch.meshgrid([th, th], indexing='ij'), axis=-1)
    return th
def plot_dist(action, *, ax, nsteps=60):
    th = make_th_grid(nsteps)
    S = action(th)
    th = grab(th)
    ax.contourf(th[...,0], th[...,1], np.exp(-grab(S)))
def plot_samples(th, *, ax, nbins=60):
    bins = np.linspace(-np.pi, np.pi, num=nbins+1)
    th = wrap(grab(th))
    ax.hist2d(th[...,0], th[...,1], bins=bins)

# Normalizing flow model

In [None]:
class ModelVelocity(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(2+2+1, 32),
            torch.nn.SiLU(),
            torch.nn.Linear(32, 32),
            torch.nn.SiLU(),
            torch.nn.Linear(32, 2),
        )
    def value(self, th, t):
        assert th.shape[-1] == 2
        t_expand = (t * torch.ones(th.shape[:-1])).unsqueeze(-1)
        return self.net(torch.cat([torch.cos(th), torch.sin(th), t_expand], dim=-1))
    def div(self, th, t):
        # NOTE: for high dimensions, this is expensive!
        # vmap allows th to have a batch index over which we vectorize this operation
        J = torch.func.vmap(torch.func.jacfwd(self.value, argnums=0))(th, t)
        trJ = torch.einsum('...ii->...', J)
        return trJ

In [None]:
def apply_flow(th, model, *, nsteps, tf=1.0, reverse=False):
    dt = tf/nsteps
    logJ = 0
    steps = range(nsteps)
    if reverse:
        steps = reversed(steps)
    for i in steps:
        t = i*dt
        t = torch.ones((th.shape[0],)) * t
        v = model.value(th, t)
        assert v.shape == th.shape
        div = model.div(th, t)
        sign = -1 if reverse else 1
        th = wrap(th + sign * dt * v)
        logJ = logJ + div * dt
    return th, logJ

In [None]:
def action_from_flow(model, *, nsteps, tf=1.0):
    def action(th):
        _, logJ = apply_flow(th, model, nsteps=nsteps, tf=tf, reverse=True)
        return logJ - 2*np.log(2*np.pi)
    return action

In [None]:
def compute_ess(logw):
    """effective sample size = <w>^2 / <w^2>"""
    return torch.exp(2*torch.logsumexp(logw, dim=0) - torch.logsumexp(2*logw, dim=0)) / len(logw)

In [None]:
def train(model, *, n_train, batch_size):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    hist = dict(loss=[], ess=[])
    for _ in tqdm.tqdm(range(n_train)):
        optimizer.zero_grad()
        prior_th = wrap(2*np.pi*torch.rand((batch_size, 2)))
        logr = -2*np.log(2*np.pi)
        flow_th, _ = apply_flow(prior_th, model, nsteps=50)
        # path gradients evaluation of logq
        model.requires_grad_(False)
        _, logJ = apply_flow(flow_th, model, nsteps=50, reverse=True)
        model.requires_grad_(True)
        logq = logr - logJ
        logp = -target_action(flow_th)
        loss = (logq - logp).mean()
        hist['loss'].append(grab(loss))
        with torch.no_grad():
            ess = compute_ess(logp - logq)
        hist['ess'].append(grab(ess))
        loss.backward()
        optimizer.step()
    return hist

In [None]:
model = ModelVelocity()
hist = train(model, n_train=200, batch_size=1024)

In [None]:
fig, axes = plt.subplots(2,1)
axes[0].plot(hist['loss'])
axes[0].set_ylabel('loss')
axes[1].plot(hist['ess'])
axes[1].set_ylabel('ess')
axes[1].set_xlabel('train iter')
plt.show()

In [None]:
fig, ax = plt.subplots(1,1)
prior_th = wrap(2*np.pi*torch.rand((16000, 2)))
with torch.no_grad():
    flow_th = apply_flow(prior_th, model, nsteps=100)[0]
plot_samples(flow_th, ax=ax)
ax.set_aspect(1.0)
plt.show()

**EXERCISE:** Adjust the model and training parameters to optimize the final loss and ESS.

In [None]:
def measure_coeffs_grid(S):
    """extract Wilson-like coeffs using the Fourier transform"""
    Sk = np.fft.ifft2(S)
    c = Sk[0,0]
    a1 = Sk[0,1] + Sk[0,-1]
    a2 = Sk[1,0] + Sk[-1,0]
    b1 = Sk[1,1] + Sk[-1,-1]
    b2 = Sk[1,-1] + Sk[-1,1]
    return dict(c=c, a1=a1, a2=a2, b1=b1, b2=b2)
def measure_coeffs(action):
    th = make_th_grid(50)
    grid = th.shape[:-1]
    S = grab(action(th.flatten(0,1)).reshape(grid))
    return measure_coeffs_grid(S)
def plot_coeffs(ts, coeffs, x='a1', y='b2', *, ax, cmap, marker='.', label=None):
    pts = np.stack([(coeff[x], coeff[y]) for coeff in coeffs], axis=1)
    cmap = plt.get_cmap(cmap)
    ax.scatter(*pts, marker=marker, s=3, color=cmap(ts), label=label)

In [None]:
target_coeffs = measure_coeffs(target_action)

In [None]:
ts = np.linspace(0.0, 1.0, num=11)
with torch.no_grad():
    flow_coeffs = [
        measure_coeffs(action_from_flow(model, nsteps=100, tf=t))
        for t in ts
    ]

We can finally look at the path through distribution space learned by the continuous normalizing flow:

In [None]:
fig, ax = plt.subplots(1,1, figsize=(3, 3), tight_layout=True)
plot_coeffs(ts, flow_coeffs, ax=ax, cmap='Reds_r', marker='o', label='Normalizing flow')
plot_coeffs([1.0], [target_coeffs], ax=ax, cmap='Greys', marker='x', label='Target')
ax.set_xlabel(r'$\alpha$')
ax.set_ylabel(r'$\beta$')
ax.legend()
plt.show()