# Neural Posterior Estimation for simulation-based inference

<br/>


In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
import numpy as np
from sbi.analysis import pairplot
from sbi.utils import BoxUniform
from torch.distributions import Normal

import pickle

import sys

from ball_throw import throw



import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

import jupyter_black

jupyter_black.load()


_ = torch.manual_seed(0)

# The main idea

In Neural Posterior Estimation (NPE) want to use conditional density estimation to learn the posterior $p(\theta | x)$.

As a first step, we have to generate a dataset that follows the joint density $p(\theta, x)$.

We can obtain this by sampling from $p(\theta)$ (the prior) and then sampling the likelihood $p(x | \theta)$ (i.e. simulating). 

The resulting $(\theta, x)$ pairs follow $p(\theta, x) = p(\theta)p(x|\theta)$.

### Neural Posterior estimation: recipe

- sample parameters $\theta$ from prior $p(\theta)$
- run each of these parameters through the (stochastic) simulator to obtain $x \sim p(x | \theta)$
- train a conditional density estimator on these data to learn $p(\theta | x)$:

<img src="figures/npe_illustration.png" alt="drawing" width="1000"/>

### The simulator
In the following example, we will use the physical example of a ball throw. 

A nice animation and explanation can be found here: http://www.physics.smu.edu/fattarus/ballistic.html

We have three free paramters $\theta = (\theta_1,\theta_2,\theta_3)$ for this simulator:
- $\theta_1$: speed: magnitude of initial speed (m/s).
- $\theta_2$: angle: launch angle with horizontal (degrees)
- $\theta_3$: drag: drag coefficient

We assume, that we can only observe a noisy version of the trajectory, because we can only measure the height imprecisely. 
We simulated this by adding independent Gaussian noise.

The implemetation can be found in `simulators/ball_throw.py`


In [None]:
# Let's run the simulator
velocity = 40
angle = 30
drag = 0.2
sim1 = throw(velocity, angle, drag)

In [None]:
# and look at the simulation
plt.plot(sim1[0], sim1[1])
plt.ylim(
    0,
)
plt.title("Ball throw")
plt.xlabel("distance [m]")
plt.ylabel("height [m]")

### Summary statistics
In principle, we can run NPE on the raw trajectory (more on that later). 
However, it is often preferable to define summary statistics that are of interest and try reproducing only those.

In [None]:
from ball_throw import (
    get_landing_distance,
    get_distance_at_highest_point,
    get_highest_point,
)

In [None]:
def calculate_summary_statistics(x):
    """Calculate summary statistics for results in x"""

    return np.array(
        [
            get_landing_distance(x),
            get_distance_at_highest_point(x),
            get_highest_point(x),
        ]
    )

In [None]:
def sbi_throw_with_sumstats(theta, return_raw_sims=False):
    """Wrapper for throw function to work with SBI.
    Arguments:
        theta: parameters (batch, 3) for throw
    returns:
        tensor: summary stats (batch,3)
    """

    sumstats = torch.zeros(theta.shape[0], 3)
    sim1 = throw(*theta[0])
    sims = np.zeros((theta.shape[0], sim1.shape[-1]))
    for i, theta1 in enumerate(theta):
        sim1 = throw(*theta1)
        sumstats[i] = torch.from_numpy(calculate_summary_statistics(sim1))
        sims[i] = sim1[1]

    sims[np.isnan(sims)] = 0
    if return_raw_sims:
        return sumstats, sims
    else:
        return sumstats

In [None]:
# Let's check for two different parameter sets
theta = torch.tensor([[21, 40, 0.1], [31, 72, 0.01]])
sbi_throw_with_sumstats(theta)

*Question:* Why are the summary statistics not always the same for one parameter $\theta$?

### The prior
We then have to define a prior (a "first guess of plausible values"). Here, we pick a uniform distribution within some bounds of reasonable values.

In [None]:
from sbi.utils import BoxUniform

*Question:*

 What are meaningful boundaries for a Boxuniform distribution for the velocity, angle and drag?

