In [None]:
import os
import sys

sys.path.append("../../")

import pickle
import timeit

import bayesflow as bf
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from bayesflow.computational_utilities import maximum_mean_discrepancy
from tqdm.autonotebook import tqdm
from train import build_trainer, configurator

In [None]:
physical_devices = tf.config.list_physical_devices("GPU")
if physical_devices:
    try:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
    except (ValueError, RuntimeError):
        # Invalid device or cannot modify virtual devices once initialized.
        pass

# Set up Forward Inference

In [None]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

In [None]:
forward_train = {"prior_draws": train_images, "sim_data": train_images}

num_val = 500
perm = np.random.default_rng(seed=42).permutation(test_images.shape[0])

forward_val = {
    "prior_draws": test_images[perm[:num_val]],
    "sim_data": test_images[perm[:num_val]],
}

forward_test = {
    "prior_draws": test_images[perm[num_val:]],
    "sim_data": test_images[perm[num_val:]],
}

val_labels = test_labels[perm[:num_val]]
test_labels = test_labels[perm[num_val:]]

# Sanity Check

In [None]:
how_many = 5
conf = configurator(
    {
        "sim_data": forward_train["sim_data"][:how_many],
        "prior_draws": forward_train["prior_draws"][:how_many],
    }
)

f, axarr = plt.subplots(how_many, 2)
for i in range(how_many):
    if i == 0:
        axarr[i, 0].set_title("Blurred")
        axarr[i, 1].set_title("True")
    axarr[i, 0].imshow(
        conf["summary_conditions"][i, :, :, 0],
        cmap=plt.cm.get_cmap("Greys"),
    )
    axarr[i, 1].imshow(forward_train["prior_draws"][i].reshape(28, 28), cmap=plt.cm.get_cmap("Greys"))
    axarr[i, 0].axis("off")
    axarr[i, 1].axis("off")
f.tight_layout()

## Set up Network, Amortizer and Trainer

In [None]:
# sampling steps for CMPE - two-step sampling
cmpe_steps = 2
# step size for FMPE, following Flow Matching for Scalable Simulation-Based Inference, https://arxiv.org/pdf/2305.17161.pdf
fmpe_step_size = 1 / 248

In [None]:
def to_id(method, architecture, num_train):
    return f"{method}-{architecture}-{num_train}"

In [None]:
checkpoint_path_dict = {
    to_id("cmpe", "naive", 2000): "checkpoints/cmpe-naive-2000-23-12-02-155659/",
    to_id("cmpe", "naive", 60000): "checkpoints/cmpe-naive-60000-23-12-02-160801/",
    to_id("fmpe", "naive", 2000): "checkpoints/fmpe-naive-2000-23-12-02-161806/",
    to_id("fmpe", "naive", 60000): "checkpoints/fmpe-naive-60000-23-12-02-161806/",
    to_id("cmpe", "unet", 2000): "checkpoints/cmpe-unet-2000-23-12-02-144825/",
    to_id("cmpe", "unet", 60000): "checkpoints/cmpe-unet-60000-23-12-02-161035/",
    to_id("fmpe", "unet", 2000): "checkpoints/fmpe-unet-2000-23-12-02-161806/",
    to_id("fmpe", "unet", 60000): "checkpoints/fmpe-unet-60000-23-12-02-161806/",
}

In [None]:
arg_dict = {}
for key, checkpoint_path in checkpoint_path_dict.items():
    with open(os.path.join(checkpoint_path, "args.pickle"), "rb") as f:
        arg_dict[key] = pickle.load(f)

In [None]:
trainer_dict = {}
for key, checkpoint_path in checkpoint_path_dict.items():
    trainer_dict[key] = build_trainer(checkpoint_path, arg_dict[key])

In [None]:
for key, trainer in trainer_dict.items():
    fig_dir = f"figures/{key}"
    os.makedirs(fig_dir, exist_ok=True)
    h = trainer.loss_history.get_plottable()
    f = bf.diagnostics.plot_losses(h["train_losses"], h["val_losses"])
    f.savefig(os.path.join(fig_dir, "loss_history.pdf"), bbox_inches="tight", dpi=300)

# Evaluation

In [None]:
plt.rcParams.update(
    {
        "axes.labelsize": 24,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "legend.fontsize": 24,
        "text.usetex": False,
        "font.family": "serif",
        "text.latex.preamble": r"\usepackage{{amsmath}}",
    }
)

In [None]:
conf = configurator(forward_test)

## Per-Class Generation: Means and STDs

