# Two Moons: Tackling Bimodal Posteriors

_Authors: Lars Kühmichel, Marvin Schmitt, Valentin Pratz, Stefan T. Radev_

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import keras

# For BayesFlow devs: this ensures that the latest dev version can be found
import sys
sys.path.append('../')

import bayesflow as bf

## Simulator<a class="anchor" id="simulator"></a>

This example will demonstrate amortized estimation of a somewhat strange Bayesian model, whose posterior evaluated at the origin $x = (0, 0)$ of the "data" will resemble two crescent moons. The forward process is a noisy non-linear transformation on a 2D plane:

$$
\begin{align}
x_1 &= -|\theta_1 + \theta_2|/\sqrt{2} + r \cos(\alpha) + 0.25\\
x_2 &= (-\theta_1 + \theta_2)/\sqrt{2} + r\sin{\alpha}
\end{align}
$$

with $x = (x_1, x_2)$ playing the role of "observables" (data to be learned from), $\alpha \sim \text{Uniform}(-\pi/2, \pi/2)$, and $r \sim \text{Normal}(0.1, 0.01)$ being latent variables creating noise in the data, and $\theta = (\theta_1, \theta_2)$ being the parameters that we will later seek to infer from new $x$. We set their priors to

$$
\begin{align}
\theta_1, \theta_2 \sim \text{Uniform}(-1, 1).
\end{align}
$$

This model is typically used for benchmarking simulation-based inference (SBI) methods (see https://arxiv.org/pdf/2101.04653) and any method for amortized Bayesian inference should be capable of recovering the two moons posterior *without* using a gazillion of simulations. Note, that this is a considerably harder task than modeling the common unconditional two moons data set used often in the context of normalizing flows.

BayesFlow offers many ways to define your data generating process. Here, we use sequential functions to build a simulator object for online training. Within this composite simulator, each function has access to the outputs of the previous functions. This effectively allows you to define any generative graph.

In [3]:
def theta_prior():
    theta = np.random.uniform(-1, 1, 2)
    return dict(theta=theta)

def forward_model(theta):
    alpha = np.random.uniform(-np.pi / 2, np.pi / 2)
    r = np.random.normal(0.1, 0.01)
    x1 = -np.abs(theta[0] + theta[1]) / np.sqrt(2) + r * np.cos(alpha) + 0.25
    x2 = (-theta[0] + theta[1]) / np.sqrt(2) + r * np.sin(alpha)
    return dict(x=np.array([x1, x2]))

Within the composite simulator, every simulator has access to the outputs of the previous simulators in the list. For example, the last simulator `forward_model` has access to the outputs of the three other simulators.

In [None]:
simulator = bf.make_simulator([theta_prior, forward_model])

## Workflow

In [None]:
workflow = bf.BasicWorkflow(simulator=simulator)

history = workflow.fit_online(epochs=10, validation_data=300)

INFO:bayesflow:Fitting on dataset instance of OnlineDataset.
INFO:bayesflow:Building on a test batch.


Epoch 1/10
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 20ms/step - loss: -0.0079 - loss/inference_loss: -0.0079 - val_loss: -0.3098 - val_loss/inference_loss: -0.3098
Epoch 2/10
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - loss: -1.1417 - loss/inference_loss: -1.1417 - val_loss: -1.7498 - val_loss/inference_loss: -1.7498
Epoch 3/10
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - loss: -1.9484 - loss/inference_loss: -1.9484 - val_loss: -1.4330 - val_loss/inference_loss: -1.4330
Epoch 4/10
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - loss: -2.1626 - loss/inference_loss: -2.1626 - val_loss: -2.4202 - val_loss/inference_loss: -2.4202
Epoch 5/10
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step - loss: -2.3003 - loss/inference_loss: -2.3003 - val_loss: -2.6998 - val_loss/inference_loss: -2.6998
Epoch 6/10
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[

In [None]:
figs = workflow.plot_diagnostics(
    test_data=300,
    filter_names=["beta1", "beta2"],
    diagnostics = ["recovery", "calibration"]
)

metrics_dict = workflow.compute_diagnostics(
    test_data=300,
    diagnostics = ["recovery", "calibration"]
)