In [None]:
# we want to define a Box Uniform prior, and specify the boundaries here
prior_speed = (10, 50)  # m/s
prior_angle = (20, 80)  # degree
prior_drag = (0.1, 1)  #  drag
# define the prior with U(low, high), a box in 3 dimensions
prior = BoxUniform(*zip(prior_speed, prior_angle, prior_drag))

In [None]:
# Let's test by drawing some samples from the prior
theta = prior.sample((10,))
print(theta)

# Generate simulated data
We will run N simulations that will be used to train the conditional density estimator.

In [None]:
N = 100  # number of simulations

thetas = prior.sample((N,))

xs, sims = sbi_throw_with_sumstats(thetas, return_raw_sims=True)

In [None]:
# save the data
# data_dict = dict(thetas=thetas, xs=xs, sims=sims)
# with open("throw_dataset.pickle", "wb") as f:
#    pickle.dump(data_dict, f)

To save the time of simulating: We have already simulated 10_000 traces with different parameters.

These were our prior boundaries:

prior_speed = (10, 50)  # m/s

prior_angle = (20, 80)  # degree

prior_drag = (0.1, 1)  #  drag

In [None]:
num_simulations = 10_000

with open("throw_dataset.pickle", "rb") as f:
    data_dict = pickle.load(f)

thetas = data_dict["thetas"][:num_simulations]
xs = np.array(data_dict["xs"][:num_simulations])

# Data pre-processing

Let's inspect our simulation results. One thing we will realize is that some simulations produce `NaN`:

In [None]:
print("The summary statistics of the 1st simulation: ", xs[0])
print("The summary statistics of the 8th simulation: ", xs[7])
print("The summary statistics of some simulation: ", xs[1342])

In (Sequential) Neural **Posterior** estimation (SNPE), we can simply exclude those simulations from training for which at least one summary feature is `NaN`:

In [None]:
contains_no_nan = np.invert(np.any(np.isnan(xs), axis=1))
thetas_train = thetas[contains_no_nan]
xs_train = xs[contains_no_nan]

Sometimes, we also have to exclude very large values. 
These large values could break neural network training.
This is not the case here, but will be important for later (e.g. Lotka-Volterra model)

In [None]:
contains_no_inf = np.invert(np.any(xs_train < -1e6, axis=1))
thetas_train = thetas_train[contains_no_inf]
xs_train = xs_train[contains_no_inf]

contains_no_inf = np.invert(np.any(xs_train > 1e6, axis=1))
thetas_train = thetas_train[contains_no_inf]
xs_train = xs_train[contains_no_inf]

We also have to standardize (i.e. z-score) the data $X$ as well as the parameters $\theta$:

In [None]:
thetas_torch = torch.as_tensor(thetas_train, dtype=torch.float32)
xs_torch = torch.as_tensor(xs_train, dtype=torch.float32)

xs_mean = torch.mean(xs_torch, dim=0)
xs_std = torch.std(xs_torch, dim=0)
xs_zscored = (xs_torch - xs_mean) / xs_std

theta_mean = torch.mean(thetas_torch, dim=0)
theta_std = torch.std(thetas_torch, dim=0)
theta_zscored = (thetas_torch - theta_mean) / theta_std

# Exercise 1: Train neural network to learn $p(\theta | x)$

We now use a Mixture density network to learn the conditional density $p(\theta | x)$ (=the posterior).


In [None]:
from mdn import MultivariateGaussianMDN as MultivariateGaussianMDN_diag
from mdn import mog_log_prob, mog_sample

_ = torch.manual_seed(0)

In [None]:
dataset = data.TensorDataset(theta_zscored, xs_zscored)
train_loader = data.DataLoader(
    dataset,
    batch_size=500,
)
mdn_diag = MultivariateGaussianMDN_diag(
    features=3,  # theta dim
    hidden_net=nn.Sequential(
        nn.Linear(3, 10),  # input dim: number of summary statistics
        nn.ReLU(),
        nn.Linear(10, 10),
        nn.ReLU(),
        nn.Linear(10, 20),  # the last hiddden layer should match the hidden_features,
    ),
    num_components=4,
    hidden_features=20,  # what is a meaningful number here?
)