In [None]:
class_names = [
    "T-Shirt/Top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle Boot",
]

y_labels = [r"Parameter $\theta$", r"Observation $x$", "Mean", "Std.Dev"]

In [None]:
def random_indices_per_class(labels, seed=42):
    out = {}
    unique = np.unique(labels)
    perm = np.random.default_rng(seed).permutation(labels.shape[0])
    for i in unique:
        for idx in perm:
            if i == labels[idx]:
                out[i] = idx
                break
    return out


def create_mean_std_plots(
    trainer, seed=42, filepath=None, n_samples=500, cmpe_steps=30, fmpe_step_size=1 / 248, method=""
):
    """Helper function for displaying Figure 7 in main paper.
    Default seed is the one and only 42!
    """

    idx_dict = random_indices_per_class(test_labels, seed=seed)
    f, axarr = plt.subplots(4, len(idx_dict), figsize=(12, 4))
    for i, (c, idx) in tqdm(enumerate(idx_dict.items()), total=len(idx_dict)):
        # print(f"{i+1:02}/{len(class_names)}", end="\r")
        # Prepare input dict for network
        inp = {
            "parameters": conf["parameters"][idx : (idx + 1)],
            "summary_conditions": conf["summary_conditions"][idx : (idx + 1)],
        }

        # Obtain samples and clip to prior range, instead of rejecting
        if method == "cmpe":
            samples = trainer.amortizer.sample(inp, n_steps=cmpe_steps, n_samples=n_samples)
        else:
            samples = trainer.amortizer.sample(inp, n_samples=n_samples, step_size=fmpe_step_size)
        samples = np.clip(samples, a_min=-1.01, a_max=1.01)

        # Plot truth and blurred
        axarr[0, i].imshow(inp["parameters"].reshape((28, 28, 1)), cmap=matplotlib.colormaps["binary"])
        axarr[1, i].imshow(
            inp["summary_conditions"].reshape((28, 28, 1)),
            cmap=matplotlib.colormaps["binary"],
        )
        axarr[2, i].imshow(samples.mean(0).reshape(28, 28, 1), cmap=matplotlib.colormaps["binary"])
        axarr[3, i].imshow(samples.std(0).reshape(28, 28, 1), cmap=matplotlib.colormaps["binary"])

        axarr[0, i].set_title(class_names[i])

    for j, label in enumerate(y_labels):
        axarr[j, 0].set_ylabel(label, rotation=0, labelpad=55, fontsize=12)

    # get rid of axis
    for ax in axarr.flat:
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.set_yticklabels([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_xticks([])
    f.tight_layout()

    if filepath is not None:
        f.savefig(filepath, dpi=300, bbox_inches="tight")
    return f

In [None]:
for key, trainer in trainer_dict.items():
    print(key)
    fig_dir = f"figures/{key}"
    os.makedirs(fig_dir, exist_ok=True)
    f = create_mean_std_plots(
        trainer,
        seed=42,
        filepath=os.path.join(fig_dir, "main.pdf"),
        method=arg_dict[key].method,
        cmpe_steps=cmpe_steps,
        fmpe_step_size=fmpe_step_size,
    )

## Per-Class Generation: Samples

In [None]:
def create_sample_plots(trainer, seed=42, filepath=None, cmpe_steps=30, fmpe_step_size=1 / 248, method=""):
    """Helper function for displaying Figure 7 in main paper.
    Default seed is the one and only 42!
    """

    idx_dict = random_indices_per_class(test_labels, seed=seed)
    n_samples = 5
    f, axarr = plt.subplots(len(idx_dict), 2 + n_samples, figsize=(8.27, 11.69))
    titles = [r"Param. $\theta$", r"Obs. $x$"] + n_samples * ["Sample"]
    for i, (c, idx) in tqdm(enumerate(idx_dict.items()), total=len(idx_dict)):
        # Prepare input dict for network
        inp = {
            "parameters": conf["parameters"][idx : (idx + 1)],
            "summary_conditions": conf["summary_conditions"][idx : (idx + 1)],
        }

        # Obtain samples and clip to prior range, instead of rejecting
        if method == "cmpe":
            samples = trainer.amortizer.sample(inp, n_steps=cmpe_steps, n_samples=n_samples)
        else:
            samples = trainer.amortizer.sample(inp, n_samples=n_samples, step_size=fmpe_step_size)
        samples = np.clip(samples, a_min=-1.01, a_max=1.01)

        # Plot truth and blurred
        axarr[i, 0].imshow(inp["parameters"].reshape((28, 28, 1)), cmap=matplotlib.colormaps["binary"])
        axarr[i, 1].imshow(
            inp["summary_conditions"].reshape((28, 28, 1)),
            cmap=matplotlib.colormaps["binary"],
        )
        for j in range(n_samples):
            axarr[i, 2 + j].imshow(samples[j].reshape(28, 28, 1), cmap=matplotlib.colormaps["binary"])

        axarr[i, 0].set_ylabel(class_names[i], fontsize=12)

    for i, title in enumerate(titles):
        axarr[0, i].set_title(title, fontsize=12)

    # get rid of axis
    for ax in axarr.flat:
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.set_yticklabels([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_xticks([])
    f.tight_layout()

    if filepath is not None:
        f.savefig(filepath, dpi=300, bbox_inches="tight")
        pass
    return f

In [None]:
for key, trainer in trainer_dict.items():
    print(key)
    fig_dir = f"figures/{key}"
    os.makedirs(fig_dir, exist_ok=True)
    f = create_sample_plots(
        trainer,
        seed=42,
        filepath=os.path.join(fig_dir, "samples_main.pdf"),
        method=arg_dict[key].method,
        cmpe_steps=cmpe_steps,
        fmpe_step_size=fmpe_step_size,
    )
    f.show()

### Averaged RMSE

In [None]:
n_samples = 100
n_datasets = 100
parameters = conf["parameters"][:n_datasets]

for key, trainer in trainer_dict.items():
    print(key, end="")

    # sample once, to avoid contaminating timing with tracing
    c = conf["summary_conditions"][0, None]
    print(f" Initializing...")
    if arg_dict[key].method == "cmpe":
        trainer.amortizer.sample({"summary_conditions": c}, n_steps=cmpe_steps, n_samples=n_samples)
    else:
        trainer.amortizer.sample({"summary_conditions": c}, n_samples=n_samples, step_size=fmpe_step_size)

    # store samples
    post_samples = np.zeros((n_datasets, n_samples, conf["parameters"].shape[-1]))

    tic = timeit.default_timer()
    for i in range(n_datasets):
        print(f"{i+1:03}/{n_datasets}", end="\r")
        c = conf["summary_conditions"][i, None]
        if arg_dict[key].method == "cmpe":
            post_samples[i] = trainer.amortizer.sample(
                {"summary_conditions": c}, n_steps=cmpe_steps, n_samples=n_samples
            )
        else:
            post_samples[i] = trainer.amortizer.sample(
                {"summary_conditions": c}, n_samples=n_samples, step_size=fmpe_step_size
            )
    toc = timeit.default_timer()

    duration = toc - tic
    rmse = bf.computational_utilities.aggregated_rmse(parameters, post_samples)

    output_dir = f"evaluation/{key}"
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, "rmse.csv"), "w") as f:
        f.write(f"duration,rmse\n{duration},{float(rmse)}\n")
    np.save(os.path.join(output_dir, "rmse_samples.npy"), post_samples)
    print(f"duration: {duration/(n_datasets * n_samples) * 1000:.2f}ms\nRMSE:{float(rmse):.3f}")

RMSE for predicting only zeros:

In [None]:
bf.computational_utilities.aggregated_rmse(parameters, tf.zeros_like(conf["summary_conditions"][:n_datasets, None]))

RMSE for predicting the noisy image:

In [None]:
bf.computational_utilities.aggregated_rmse(parameters, conf["summary_conditions"][:n_datasets, None])

### MMD

Split the training images into six parts (due to memory limits) and calculate the maximum mean discrepancy.

In [None]:
parameters = conf["parameters"]
split_size = 1583

for key, trainer in trainer_dict.items():
    print(key)

    if arg_dict[key].method == "cmpe":
        samples = trainer.amortizer.sample(conf, n_steps=cmpe_steps, n_samples=1, to_numpy=False)
    else:
        samples = trainer.amortizer.sample(conf, n_samples=1, step_size=fmpe_step_size, to_numpy=False)
    mmds = np.zeros((6,))
    for i in range(6):
        mmds[i] = maximum_mean_discrepancy(
            conf["parameters"][(i * split_size) : ((i + 1) * split_size)],
            samples[(i * split_size) : ((i + 1) * split_size), 0],
        ).numpy()

    output_dir = f"evaluation/{key}"
    os.makedirs(output_dir, exist_ok=True)
    np.save(os.path.join(output_dir, "mmds.npy"), mmds)
    print(f"{mmds.mean():.5f}, std: {mmds.std():.5f}")