Example of training a ConditionalDiagonalNormal from nflows.

Michael J. Williams 2023

In [None]:
from glasflow.nflows.distributions import ConditionalDiagonalNormal
from glasflow.nets import MLP
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import optim

In [None]:
n = 10000
data = torch.cat(
    [torch.randn(n, 2) - 5, torch.randn(n, 2) + 5],
)
conditional = torch.cat(
    [torch.ones(n, 1), torch.zeros(n, 1)]
)

In [None]:
idx = np.random.choice(len(conditional), size=len(conditional), replace=False)
data, conditional = data[idx], conditional[idx]

In [None]:
plt.scatter(data[:, 0], data[:, 1], c=conditional)
plt.show()

In [None]:
encoder = MLP(1, 4, [64, 64])

In [None]:
dist = ConditionalDiagonalNormal((2,), context_encoder=encoder)

In [None]:
def train(dist, n_iter, data):
    train_loss = []
    optimizer = optim.Adam(dist.parameters())
    for i in range(n_iter):
        t_loss = 0

        x, conditional = data
        optimizer.zero_grad()
        loss = -dist.log_prob(x, context=conditional).mean()
        loss.backward()
        optimizer.step()
        t_loss += loss.item()

        train_loss.append(t_loss)
    return train_loss

In [None]:
loss = train(dist, 500, [data, conditional])

In [None]:
plt.plot(loss)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

**Note:** the output of the `sample` will have shape `(len(context), n_samples, n_dims)` since it provides `n_samples` for each entry in the context tensor

In [None]:
n_sample = 1000
dist.eval()
with torch.no_grad():
    samples_0 = np.squeeze(dist.sample(n_sample, context=torch.zeros(1, 1)).numpy())
    samples_1 = np.squeeze(dist.sample(n_sample, context=torch.ones(1, 1)).numpy())

In [None]:
plt.scatter(samples_0[:, 0], samples_0[:, 1])
plt.scatter(samples_1[:, 0], samples_1[:, 1])
plt.show()