opt = optim.Adam(mdn_diag.parameters(), lr=0.001)
training_loss = []
for e in range(100):
    for theta_batch, x_batch in train_loader:
        opt.zero_grad()
        weights_of_gaussians, means, variances = mdn_diag.get_mixture_components(
            x_batch
        )
        out = mog_log_prob(theta_batch, weights_of_gaussians, means, variances)
        loss = -out.sum()
        loss.backward()
        opt.step()
        training_loss.append(loss.detach().item())

In [None]:
# let's have a look at the loss and see if the network converged
plt.plot(np.arange(len(training_loss)) / len(train_loader), training_loss)
plt.xlabel("training epochs")
plt.ylabel("-log p")
plt.show()

# Define an observation

We will now define an **observation** $x_o$, i.e. the data that we want for which we want to infer the posterior $p(\theta|x_o)$. In real problems, this will be an experimentally measured trace and we will not know the ground truth parameter.

In [None]:
# Let's run the simulator with a specific value. this is not known in real problems.
velocity = 40
angle = 30
drag = 0.3
# put this into one tensor
theta_gt = torch.tensor([velocity, angle, drag])

sim_o = throw(velocity, angle, drag)
x_o = torch.tensor(calculate_summary_statistics(sim_o))
# xo = sbi_throw_with_sumstats(torch.tensor([velocity, angle, drag]).unsqueeze(0))
print("summary stats for this simulation:", x_o)

plt.plot(sim_o[0], sim_o[1])
plt.ylim(
    0,
)
plt.title("Ball throw")
plt.xlabel("distance [m]")
plt.ylabel("height [m]")

Because we trained the neural network on z-scored data, we also have to z-score the summary stats of $x_o$:

In [None]:
xo_zscored = (x_o - xs_mean) / xs_std
xo_torch = torch.as_tensor(xo_zscored, dtype=torch.float32).unsqueeze(0)

### Draw samples from the posterior 

As we are working with summary statistics we have:
$p(\theta | x_o) = p(\theta | s(sim_o))$

In [None]:
n = 10_000
weigths_of_gaussians, means, variances = mdn_diag.get_mixture_components(xo_torch)

samples = []
for _ in range(n):
    samples.append(mog_sample(weigths_of_gaussians, means, variances))

samples = torch.cat(samples).detach()
samples = samples * theta_std + theta_mean  # de-standardize the parameters

In [None]:
_ = pairplot(
    samples,
    limits=[prior_speed, prior_angle, prior_drag],
    points=[theta_gt],
    figsize=(7.5, 7.5),
    points_colors="r",
    labels=["speed [m/s]", "angle [deg]", "drag"],
)

*Questions:* 
- What can you observe?
- What are potential problems?



### Full MoG

Let's have a look how this changes for a MoG with full covariance matrices.

Here is an implementaion of this:
https://github.com/mackelab/pyknos/blob/main/pyknos/mdn/mdn.py


In [None]:
from pyknos.mdn.mdn import MultivariateGaussianMDN

For this we need to change our code slightly:

In [None]:
dataset = data.TensorDataset(theta_zscored, xs_zscored)
train_loader = data.DataLoader(
    dataset,
    batch_size=500,
)

mdn = MultivariateGaussianMDN(
    features=3,  # theta dim
    context_features=3,  # Dimension of inputs.
    hidden_features=10,  #  Dimension of final layer of `hidden_net`.
    hidden_net=nn.Sequential(
        nn.Linear(3, 10),  # input dim
        nn.ReLU(),
        nn.Linear(10, 10),
        nn.ReLU(),
        nn.Linear(10, 10),
    ),
    num_components=3,
)

opt = optim.Adam(mdn.parameters(), lr=0.001)
training_loss = []
for e in range(50):
    for theta_batch, x_batch in train_loader:
        opt.zero_grad()

        out = mdn.log_prob(theta_batch, x_batch)
        # weights_of_gaussians, means, variances = mdn.get_mixture_components(x_batch)
        # out = mog_log_prob(theta_batch, weights_of_gaussians, means, variances)
        loss = -out.sum()
        loss.backward()
        opt.step()
        training_loss.append(loss.detach().item())

