# Training normalising flows with weights

See Appendix A of [Williams et al. 2023](https://arxiv.org/abs/2302.08526) for details

Michael J. Williams

In [None]:
import copy

from glasflow.flows import RealNVP
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from sklearn.model_selection import train_test_split
import torch

If `CLIP` is true, samples with weights smaller than the smallest positive float that can be represented by the current torch floating point precision will be removed from the training set. Else, all samples will be used.

In [None]:
CLIP = False

## Training data

In [None]:
dims = 2
training_dist = stats.multivariate_normal(mean=np.zeros(dims), cov=25 * np.eye(dims))
target_dist = stats.multivariate_normal(mean=2.0 * np.ones(dims), cov=2 *np.eye(dims))

In [None]:
n_train = 10_000
samples = training_dist.rvs(size=n_train)

weights = (target_dist.pdf(samples) / training_dist.pdf(samples))

In [None]:
plt.scatter(samples[:, 0], samples[:, 1], c=weights, s=1.0)
plt.colorbar(label="Weight")
plt.show()

In [None]:
if CLIP:
    remove = weights < torch.finfo().eps
    samples = samples[~remove]
    weights = weights[~remove]
    plt.scatter(samples[:, 0], samples[:, 1], c=weights, s=1.0)
    plt.colorbar(label="Weight")
    plt.show()

## Training

In [None]:
def get_dataloaders(samples, weights, batch_size=1000):
    x_train, x_val, w_train, w_val = train_test_split(samples, weights, shuffle=False)
    x_train_tensor = torch.from_numpy(x_train.astype(np.float32))
    w_train_tensor = torch.from_numpy(w_train.astype(np.float32))
    train_dataset = torch.utils.data.TensorDataset(x_train_tensor, w_train_tensor)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=False
    )

    x_val_tensor = torch.from_numpy(x_val.astype(np.float32))
    w_val_tensor = torch.from_numpy(w_val.astype(np.float32))
    val_dataset = torch.utils.data.TensorDataset(x_val_tensor, w_val_tensor)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False
    )
    return train_loader, val_loader

In [None]:
def loss_fn(log_prob, weights):
    return - torch.sum(log_prob * weights) / torch.sum(weights)

In [None]:
def train(flow, samples, weights, epochs=100, device="cpu", **kwargs):

    flow = flow.to(device)

    train_loader, val_loader = get_dataloaders(samples, weights, **kwargs)

    loss = dict(
        train=[],
        val=[],
    )

    best_val_loss = np.inf
    best_epoch = np.nan
    best_flow = None

    optimiser = torch.optim.Adam(flow.parameters(), lr=0.001, weight_decay=1e-5)

    for i in range(epochs):
        flow.train()
        train_loss = 0.0
        for batch in train_loader:
            x, w = batch
            x = x.to(device)
            w = w.to(device)
            optimiser.zero_grad()
            _loss = loss_fn(flow.log_prob(x), w)
            _loss.backward()
            optimiser.step()
            train_loss += _loss.item()
        loss["train"].append(train_loss / len(train_loader))

        flow.eval()
        val_loss = 0.0
        for batch in val_loader:
            x, w = batch
            x = x.to(device)
            w = w.to(device)
            with torch.no_grad():
                _loss = loss_fn(flow.log_prob(x), w)
            val_loss += _loss
        val_loss /= len(val_loader)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = i
            best_flow = copy.deepcopy(flow)

        loss["val"].append(val_loss)
        if not i % 20:
            print(
                f"Epoch {i} - train: {loss['train'][-1]:.3f}, val: {loss['val'][-1]:.3f}"
            )
            with torch.no_grad():
                new_samples = flow.sample(1000).cpu().numpy()
            plt.scatter(samples[:, 0], samples[:, 1], s=1)
            plt.scatter(new_samples[:, 0], new_samples[:, 1], s=1)
            plt.show()
    print(f"Returning best flow from epoch: {best_epoch}")
    return best_flow, loss

In [None]:
flow = RealNVP(
    n_inputs=2,
    n_transforms=4,
    n_neurons=32,
    batch_norm_between_transforms=False,
    linear_transform=None,
)

In [None]:
flow, loss = train(flow, samples, weights)

In [None]:
plt.plot(loss["train"], label="Train")
plt.plot(loss["val"], label="Val.")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

## Examining the trained flow

In [None]:
flow.eval()
with torch.inference_mode():
    new_samples = flow.sample(10_000)
new_samples = new_samples.cpu().numpy()

In [None]:
plt.scatter(samples[:, 0], samples[:, 1], s=1.0, label="Training data")
plt.scatter(new_samples[:, 0], new_samples[:, 1], s=1.0, label="New samples")
plt.legend()
plt.show()