# Introduction to Conditional Density Estimation


In [None]:
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

import math

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from math import pi, log
from torch.utils import data
import numpy as np
from torch import tensor
from torch.distributions import Normal
import matplotlib as mpl

import jupyter_black

jupyter_black.load()

_ = torch.manual_seed(0)

# Density estimation

Density estimation means to estimate the probability distribution $p(\theta)$ from samples $\theta_1, \theta_2, ...$.

There are many ways to do this:

- kernel density estimation
- non-parametric approaches with order statistics
- many neural network methods (GANs, VAEs, Normalizing flows,...)
- ...

Here, we will discuss (conditional) density estimation with Mixtures of Gaussians.

# Let's move to a simple example

In [None]:
mean = 1.0
std = 0.4

normal_dist = Normal(mean, std)  # We do not usually know this...
samples = normal_dist.sample((50,))  # ...but all we have are these samples.

In [None]:
with mpl.rc_context(fname=".matplotlibrc"):
    fig = plt.figure(figsize=(6, 2))
    plt.plot(
        samples,
        np.zeros_like(samples),
        "bx",
        alpha=0.5,
        markerfacecolor="none",
        markersize=6,
    )
    _ = plt.xlim([-4, 4])
    _ = plt.ylim([-0.12, 1.2])
    # plt.savefig("figures/samples_gauss.png", dpi=400, bbox_inches="tight")
    # plt.show()

<img src="figures/samples_gauss.png" alt="drawing" width="800"/>

Our goal is to estimate the underlying distribution of these samples.

# Exercise 1: fitting a Gaussian to data with maximum likelihood

### Learning mean and standard deviation

In [None]:
def prob_Gauss(x, mean, std):
    """Gaussian probability density function
    args:
        mean: mean of the Gaussian
        std: standard deviation of the Gaussian
    returns:
        probability density of a one dimensional Gaussian
    """

    # TODO: Implement the Gaussian probability density function

    prob = (
        1 / torch.sqrt(2 * pi * std**2) * torch.exp(-0.5 / std**2 * (x - mean) ** 2)
    )

    return prob

In [None]:
dataset = data.TensorDataset(samples)
train_loader = data.DataLoader(samples, batch_size=10)

learned_mean = torch.nn.Parameter(torch.zeros(1))
learned_log_std = torch.nn.Parameter(torch.zeros(1))

opt = optim.Adam([learned_mean, learned_log_std], lr=0.005)

for e in range(500):
    for sample_batch in train_loader:
        opt.zero_grad()
        learned_std = torch.exp(learned_log_std)
        prob = prob_Gauss(sample_batch, learned_mean, learned_std)
        loss = -torch.log(prob).mean()
        loss.backward()
        opt.step()

In [None]:
print(
    "Learned mean: ",
    learned_mean.item(),
    ", learned standard deviation: ",
    torch.exp(learned_log_std).item(),
)

print(
    "True mean: ",
    mean,
    ", true standard deviation: ",
    std,
)

In [None]:
true_dist = Normal(mean, std)
learned_dist = Normal(learned_mean, torch.exp(learned_log_std))

In [None]:
x = torch.linspace(-4, 4, 100)
true_probs = torch.exp(normal_dist.log_prob(x))
learned_probs = torch.exp(learned_dist.log_prob(x)).detach()

In [None]:
with mpl.rc_context(fname=".matplotlibrc"):
    fig = plt.figure(figsize=(6, 2))
    plt.plot(
        samples,
        np.zeros_like(samples),
        "bx",
        alpha=0.5,
        markerfacecolor="none",
        markersize=6,
    )
    plt.plot(x, true_probs)
    plt.plot(x, learned_probs)
    plt.legend(["Samples", "Ground truth", "Learned"], loc="upper left")
    _ = plt.xlim([-4, 4])
    _ = plt.ylim([-0.12, 1.2])
    # plt.savefig("figures/fitted_samples.png", dpi=200, bbox_inches="tight")
plt.show()

<img src="figures/fitted_samples.png" alt="drawing" width="800"/>

### Questions to think about

1) Why do we parameterize the logarithm of the standard deviation instead of the standard deviation itself?  
2) We are evaluating the probability in batches of samples. Does this also work for single samples (i.e. non-batched)?