In [None]:
# let's have a look at the loss and see if the network converged
plt.plot(np.arange(len(training_loss)) / len(train_loader), training_loss)
plt.xlabel("training epochs")
plt.ylabel("-log p")
plt.show()

In [None]:
n = 1_000

samples_posterior = mdn.sample(n, xo_torch).detach().squeeze()
samples_posterior = (
    samples_posterior * theta_std + theta_mean
)  # de-standardize the parameters

In [None]:
_ = pairplot(
    samples_posterior,
    limits=[prior_speed, prior_angle, prior_drag],
    points=[theta_gt],
    figsize=(7.5, 7.5),
    points_colors="r",
    labels=["speed [m/s]", "angle [deg]", "drag"],
)

# How to evaluate that this is correct?

More on this later! But a quick check are **Posterior predictive checks**. We draw parameters from the posterior, simulate them, and inspect whether the resulting traces match $x_o$.

In [None]:
# get the ground truth simulation
gt_simulation = throw(*theta_gt)

# get the posterior simulation
posterior_simulation = [throw(*samples_posterior[i]) for i in range(20)]

# get the prior simulation for comparison
prior_simulation = [throw(*thetas[i]) for i in range(20)]

In [None]:
d = gt_simulation[0]
with mpl.rc_context(fname=".matplotlibrc"):
    fig, ax = plt.subplots(1, 2, figsize=(10, 3))

    ax[0].plot(
        d,
        prior_simulation[0][1],
        "black",
        label="prior predictive",
        alpha=0.5,
        lw=0.5,
    )
    for i in range(1, 20):
        ax[0].plot(d, prior_simulation[i][1], "black", alpha=0.5, lw=0.5)

    ax[0].plot(d, gt_simulation[1], color="r", label="ground truth")

    ax[0].legend()
    ax[0].set_ylim(0, 30)
    ax[0].set_xlim(0, 150)
    ax[0].set_title("Prior predictive")
    ax[0].set_xlabel("distance [m]")
    ax[0].set_ylabel("height [m]")

    ax[1].plot(
        d,
        posterior_simulation[0][1],
        "b-",
        label="prior predictive",
        alpha=0.5,
        lw=0.5,
    )
    for i in range(20):
        ax[1].plot(d, posterior_simulation[i][1], "b-", alpha=0.5, lw=0.5)

    ax[1].plot(d, gt_simulation[1], color="r", label="ground truth")

    ax[1].legend()
    ax[1].set_ylim(0, 30)
    ax[1].set_xlim(0, 150)
    ax[1].set_title("Posterior predictive")
    ax[1].set_xlabel("distance [m]")

    # plt.savefig("figures/post_predictives.png", dpi=200, bbox_inches="tight")

# Exercise 2: Try different number of training samples

- first, familiarize yourself with the code above and make sure you understand what's going on.
- Then, go back to the cell in which we loaded the presimulated data:
```python
num_simulations = 10_000

with open("data/throw_dataset.pickle", "rb") as f:
    data_dict = pickle.load(f)

thetas = data_dict["thetas"][:num_simulations]
xs = np.array(data_dict["xs"][:num_simulations])
```
- try training the neural network and evaluating the posterior with fewer simulations. What do you observe as you go to around $500$ (or even fewer) simulations?

# Congrats, you understood the basics of NPE!

Let's move on to some cool features...

# Amortization

One of the cool features of NPE is that the posterior is **amortized**. This means that, after the simulations are done and the network is trained, one can quickly obtain the posterior for any observation $x_o$ (a single forward pass through the neural network)

# Exercise 3: Test amortization for a few different $x_o$

Use the code cells below to test amortization. In other words: change the parameters used to generate observed data (`gt2 = ...`) and inspect whether the posterior samples match the observed data.

In [None]:
theta_gt2 = torch.tensor([40.0, 70.0, 0.2])  # [velocity, angle, drag])
# remember ther prior bounds:
# prior_speed = (10, 50)  # m/s
# prior_angle = (20, 80)  # degree
# prior_drag = (0.1, 1)  # (0.05, 0.3)  # drag

sim_o2 = throw(*theta_gt2)
xo2 = torch.tensor(calculate_summary_statistics(sim_o2), dtype=torch.float)
# xo = sbi_throw_with_sumstats(torch.tensor([velocity, angle, drag]).unsqueeze(0))
print("summary stats for this simulation:", xo2)
# z-score the summary stats
xo_torch2 = (xo2 - xs_mean) / xs_std

*Question:* What happens if you put in a `theta_gt2` which is outside of the prior bounds?

In [None]:
n = 1000

samples = mdn.sample(n, xo_torch2.unsqueeze(0)).detach().squeeze()
samples = samples * theta_std + theta_mean  # de-standardize the parameters
_ = pairplot(
    samples,
    limits=[prior_speed, prior_angle, prior_drag],
    points=[theta_gt2],
    figsize=(7.5, 7.5),
    points_colors="r",
    labels=["speed [m/s]", "angle [deg]", "drag"],
)

In [None]:
gt_simulation2 = throw(*theta_gt2)

posterior_simulation2 = [throw(*samples[i]) for i in range(20)]


prior_simulation = [throw(*thetas[i]) for i in range(20)]

In [None]:
d = gt_simulation[0]
with mpl.rc_context(fname=".matplotlibrc"):
    fig, ax = plt.subplots(1, 2, figsize=(10, 3))

    # Prior
    ax[0].plot(
        d,
        prior_simulation[0][1],
        "black",
        label="prior predictive",
        alpha=0.5,
        lw=0.5,
    )
    for i in range(1, 20):
        ax[0].plot(d, prior_simulation[i][1], "black", alpha=0.5, lw=0.5)

    ax[0].plot(d, gt_simulation2[1], color="r", label="ground truth")

    ax[0].legend()
    ax[0].set_ylim(0, 60)
    ax[0].set_xlim(0, 150)
    ax[0].set_title("Prior predictive")
    ax[0].set_xlabel("distance [m]")
    ax[0].set_ylabel("height [m]")

    # Posterior
    ax[1].plot(
        d,
        posterior_simulation2[0][1],
        "b-",
        label="prior predictive",
        alpha=0.5,
        lw=0.5,
    )
    for i in range(1, 20):
        ax[1].plot(d, posterior_simulation2[i][1], "b-", alpha=0.5, lw=0.5)

    ax[1].plot(d, gt_simulation2[1], color="r", label="ground truth")

    ax[1].legend()
    ax[1].set_ylim(0, 60)
    ax[1].set_xlim(0, 150)
    ax[1].set_title("Posterior predictive")
    ax[1].set_xlabel("distance [m]")

    # plt.savefig("figures/post_predictives2.png", dpi=200, bbox_inches="tight")

# Embedding network

So far, we used summary statistics of the raw trace (i.e. landing distance, highest point etc.).

In some cases, you might not want to (or can not) define summary statistics. What to do then?

We can learn summary statistics! One can pass the simulated data $x$ through **any** neural network before regressing on the Mixture Parameters (e.g. CNN, LSTM, GNN,...)

<img src="figures/cnn.png" alt="drawing" width="1000"/>

The network will automatically extract relevant features!

In [None]:
# load the data
num_simulations = 10_000

with open("throw_dataset.pickle", "rb") as f:
    data_dict = pickle.load(f)

thetas = data_dict["thetas"][:num_simulations]
xs = np.array(
    data_dict["sims"][:num_simulations]
)  # <-- we now load the raw simulations instead

# Filter nans
contains_no_nan = np.invert(np.any(np.isnan(xs), axis=1))
thetas_train = thetas[contains_no_nan]
xs_train = xs[contains_no_nan]

# z-score data
thetas_torch = torch.as_tensor(thetas_train, dtype=torch.float32)
xs_torch = torch.as_tensor(xs_train, dtype=torch.float32)

xs_mean = torch.mean(xs_torch, dim=0)
xs_std = torch.std(xs_torch, dim=0)
xs_zscored = (xs_torch - xs_mean) / xs_std