# Conditional density estimation


### Let's look at a simple model
Generate data from $x = \theta + 0.3 \sin(2 \pi \theta) + \epsilon$

In [None]:
# generate data
n = 4000  # number of datapoints
d = 1  # dimensionality of parameters theta

theta = torch.rand((n, d))
noise_std = 0.05

def generate_data(theta, noise_std):
    noise = torch.randn(theta.shape) * noise_std
    x = theta + 0.3 * torch.sin(2 * torch.pi * theta) + noise
    return x

x = generate_data(theta, noise_std)

In [None]:
# set the conditional value, in which we are interested in
# p(. | x = val_to_eval)
val_to_eval = torch.tensor(0.1)

In [None]:
# Data t_train and x_train
with mpl.rc_context(fname=".matplotlibrc"):
    fig = plt.figure(figsize=(4.5, 2.2))
    plt.plot(theta[:400], x[:400], "go", alpha=0.4, markerfacecolor="none")
    plt.axvline(val_to_eval.numpy())
    plt.xlabel(r"$\theta$")
    plt.ylabel("x")
    # plt.savefig("figures/cde_samples.png", dpi=200, bbox_inches="tight")
    plt.show()

In [None]:
# Data t_train and x_train

mean = val_to_eval + 0.3 * torch.sin(2 * pi * val_to_eval)