theta_mean = torch.mean(thetas_torch, dim=0)
theta_std = torch.std(thetas_torch, dim=0)
theta_zscored = (thetas_torch - theta_mean) / theta_std

In [None]:
# lets check the shape of our observations
xs.shape

In [None]:
# TODO: define the embedding network as a fully connected network
hidden_net = nn.Sequential(
   
)

In [None]:
# SOLUTION:
# define the embedding network
hidden_net = nn.Sequential(
    nn.Linear(151, 50),  # input dim,
    nn.ReLU(),
    nn.Linear(50, 50),
    nn.ReLU(),
    nn.Linear(50, 20),
)

In [None]:
dataset = data.TensorDataset(theta_zscored, xs_zscored)
train_loader = data.DataLoader(
    dataset,
    batch_size=500,
)


mdn = MultivariateGaussianMDN(
    features=3,  # theta dim
    context_features=151,  # Dimension of inputs. this is our raw data dimension now.
    hidden_features=20,  #  Dimension of final layer of `hidden_net`.
    hidden_net=hidden_net,  # <--here goes the CNN, LSTM, GNN,..
    num_components=3,
)


opt = optim.Adam(mdn.parameters(), lr=0.001)
training_loss = []
for e in range(100):
    for theta_batch, x_batch in train_loader:
        opt.zero_grad()

        out = mdn.log_prob(theta_batch, x_batch)
        # weights_of_gaussians, means, variances = mdn.get_mixture_components(x_batch)
        # out = mog_log_prob(theta_batch, weights_of_gaussians, means, variances)
        loss = -out.sum()
        loss.backward()
        opt.step()
        training_loss.append(loss.detach().item())

In [None]:
# let's have a look at the loss and see if the network converged
plt.plot(np.arange(len(training_loss)) / len(train_loader), training_loss)
plt.xlabel("training epochs")
plt.ylabel("-log p")
plt.show()

In [None]:
theta_gt = torch.tensor([40.0, 50.0, 0.6])  # [velocity, angle, drag])
# remember ther prior bounds:
# prior_speed = (10, 50)  # m/s
# prior_angle = (20, 80)  # degree
# prior_drag = (0.1, 1)  # (0.05, 0.3)  # drag


sumstatsxo, xo = sbi_throw_with_sumstats(theta_gt.unsqueeze(0), return_raw_sims=True)
# z-score the simulation
xo = torch.tensor(xo, dtype=torch.float)
xo_torch = (xo - xs_mean) / xs_std

In [None]:
n = 2000

samples = mdn.sample(n, xo_torch).detach().squeeze()
samples = samples * theta_std + theta_mean  # de-standardize the parameters
_ = pairplot(
    samples,
    limits=[prior_speed, prior_angle, prior_drag],
    points=[theta_gt],
    figsize=(7.5, 7.5),
    points_colors="r",
    labels=["speed [m/s]", "angle [deg]", "drag"],
)

*Question:*

The posterior marginals seem to be tighter, can you explain why?

### Should I use an embedding net?

Advantages:
- No need for hand-selected features
- possible insights into which features are learned by the CNN, LSTM,...

Disadvantages:
- Probably more training data needed to learn useful features
- The embedding net can learn suspicious simulation effects (e.g. initial value etc.) which may not be interesting for the real data, but highly informative for the posterior


# Summary

Neural Posterior Estimation (NPE) works as follows:
- sample the prior: $\theta \sim p(\theta)$
- run the simulator for each parameter: $x \sim p(x | \theta)$
- train a conditional density estimator $q(\theta | x)$.
- after training, plug the observed data $x_o$ into the network to obtain the posterior.

Benefits:
- after training, the posterior is **amortized**, i.e. it can rapidly be evaluated for new data (no new simulations or retraining)
- NPE can automatically learn summary statistics with the embedding net

In the last week we will see Sequential Neural Posterior Estimation (SNPE), which performs inference over multiple rounds.
- This can enhance the simulation efficiency
- But it requires changes to the loss function. The SNPE algorithms differ in how they deal with this.

# Thank you for your attention!