with mpl.rc_context(fname=".matplotlibrc"):
    fig, ax = plt.subplots(1, 1, figsize=(4.5, 2.2))
    x_vals = torch.linspace(-0.2, 1.0, 1000)
    n = Normal(mean, 0.05)
    y_vals = n.log_prob(x_vals).exp().numpy()
    x_vals = x_vals.numpy()
    plt.plot(x_vals, np.flip(y_vals), linewidth=2)
    plt.xticks([])
    plt.yticks([])
    ax.spines["left"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.set_ylabel(r"$p(x|\theta=0.1)$", rotation=90, fontsize=18)
    # plt.savefig("figures/conditional.png", dpi=300, bbox_inches="tight")
    plt.show()

The goal is to estimate parameters of a conditional distribution:

<img src="figures/cond_assembled.png" alt="drawing" width="1100"/>  

<img src="figures/nn.png" alt="drawing" width="900"/>  

# Exercise 2

### Learning mean and standard deviation conditional on inputs

In [None]:
# Todo: implement the neural net as fully connected network and ReLU activation function
# net = nn.Sequential(...)

# Hint: what is the input dimensionality of the network? and the output dimensionality?
# the output of the net should be the mean and the log_std of the Gaussian distribution.
# one hidden layer with 20 neurons is sufficient. .

# Solution:
net = nn.Sequential(
    nn.Linear(1, 20), nn.ReLU(), nn.Linear(20, 20), nn.ReLU(), nn.Linear(20, 2)
)

In [None]:
dataset = data.TensorDataset(theta, x)
train_loader = data.DataLoader(dataset, batch_size=20)

opt = optim.Adam(net.parameters(), lr=0.01)
for e in range(50):
    for theta_batch, x_batch in train_loader:
        opt.zero_grad()
        nn_output = net(theta_batch)
        mean = nn_output[:, 0].unsqueeze(1)
        std = torch.exp(nn_output[:, 1]).unsqueeze(1)
        prob = prob_Gauss(x_batch, mean, std)
        loss = -torch.log(prob).sum()
        loss.backward()
        opt.step()

We now have a neural network which outputs the parameters of a probability density distribution given inputs.
In this case, it was a Gaussian distribution of which we learned the mean and standard deviation. To evaluate the result, we can look at samples from the returned distribution.

### Inspect the distribution for t = 0.1

In [None]:
theta_test = tensor([0.1])
nn_output = net(theta_test)
conditional_mean = nn_output[0].detach().numpy()
conditional_std = torch.exp(nn_output[1]).detach().numpy()
print(
    "Learned: Conditional mean: ",
    conditional_mean,
    ", conditional std: ",
    conditional_std,
)

print(
    "True: Conditional mean: ",
    generate_data(theta_test, noise_std=0.0).item(),
    ", conditional std: ",
    noise_std,
)

In [None]:
# Data t_train and x_train
with mpl.rc_context(fname=".matplotlibrc"):
    fig = plt.figure(figsize=(4.5, 2.2))
    plt.plot(theta[:400], x[:400], "go", alpha=0.4, markerfacecolor="none", zorder=-100)
    plt.plot(
        [theta_test, theta_test],
        [conditional_mean - conditional_std, conditional_mean + conditional_std],
        c="r",
        linewidth=2,
    )
    plt.scatter(theta_test, conditional_mean, c="r", s=30, alpha=1.0)
    plt.xlabel(r"$\theta$")
    plt.ylabel("x")
    # plt.savefig("figures/cde_fitted.png", dpi=200, bbox_inches="tight")
    plt.show()

<img src="figures/cde_fitted.png" alt="drawing" width="600"/>

### Alternative evaluation: for every $t$, sample from the Gaussian

In [None]:
samples = []
theta_test = torch.linspace(0.0, 1.0, 500).unsqueeze(1)

for single_theta in theta_test:
    network_outs = net(single_theta.unsqueeze(1))
    m = network_outs[:, 0]
    s = torch.exp(network_outs[:, 1])
    conditional_distribution = Normal(m, s)
    sample = conditional_distribution.sample((1,))

    samples.append(sample)
samples = torch.cat(samples).data

In [None]:
# plot
with mpl.rc_context(fname=".matplotlibrc"):
    fig = plt.figure(figsize=(4.5, 2.2))
    plt.plot(theta[:400], x[:400], "go", alpha=0.5, markerfacecolor="none")
    plt.plot(theta_test, samples.squeeze().detach(), "ro", linewidth=3.0)
    plt.xlabel(r"$\theta$")
    plt.ylabel("x")
    # plt.savefig("figures/cde_fitted_samples.png", dpi=200, bbox_inches="tight")
    plt.show()

<img src="figures/cde_fitted_samples.png" alt="drawing" width="500"/>

# But what if the conditional distribution is not Gaussian?

Again, generate data from $x = t + 0.3 \sin(2 \pi t) + \epsilon$. 

But now, predict $\theta$ from $x$, i.e. learn $p(\theta | x)$.

In [None]:
with mpl.rc_context(fname=".matplotlibrc"):
    fig = plt.figure(figsize=(4.5, 2.2))
    plt.plot(x[:400], theta[:400], "go", alpha=0.5, markerfacecolor="none")
    plt.xlabel("x")
    plt.ylabel(r"$\theta$")
    # plt.savefig("figures/cde_samples_inv.png", dpi=200, bbox_inches="tight")
    plt.show()

<img src="figures/cde_samples_inv.png" alt="drawing" width="600"/>

In [None]:
_ = torch.manual_seed(1)

### Let's train this just like before, but now predicting $\theta$ from $x$.

In [None]:
dataset = data.TensorDataset(x, theta)
train_loader = data.DataLoader(dataset, batch_size=20)

net = nn.Sequential(
    nn.Linear(1, 20), nn.ReLU(), nn.Linear(20, 20), nn.ReLU(), nn.Linear(20, 2)
)

opt = optim.Adam(net.parameters(), lr=0.01)
for e in range(100):
    for x_batch, theta_batch in train_loader:
        opt.zero_grad()
        nn_output = net(x_batch)
        mean = nn_output[:, 0].unsqueeze(1)
        std = torch.exp(nn_output[:, 1]).unsqueeze(1)
        prob = prob_Gauss(theta_batch, mean, std)
        loss = -torch.log(prob).sum()
        loss.backward()
        opt.step()

In [None]:
# Generate samples from the learned conditional distribution

x_test = torch.linspace(-0.1, 1.1, 500).unsqueeze(1)

network_outs = net(x_test)
mean = network_outs[:, 0]
std = torch.exp(network_outs[:, 1])
conditional_distribution = Normal(mean, std)
sample = conditional_distribution.sample((1,))

In [None]:
# plot
with mpl.rc_context(fname=".matplotlibrc"):
    fig = plt.figure(figsize=(4.5, 2.2))
    plt.plot(x[:400], theta[:400], "go", alpha=0.5, markerfacecolor="none")
    plt.plot(x_test, samples.squeeze().detach(), "ro", linewidth=3.0)
    plt.xlabel("x")
    plt.ylabel(r"$\theta$")
    # plt.savefig("figures/cde_fitted_inv_samples.png", dpi=200, bbox_inches="tight")
    plt.show()

<img src="figures/cde_fitted_inv_samples.png" alt="drawing" width="900"/>

Can you explain this result?


# Mixture Density Networks
$$
$$
<img src="figures/architecture.png" alt="drawing" width="1200"/>

The loss is now the probability under this mixture of Gaussians  <br/>

$\mathcal{L} = -\sum_i \log q(\theta_i | x_i) = -\sum_i \log \sum_j \alpha_{j,i} \mathcal{N}(\theta_i; \mu_{j,i}(x_i), \sigma_{j,i}(x_i)$  <br/>

We learn the mixture components $\alpha_j$, the means $\mu_j$, and the variances $\sigma_j$.

At the moment we will restrict ourselves to diagonal covariance matrices, meaning we can not learn any correlation structure. 

In [None]:
class MultivariateGaussianMDN(nn.Module):
    """
    Multivariate Gaussian MDN with diagonal Covariance matrix.

    For a documented version of this code, see:
    https://github.com/mackelab/pyknos/blob/main/pyknos/mdn/mdn.py
    """

    def __init__(
        self,
        features,
        hidden_net,
        num_components,
        hidden_features,
    ):
        super().__init__()

        self._features = features
        self._num_components = num_components

        self._hidden_net = hidden_net
        self._logits_layer = nn.Linear(hidden_features, num_components)
        self._means_layer = nn.Linear(hidden_features, num_components * features)
        self._unconstrained_diagonal_layer = nn.Linear(
            hidden_features, num_components * features
        )

    def get_mixture_components(self, context):
        h = self._hidden_net(context)

        # mixture coefficients in log space
        logits = self._logits_layer(h)
        logits = logits - torch.logsumexp(logits, dim=1).unsqueeze(1)  # normalization

        # means
        means = self._means_layer(h).view(-1, self._num_components, self._features)

        # log variances for diagonal Cov matrix
        # otherwise: Cholesky decomposition s.t. Cov = AA^T, A is lower triangular.
        log_variances = self._unconstrained_diagonal_layer(h).view(
            -1, self._num_components, self._features
        )
        variances = torch.exp(log_variances)

        return logits, means, variances

In [None]:
def mog_log_prob(theta, logits, means, variances):
    """ Log probability of a mixture of Gaussians.
        args:
            theta: parameters
            logits: log mixture coefficients
            means: means of the Gaussians
            variances: variances of the Gaussians
        returns:
            log probability of the mixture of Gaussians"""
    _, _, theta_dim = means.size()
    theta = theta.view(-1, 1, theta_dim)

    log_cov_det = -0.5 * torch.log(torch.prod(variances, dim=2))

    a = logits
    b = -(theta_dim / 2.0) * log(2 * pi)
    c = log_cov_det
    d1 = theta.expand_as(means) - means
    precisions = 1.0 / variances
    exponent = torch.sum(d1 * precisions * d1, dim=2)
    exponent = tensor(-0.5) * exponent

    return torch.logsumexp(a + b + c + exponent, dim=-1)


def mog_sample(logits, means, variances):
    """Sample from a mixture of Gaussians.
        args:
            logits: log mixture coefficients
            means: means of the Gaussians
            variances: variances of the Gaussians   
        returns:
            samples from the mixture of Gaussians"""
    
    coefficients = F.softmax(logits, dim=-1)
    choices = torch.multinomial(coefficients, num_samples=1, replacement=True).view(-1)
    chosen_means = means[0, choices, :]  # 0 for first batch position
    chosen_variances = variances[0, choices, :]

    _, _, output_dim = means.shape
    standard_normal_samples = torch.randn(output_dim)
    zero_mean_samples = standard_normal_samples * torch.sqrt(chosen_variances)
    samples = chosen_means + zero_mean_samples

    return samples

## Training the network

In [None]:
_ = torch.manual_seed(1)

In [None]:
dataset = data.TensorDataset(x, theta)
train_loader = data.DataLoader(
    dataset,
    batch_size=50,
)

# TODO: implement the neural net as fully connected network and ReLU activation function
# use 30 hidden dimensions
# hidden_net = nn.Sequential(...)

###########
hidden_net = nn.Sequential(
    nn.Linear(1, 30),
    nn.ReLU(),
    nn.Linear(30, 30),
    nn.ReLU(),
    nn.Linear(30, 30),
    nn.ReLU(),
)
###########

mdn = MultivariateGaussianMDN(
    features=1,
    hidden_net=hidden_net,
    num_components=5,
    hidden_features=30,
)
opt = optim.Adam(mdn.parameters(), lr=0.001)
for e in range(200):
    for x_batch, theta_batch in train_loader:
        opt.zero_grad()
        logits, means, variances = mdn.get_mixture_components(x_batch)
        log_probs = mog_log_prob(theta_batch, logits, means, variances)
        loss = -log_probs.sum()
        loss.backward()
        opt.step()

In [None]:
# lets sample from the learned distribution for different x_test
samples = []
x_test = torch.linspace(-0.1, 1.1, 500).unsqueeze(1)

for single_x in x_test:
    ### TODO: sample from the learned distribution
    # logits, means, variances = ...
    # sample = ...

    ###########
    logits, means, variances = mdn.get_mixture_components(single_x.unsqueeze(1))
    sample = mog_sample(logits, means, variances)
    ###########

    samples.append(sample)
samples = torch.cat(samples).data

In [None]:
# lets evaluate the posterior at a singe x_demo to get the one dimensionsl MoG density
x_demo = torch.as_tensor([[0.6]])
# get parameters of the MoG
logits, means, variances = mdn.get_mixture_components(x_demo)
demo_probs = []
for t_demo in torch.linspace(-0.1, 1.0, 100):
    # get the log_probs for this all t_demo| x_demo
    prob = mog_log_prob(torch.as_tensor([t_demo]), logits, means, variances)
    demo_probs.append(prob)
demo_probs = torch.stack(demo_probs).detach().exp()

In [None]:
with mpl.rc_context(fname=".matplotlibrc"):
    fig = plt.figure(figsize=(4.5, 2.2))
    plt.plot(x[:400], theta[:400], "go", alpha=0.5, markerfacecolor="none")
    plt.plot(x_test, samples.squeeze().detach(), "ro", linewidth=3.0)
    plt.plot(demo_probs.numpy() * 0.1 + 0.6, np.linspace(-0.1, 1, 100))
    plt.plot([0.6] * 100, np.linspace(-0.1, 1, 100), linestyle="--", color="k")
    plt.ylim([-0.05, 1.1])
    plt.xlabel("x")
    plt.ylabel(r"$\theta$")
    # plt.savefig("figures/cde_fitted_inv_samples_mdn.png", dpi=200, bbox_inches="tight")
    plt.show()

<img src="figures/cde_fitted_inv_samples_mdn.png" alt="drawing" width="1000"/>

# What are the limitations?
<br/>
- One has to decide on hyperparameters: e.g. how many components (i.e. Gaussians) will I need? <br/>
<br/>
- Training often does not converge perfectly. <br/>
<br/>
- In practice, high-dimensional distributions might not be captured perfectly even with a high number of components. <br/>
<br/>
- Normalizing flows, VAEs, and GANs are more flexible. However, Mixtures of Gaussians often allow for useful computations in closed form and are still sometimes used in practice.<br/>
<br/>

## Summary: mixture density networks

<br/>
- Mixture density networks predict parameters of a mixture of Gaussians <br/>
<br/>
- Because of this, this can capture variability and multi-modality <br/>
<br/>
- To look at their predictions, one can either draw samples from the predicted mixture of Gaussians or evaluate the probability of the prediction

# Acknowledgments
<br/>

An initial versionn of this notebook was created by Michael Deistler.<br/>
Parts of the code from: https://mikedusenberry.com/mixture-density-networks  <br/>
Code of MDNs based on Conor Durkan's `lfi` package.  <br/>
Bishop, 1994  <br/>
[MDN graphic](https://towardsdatascience.com/a-hitchhikers-guide-to-mixture-density-networks-76b435826cca)  
Pedro Gonçalves et al. for figure.