From 1f55216b987fcb0d911cf523e567447746d76941 Mon Sep 17 00:00:00 2001 From: Jerry Date: Fri, 14 Jun 2024 00:22:53 -0400 Subject: [PATCH 01/22] Reset to align with origin/streamlined-backend --- .../diagnostics/plot_calibration_curves.py | 110 +++++++++++ .../diagnostics/plot_confusion_matrix.py | 124 ++++++++++++ .../diagnostics/plot_distribution_2d.py | 85 +++++++++ .../diagnostics/plot_latent_space_2d.py | 36 ++++ .../experimental/diagnostics/plot_losses.py | 121 ++++++++++++ .../diagnostics/plot_mmd_hypothesis_test.py | 100 ++++++++++ .../diagnostics/plot_posterior_2d.py | 134 +++++++++++++ .../experimental/diagnostics/plot_prior_2d.py | 49 +++++ .../experimental/diagnostics/plot_recovery.py | 164 ++++++++++++++++ .../experimental/diagnostics/plot_sbc_ecdf.py | 177 ++++++++++++++++++ .../diagnostics/plot_sbc_histograms.py | 137 ++++++++++++++ .../diagnostics/plot_z_score_contraction.py | 115 ++++++++++++ 12 files changed, 1352 insertions(+) create mode 100644 bayesflow/experimental/diagnostics/plot_calibration_curves.py create mode 100644 bayesflow/experimental/diagnostics/plot_confusion_matrix.py create mode 100644 bayesflow/experimental/diagnostics/plot_distribution_2d.py create mode 100644 bayesflow/experimental/diagnostics/plot_latent_space_2d.py create mode 100644 bayesflow/experimental/diagnostics/plot_losses.py create mode 100644 bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py create mode 100644 bayesflow/experimental/diagnostics/plot_posterior_2d.py create mode 100644 bayesflow/experimental/diagnostics/plot_prior_2d.py create mode 100644 bayesflow/experimental/diagnostics/plot_recovery.py create mode 100644 bayesflow/experimental/diagnostics/plot_sbc_ecdf.py create mode 100644 bayesflow/experimental/diagnostics/plot_sbc_histograms.py create mode 100644 bayesflow/experimental/diagnostics/plot_z_score_contraction.py diff --git a/bayesflow/experimental/diagnostics/plot_calibration_curves.py b/bayesflow/experimental/diagnostics/plot_calibration_curves.py new file mode 100644 index 000000000..7efe79ca2 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_calibration_curves.py @@ -0,0 +1,110 @@ +from ..utils.plotutils import preprocess, postprocess +from ..utils.computils import expected_calibration_error +from keras import ops + + +def plot_calibration_curves( + true_models, + pred_models, + model_names: list = None, + num_bins: int = 10, + label_fontsize: int = 16, + legend_fontsize: int = 14, + title_fontsize: int = 18, + tick_fontsize: int = 12, + epsilon: float = 0.02, + fig_size: tuple = None, + color: str | tuple = "#8f2727", + x_label: str = "Predicted probability", + y_label: str = "True probability", + n_row: int = None, + n_col: int = None, +): + """Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities + for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin. + Depends on the ``expected_calibration_error`` function for computing the ECE. + + Parameters + ---------- + true_models : np.ndarray of shape (num_data_sets, num_models) + The one-hot-encoded true model indices per data set. + pred_models : np.ndarray of shape (num_data_sets, num_models) + The predicted posterior model probabilities (PMPs) per data set. + model_names : list or None, optional, default: None + The model names for nice plot titles. Inferred if None. + num_bins : int, optional, default: 10 + The number of bins to use for the calibration curves (and marginal histograms). + label_fontsize : int, optional, default: 16 + The font size of the y-label and y-label texts + legend_fontsize : int, optional, default: 14 + The font size of the legend text (ECE value) + title_fontsize : int, optional, default: 18 + The font size of the title text. Only relevant if `stacked=False` + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + epsilon : float, optional, default: 0.02 + A small amount to pad the [0, 1]-bounded axes from both side. + fig_size : tuple or None, optional, default: None + The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` + color : str, optional, default: '#8f2727' + The color of the calibration curves + x_label : str, optional, default: Predicted probability + The x-axis label + y_label : str, optional, default: True probability + The y-axis label + n_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + n_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + + Returns + ------- + fig : plt.Figure - the figure instance for optional saving + """ + + f, axarr, ax, n_row, n_col, num_models, model_names = preprocess(true_models, pred_models, fig_size=fig_size) + + # Compute calibration + cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins) + + # Plot marginal calibration curves in a loop + for j in range(num_models): + # Plot calibration curve + ax[j].plot(probs_pred[j], probs_true[j], "o-", color=color) + + # Plot PMP distribution over bins + uniform_bins = ops.linspace(0.0, 1.0, num_bins + 1) + norm_weights = ops.ones_like(pred_models) / len(pred_models) + ax[j].hist(pred_models[:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3) + + # Plot AB line + ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9) + + # Tweak plot + ax[j].tick_params(axis="both", which="major", labelsize=tick_fontsize) + ax[j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) + ax[j].set_title(model_names[j], fontsize=title_fontsize) + ax[j].spines["right"].set_visible(False) + ax[j].spines["top"].set_visible(False) + ax[j].set_xlim([0 - epsilon, 1 + epsilon]) + ax[j].set_ylim([0 - epsilon, 1 + epsilon]) + ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) + ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) + ax[j].grid(alpha=0.5) + + # Add ECE label + ax[j].text( + 0.1, + 0.9, + r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}".format(cal_errs[j]), + horizontalalignment="left", + verticalalignment="center", + transform=ax[j].transAxes, + size=legend_fontsize, + ) + + # Post-processing + postprocess(axarr, ax, n_row, n_col, num_models, x_label, y_label, label_fontsize) + + f.tight_layout() + return f diff --git a/bayesflow/experimental/diagnostics/plot_confusion_matrix.py b/bayesflow/experimental/diagnostics/plot_confusion_matrix.py new file mode 100644 index 000000000..ac522c778 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_confusion_matrix.py @@ -0,0 +1,124 @@ +import matplotlib.pyplot as plt + +from keras import ops +from keras import backend as K +from sklearn.metrics import confusion_matrix +from matplotlib.colors import LinearSegmentedColormap +from ..utils.plotutils import initialize_figure + + +def plot_confusion_matrix( + true_models, + pred_models, + model_names: list = None, + fig_size=(5, 5), + label_fontsize: int = 16, + title_fontsize: int = 18, + value_fontsize: int = 10, + tick_fontsize: int = 12, + xtick_rotation: int = None, + ytick_rotation: int = None, + normalize: bool = True, + cmap=None, + title: bool = True, +): + """Plots a confusion matrix for validating a neural network trained for Bayesian model comparison. + + Parameters + ---------- + true_models : np.ndarray of shape (num_data_sets, num_models) + The one-hot-encoded true model indices per data set. + pred_models : np.ndarray of shape (num_data_sets, num_models) + The predicted posterior model probabilities (PMPs) per data set. + model_names : list or None, optional, default: None + The model names for nice plot titles. Inferred if None. + fig_size : tuple or None, optional, default: (5, 5) + The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` + label_fontsize : int, optional, default: 16 + The font size of the y-label and y-label texts + title_fontsize : int, optional, default: 18 + The font size of the title text. + value_fontsize : int, optional, default: 10 + The font size of the text annotations and the colorbar tick labels. + tick_fontsize : int, optional, default: 12 + The font size of the axis label and model name texts. + xtick_rotation: int, optional, default: None + Rotation of x-axis tick labels (helps with long model names). + ytick_rotation: int, optional, default: None + Rotation of y-axis tick labels (helps with long model names). + normalize : bool, optional, default: True + A flag for normalization of the confusion matrix. + If True, each row of the confusion matrix is normalized to sum to 1. + cmap : matplotlib.colors.Colormap or str, optional, default: None + Colormap to be used for the cells. If a str, it should be the name of a registered colormap, + e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red. + title : bool, optional, default True + A flag for adding 'Confusion Matrix' above the matrix. + + Returns + ------- + fig : plt.Figure - the figure instance for optional saving + """ + + if model_names is None: + num_models = true_models.shape[-1] + model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)] + + if cmap is None: + cmap = LinearSegmentedColormap.from_list("", ["white", "#8f2727"]) + + # Flatten input + true_models = ops.argmax(true_models, axis=1) + pred_models = ops.argmax(pred_models, axis=1) + + # Compute confusion matrix + cm = confusion_matrix(true_models, pred_models) + + if normalize: + # Convert to Keras tensor + cm_tensor = K.constant(cm, dtype='float32') + + # Sum along rows and keep dimensions for broadcasting + cm_sum = K.sum(cm_tensor, axis=1, keepdims=True) + + # Broadcast division for normalization + cm_normalized = cm_tensor / cm_sum + + # Since we might need to use this outside of a session, evaluate using K.eval() if necessary + cm_normalized = K.eval(cm_normalized) + + # Initialize figure + fig, ax = initialize_figure(1, 1, fig_size=fig_size) + # fig, ax = plt.subplots(1, 1, figsize=fig_size) + im = ax.imshow(cm, interpolation="nearest", cmap=cmap) + cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75) + + cbar.ax.tick_params(labelsize=value_fontsize) + + ax.set(xticks=ops.arange(cm.shape[1]), yticks=ops.arange(cm.shape[0])) + ax.set_xticklabels(model_names, fontsize=tick_fontsize) + if xtick_rotation: + plt.xticks(rotation=xtick_rotation, ha="right") + ax.set_yticklabels(model_names, fontsize=tick_fontsize) + if ytick_rotation: + plt.yticks(rotation=ytick_rotation) + ax.set_xlabel("Predicted model", fontsize=label_fontsize) + ax.set_ylabel("True model", fontsize=label_fontsize) + + # Loop over data dimensions and create text annotations + fmt = ".2f" if normalize else "d" + thresh = cm.max() / 2.0 + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text( + j, + i, + format(cm[i, j], fmt), + fontsize=value_fontsize, + ha="center", + va="center", + color="white" if cm[i, j] > thresh else "black", + ) + if title: + ax.set_title("Confusion Matrix", fontsize=title_fontsize) + return fig \ No newline at end of file diff --git a/bayesflow/experimental/diagnostics/plot_distribution_2d.py b/bayesflow/experimental/diagnostics/plot_distribution_2d.py new file mode 100644 index 000000000..9b6855b10 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_distribution_2d.py @@ -0,0 +1,85 @@ +import logging +import seaborn as sns +import pandas as pd + + +def plot_distribution_2d( + samples, + context: str = None, + height: float = 2.5, + color: str | tuple = "#8f2727", + alpha: float = 0.9, + n_params: int = None, + param_names: list = None, + render: bool = True, + **kwargs +): + """ + A more flexible pairplot function for multiple distributions based upon collected samples. + + Parameters + ---------- + samples : np.ndarray or tf.Tensor of shape (n_sim, n_params) + Sample draws from any dataset + context : str + The context that the sample represents + height : float, optional, default: 2.5 + The height of the pair plot + color : str, optional, default : '#8f2727' + The color of the plot + alpha : float in [0, 1], optonal, default: 0.9 + The opacity of the plot + n_params : int, optional, default: None + The number of parameters in the collection of distributions + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + render : bool, optional, default: True + The boolean that determines whether to render the plot visually. If true, then the plot will render; otherwise, the plot will go through further steps for postprocessing + **kwargs : dict, optional + Additional keyword arguments passed to the sns.PairGrid constructor + """ + # Get latent dimensions + dim = samples.shape[-1] + + # Get number of params + if n_params is None: + n_params = dim + + # Generate context if there is none + if context is None: + context = "Generic" + + # Generate titles + if param_names is None: + titles = [f"{context} Param. {i}" for i in range(1, dim + 1)] + else: + titles = [f"{context} {p}" for p in param_names] + + # Convert samples to pd.DataFrame + data_to_plot = pd.DataFrame(samples, columns=titles) + + # Generate plots + g = sns.PairGrid(data_to_plot, height=height, **kwargs) + + g.map_diag(sns.histplot, fill=True, color=color, alpha=alpha, kde=True) + + # Incorporate exceptions for generating KDE plots + try: + g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha) + except Exception as e: + logging.warning("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.") + g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) + + g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) + + if render: + # Generate grids + for i in range(dim): + for j in range(dim): + g.axes[i, j].grid(alpha=0.5) + + # Return figure + g.tight_layout() + return g + else: + return g diff --git a/bayesflow/experimental/diagnostics/plot_latent_space_2d.py b/bayesflow/experimental/diagnostics/plot_latent_space_2d.py new file mode 100644 index 000000000..7e91b2a11 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_latent_space_2d.py @@ -0,0 +1,36 @@ +from .plot_distribution_2d import plot_distribution_2d + +from keras import backend as K + + +def plot_latent_space_2d( + z_samples, + height: float = 2.5, + color="#8f2727", + **kwargs +): + """Creates pair plots for the latent space learned by the inference network. Enables + visual inspection of the latent space and whether its structure corresponds to the + one enforced by the optimization criterion. + + Parameters + ---------- + z_samples : np.ndarray or tf.Tensor of shape (n_sim, n_params) + The latent samples computed through a forward pass of the inference network. + height : float, optional, default: 2.5 + The height of the pair plot. + color : str, optional, default : '#8f2727' + The color of the plot + **kwargs : dict, optional + Additional keyword arguments passed to the sns.PairGrid constructor + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + """ + + # Try to convert z_samples, if eventually tf.Tensor is passed + if not isinstance(z_samples, K.tf.Tensor): + z_samples = K.constant(z_samples) + + plot_distribution_2d(z_samples, context="Latent Dim", height=height, color=color, render=True, **kwargs) diff --git a/bayesflow/experimental/diagnostics/plot_losses.py b/bayesflow/experimental/diagnostics/plot_losses.py new file mode 100644 index 000000000..0b0557c4e --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_losses.py @@ -0,0 +1,121 @@ +import seaborn as sns + +from keras import ops +from ..utils.plotutils import initialize_figure + + +def plot_losses( + train_losses, + val_losses=None, + moving_average: bool = False, + ma_window_fraction: float = 0.01, + fig_size=None, + train_color: str = "#8f2727", + val_color: str = "black", + lw_train: int = 2, + lw_val: int = 3, + grid_alpha: float = 0.5, + legend_fontsize: int = 14, + label_fontsize: int = 14, + title_fontsize: int = 16, +): + """A generic helper function to plot the losses of a series of training epochs and runs. + + Parameters + ---------- + + train_losses : pd.DataFrame + The (plottable) history as returned by a train_[...] method of a ``Trainer`` instance. + Alternatively, you can just pass a data frame of validation losses instead of train losses, + if you only want to plot the validation loss. + val_losses : pd.DataFrame or None, optional, default: None + The (plottable) validation history as returned by a train_[...] method of a ``Trainer`` instance. + If left ``None``, only train losses are plotted. Should have the same number of columns + as ``train_losses``. + moving_average : bool, optional, default: False + A flag for adding a moving average line of the train_losses. + ma_window_fraction : int, optional, default: 0.01 + Window size for the moving average as a fraction of total training steps. + fig_size : tuple or None, optional, default: None + The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` + train_color : str, optional, default: '#8f2727' + The color for the train loss trajectory + val_color : str, optional, default: black + The color for the optional validation loss trajectory + lw_train : int, optional, default: 2 + The linewidth for the training loss curve + lw_val : int, optional, default: 3 + The linewidth for the validation loss curve + grid_alpha : float, optional, default 0.5 + The opacity factor for the background gridlines + legend_fontsize : int, optional, default: 14 + The font size of the legend text + label_fontsize : int, optional, default: 14 + The font size of the y-label text + title_fontsize : int, optional, default: 16 + The font size of the title text + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + AssertionError + If the number of columns in ``train_losses`` does not match the + number of columns in ``val_losses``. + """ + + # Determine the number of rows for plot + n_row = len(train_losses.columns) + + # Initialize figure + f, axarr = initialize_figure(n_row=n_row, n_col=1, fig_size=(16, int(4 * n_row))) + + # if fig_size is None: + # fig_size = (16, int(4 * n_row)) + # f, axarr = plt.subplots(n_row, 1, figsize=fig_size) + + # Get the number of steps as an array + train_step_index = ops.arange(1, len(train_losses) + 1) + if val_losses is not None: + val_step = int(ops.floor(len(train_losses) / len(val_losses))) + val_step_index = train_step_index[(val_step - 1) :: val_step] + + # If unequal length due to some reason, attempt a fix + if val_step_index.shape[0] > val_losses.shape[0]: + val_step_index = val_step_index[: val_losses.shape[0]] + + # Loop through loss entries and populate plot + looper = [axarr] if n_row == 1 else axarr.flat + for i, ax in enumerate(looper): + # Plot train curve + ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training") + if moving_average and train_losses.columns[i] == "Loss": + moving_average_window = int(train_losses.shape[0] * ma_window_fraction) + smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean() + ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)") + + # Plot optional val curve + if val_losses is not None: + if i < val_losses.shape[1]: + ax.plot( + val_step_index, + val_losses.iloc[:, i], + linestyle="--", + marker="o", + color=val_color, + lw=lw_val, + label="Validation", + ) + # Schmuck + ax.set_xlabel("Training step #", fontsize=label_fontsize) + ax.set_ylabel("Value", fontsize=label_fontsize) + sns.despine(ax=ax) + ax.grid(alpha=grid_alpha) + ax.set_title(train_losses.columns[i], fontsize=title_fontsize) + # Only add legend if there is a validation curve + if val_losses is not None or moving_average: + ax.legend(fontsize=legend_fontsize) + f.tight_layout() + return f diff --git a/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py b/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py new file mode 100644 index 000000000..01b336094 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py @@ -0,0 +1,100 @@ +import matplotlib.pyplot as plt +import seaborn as sns + +from keras import ops + + +def plot_mmd_hypothesis_test( + mmd_null, + mmd_observed: float = None, + alpha_level: float = 0.05, + null_color: str | tuple = (0.16407, 0.020171, 0.577478), + observed_color: str | tuple = "red", + alpha_color: str | tuple = "orange", + truncate_v_lines_at_kde: bool = False, + x_min: float = None, + x_max: float = None, + bw_factor: float = 1.5, +): + """ + + Parameters + ---------- + mmd_null : np.ndarray + The samples from the MMD sampling distribution under the null hypothesis "the model is well-specified" + mmd_observed : float + The observed MMD value + alpha_level : float, optional, default: 0.05 + The rejection probability (type I error) + null_color : str or tuple, optional, default: (0.16407, 0.020171, 0.577478) + The color of the H0 sampling distribution + observed_color : str or tuple, optional, default: "red" + The color of the observed MMD + alpha_color : str or tuple, optional, default: "orange" + The color of the rejection area + truncate_v_lines_at_kde: bool, optional, default: False + true: cut off the vlines at the kde + false: continue kde lines across the plot + x_min : float, optional, default: None + The lower x-axis limit + x_max : float, optional, default: None + The upper x-axis limit + bw_factor : float, optional, default: 1.5 + bandwidth (aka. smoothing parameter) of the kernel density estimate + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + """ + + def draw_v_line_to_kde(x, kde_object, color, label=None, **kwargs): + kde_x, kde_y = kde_object.lines[0].get_data() + idx = ops.argmin(ops.abs(kde_x - x)) + plt.vlines(x=x, ymin=0, ymax=kde_y[idx], color=color, linewidth=3, label=label, **kwargs) + + def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs): + kde_x, kde_y = kde_object.lines[0].get_data() + if x_end is not None: + plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start) & (kde_x <= x_end), interpolate=True, **kwargs) + else: + plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start), interpolate=True, **kwargs) + + f = plt.figure(figsize=(8, 4)) + + kde = sns.kdeplot(mmd_null, fill=False, linewidth=0, bw_adjust=bw_factor) + sns.kdeplot(mmd_null, fill=True, alpha=0.12, color=null_color, bw_adjust=bw_factor) + + if truncate_v_lines_at_kde: + draw_v_line_to_kde(x=mmd_observed, kde_object=kde, color=observed_color, label=r"Observed data") + else: + plt.vlines( + x=mmd_observed, + ymin=0, + ymax=plt.gca().get_ylim()[1], + color=observed_color, + linewidth=3, + label=r"Observed data", + ) + + mmd_critical = ops.quantile(mmd_null, 1 - alpha_level) + fill_area_under_kde( + kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level*100)}% rejection area" + ) + + if truncate_v_lines_at_kde: + draw_v_line_to_kde(x=mmd_critical, kde_object=kde, color=alpha_color) + else: + plt.vlines(x=mmd_critical, color=alpha_color, linewidth=3, ymin=0, ymax=plt.gca().get_ylim()[1]) + + sns.kdeplot(mmd_null, fill=False, linewidth=3, color=null_color, label=r"$H_0$", bw_adjust=bw_factor) + + plt.xlabel(r"MMD", fontsize=20) + plt.ylabel("") + plt.yticks([]) + plt.xlim(x_min, x_max) + plt.tick_params(axis="both", which="major", labelsize=16) + + plt.legend(fontsize=20) + sns.despine() + + return f diff --git a/bayesflow/experimental/diagnostics/plot_posterior_2d.py b/bayesflow/experimental/diagnostics/plot_posterior_2d.py new file mode 100644 index 000000000..c546627c1 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_posterior_2d.py @@ -0,0 +1,134 @@ +import pandas as pd +import seaborn as sns + +from matplotlib.lines import Line2D +from .plot_distribution_2d import plot_distribution_2d + + +def plot_posterior_2d( + posterior_draws, + prior=None, + prior_draws=None, + param_names: list = None, + height: int = 3, + label_fontsize: int = 14, + legend_fontsize: int = 16, + tick_fontsize: int = 12, + post_color: str | tuple = "#8f2727", + prior_color: str | tuple = "gray", + post_alpha: float = 0.9, + prior_alpha: float = 0.7, + **kwargs +): + """Generates a bivariate pairplot given posterior draws and optional prior or prior draws. + + posterior_draws : np.ndarray of shape (n_post_draws, n_params) + The posterior draws obtained for a SINGLE observed data set. + prior : bayesflow.forward_inference.Prior instance or None, optional, default: None + The optional prior object having an input-output signature as given by ayesflow.forward_inference.Prior + prior_draws : np.ndarray of shape (n_prior_draws, n_params) or None, optonal (default: None) + The optional prior draws obtained from the prior. If both prior and prior_draws are provided, prior_draws + will be used. + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + height : float, optional, default: 3 + The height of the pairplot + label_fontsize : int, optional, default: 14 + The font size of the x and y-label texts (parameter names) + legend_fontsize : int, optional, default: 16 + The font size of the legend text + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + post_color : str, optional, default: '#8f2727' + The color for the posterior histograms and KDEs + priors_color : str, optional, default: gray + The color for the optional prior histograms and KDEs + post_alpha : float in [0, 1], optonal, default: 0.9 + The opacity of the posterior plots + prior_alpha : float in [0, 1], optonal, default: 0.7 + The opacity of the prior plots + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + AssertionError + If the shape of posterior_draws is not 2-dimensional. + """ + + # Ensure correct shape + assert ( + len(posterior_draws.shape) + ) == 2, "Shape of `posterior_samples` for a single data set should be 2 dimensional!" + + # Plot posterior first + g = plot_distribution_2d( + posterior_draws, + context="\\theta", + param_names=param_names, + render=False, + **kwargs + ) + + # Obtain n_draws and n_params + n_draws, n_params = posterior_draws.shape + + # If prior object is given and no draws, obtain draws + if prior is not None and prior_draws is None: + draws = prior(n_draws) + if type(draws) is dict: + prior_draws = draws["prior_draws"] + else: + prior_draws = draws + + # Attempt to determine parameter names + if param_names is None: + if hasattr(prior, "param_names"): + if prior.param_names is not None: + param_names = prior.param_names + else: + param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] + else: + param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] + + # Add prior, if given + if prior_draws is not None: + prior_draws_df = pd.DataFrame(prior_draws, columns=param_names) + g.data = prior_draws_df + g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1) + g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1) + + # Add legend, if prior also given + if prior_draws is not None or prior is not None: + handles = [ + Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha), + Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha), + ] + g.legend(handles, ["Posterior", "Prior"], fontsize=legend_fontsize, loc="center right") + + n_row, n_col = g.axes.shape + + for i in range(n_row): + # Remove upper axis + for j in range(i+1, n_col): + g.axes[i, j].axis("off") + + # Modify tick sizes + for j in range(i + 1): + g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize) + g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) + + # Add nice labels + for i, param_name in enumerate(param_names): + g.axes[i, 0].set_ylabel(param_name, fontsize=label_fontsize) + g.axes[len(param_names) - 1, i].set_xlabel(param_name, fontsize=label_fontsize) + + # Add grids + for i in range(n_params): + for j in range(n_params): + g.axes[i, j].grid(alpha=0.5) + + g.tight_layout() + return g diff --git a/bayesflow/experimental/diagnostics/plot_prior_2d.py b/bayesflow/experimental/diagnostics/plot_prior_2d.py new file mode 100644 index 000000000..5280486e4 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_prior_2d.py @@ -0,0 +1,49 @@ +from .plot_distribution_2d import plot_distribution_2d + + +def plot_prior_2d( + prior, + param_names: list = None, + n_samples: int = 2000, + height: float = 2.5, + color: str | tuple = "#8f2727", + **kwargs +): + """Creates pair-plots for a given joint prior. + + Parameters + ---------- + prior : callable + The prior object which takes a single integer argument and generates random draws. + param_names : list of str or None, optional, default None + An optional list of strings which + n_samples : int, optional, default: 1000 + The number of random draws from the joint prior + height : float, optional, default: 2.5 + The height of the pair plot + color : str, optional, default : '#8f2727' + The color of the plot + **kwargs : dict, optional + Additional keyword arguments passed to the sns.PairGrid constructor + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + """ + + # Generate prior draws + prior_samples = prior(n_samples) + + # Handle dict type + if type(prior_samples) is dict: + prior_samples = prior_samples["prior_draws"] + + plot_distribution_2d( + prior_samples, + context="Prior", + height=height, + color=color, + param_names=param_names, + render=True, + **kwargs + ) diff --git a/bayesflow/experimental/diagnostics/plot_recovery.py b/bayesflow/experimental/diagnostics/plot_recovery.py new file mode 100644 index 000000000..d8a4f83b7 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_recovery.py @@ -0,0 +1,164 @@ +import numpy as np +from scipy.stats import median_abs_deviation +from sklearn.metrics import r2_score +import seaborn as sns + +from ..utils.plotutils import preprocess, postprocess + + +def plot_recovery( + post_samples, + prior_samples, + point_agg=np.median, + uncertainty_agg=median_abs_deviation, + param_names: list = None, + fig_size: tuple = None, + label_fontsize: int = 16, + title_fontsize: int = 18, + metric_fontsize: int = 16, + tick_fontsize: int = 12, + add_corr: bool = True, + add_r2: bool = True, + color: str | tuple = "#8f2727", + n_col: int = None, + n_row: int = None, + xlabel: str = "Ground truth", + ylabel: str = "Estimated", + **kwargs, +): + """Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty. + The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate + can be controlled with the ``uncertainty_agg`` argument. + + This plot yields similar information as the "posterior z-score", but allows for generic + point and uncertainty estimates: + + https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html + + Important: Posterior aggregates play no special role in Bayesian inference and should only + be used heuristically. For instance, in the case of multi-modal posteriors, common point + estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing. + + Parameters + ---------- + post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) + The posterior draws obtained from n_data_sets + prior_samples : np.ndarray of shape (n_data_sets, n_params) + The prior draws (true parameters) obtained for generating the n_data_sets + point_agg : callable, optional, default: ``np.median`` + The function to apply to the posterior draws to get a point estimate for each marginal. + The default computes the marginal median for each marginal posterior as a robust + point estimate. + uncertainty_agg : callable or None, optional, default: scipy.stats.median_abs_deviation + The function to apply to the posterior draws to get an uncertainty estimate. + If ``None`` provided, a simple scatter using only ``point_agg`` will be plotted. + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + fig_size : tuple or None, optional, default : None + The figure size passed to the matplotlib constructor. Inferred if None. + label_fontsize : int, optional, default: 16 + The font size of the y-label text + title_fontsize : int, optional, default: 18 + The font size of the title text + metric_fontsize : int, optional, default: 16 + The font size of the goodness-of-fit metric (if provided) + tick_fontsize : int, optional, default: 12 + The font size of the axis tick labels + add_corr : bool, optional, default: True + A flag for adding correlation between true and estimates to the plot + add_r2 : bool, optional, default: True + A flag for adding R^2 between true and estimates to the plot + color : str, optional, default: '#8f2727' + The color for the true vs. estimated scatter points and error bars + n_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + n_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + xlabel : str, optional, default: 'Ground truth' + The label on the x-axis of the plot + ylabel : str, optional, default: 'Estimated' + The label on the y-axis of the plot + **kwargs : optional + Additional keyword arguments passed to ax.errorbar or ax.scatter. + Example: `rasterized=True` to reduce PDF file size with many dots + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + ShapeError + If there is a deviation from the expected shapes of ``post_samples`` and ``prior_samples``. + """ + + # Preprocess + f, axarr, axarr_it, n_row, n_col, n_params, param_names = preprocess( + post_samples, prior_samples, fig_size=fig_size + ) + + # Compute point estimates and uncertainties + est = point_agg(post_samples, axis=1) + if uncertainty_agg is not None: + u = uncertainty_agg(post_samples, axis=1) + + # Loop and plot + for i, ax in enumerate(axarr_it): + if i >= n_params: + break + + # Add scatter and error bars + if uncertainty_agg is not None: + _ = ax.errorbar(prior_samples[:, i], est[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color, **kwargs) + else: + _ = ax.scatter(prior_samples[:, i], est[:, i], alpha=0.5, color=color, **kwargs) + + # Make plots quadratic to avoid visual illusions + lower = min(prior_samples[:, i].min(), est[:, i].min()) + upper = max(prior_samples[:, i].max(), est[:, i].max()) + eps = (upper - lower) * 0.1 + ax.set_xlim([lower - eps, upper + eps]) + ax.set_ylim([lower - eps, upper + eps]) + ax.plot( + [ax.get_xlim()[0], ax.get_xlim()[1]], + [ax.get_ylim()[0], ax.get_ylim()[1]], + color="black", + alpha=0.9, + linestyle="dashed", + ) + + # Add optional metrics and title + if add_r2: + r2 = r2_score(prior_samples[:, i], est[:, i]) + ax.text( + 0.1, + 0.9, + "$R^2$ = {:.3f}".format(r2), + horizontalalignment="left", + verticalalignment="center", + transform=ax.transAxes, + size=metric_fontsize, + ) + if add_corr: + corr = np.corrcoef(prior_samples[:, i], est[:, i])[0, 1] + ax.text( + 0.1, + 0.8, + "$r$ = {:.3f}".format(corr), + horizontalalignment="left", + verticalalignment="center", + transform=ax.transAxes, + size=metric_fontsize, + ) + ax.set_title(param_names[i], fontsize=title_fontsize) + + # Prettify + sns.despine(ax=ax) + ax.grid(alpha=0.5) + ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) + ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) + + postprocess(axarr, axarr_it, n_row, n_col, n_params, xlabel, ylabel, label_fontsize) + + f.tight_layout() + return f diff --git a/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py b/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py new file mode 100644 index 000000000..1863d6620 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py @@ -0,0 +1,177 @@ +import seaborn as sns + +from keras import ops +from keras import backend as K +from ..utils.computils import simultaneous_ecdf_bands +from ..utils.plotutils import preprocess, remove_unused_axes + + +def plot_sbc_ecdf( + post_samples, + prior_samples, + difference: bool = False, + stacked: bool = False, + fig_size: tuple = None, + param_names: list = None, + label_fontsize: int = 16, + legend_fontsize: int = 14, + title_fontsize: int = 18, + tick_fontsize: int = 12, + rank_ecdf_color: str | tuple = "#a34f4f", + fill_color: str | tuple = "grey", + n_row: int = None, + n_col: int = None, + **kwargs, +): + """Creates the empirical CDFs for each marginal rank distribution and plots it against + a uniform ECDF. ECDF simultaneous bands are drawn using simulations from the uniform, + as proposed by [1]. + + For models with many parameters, use `stacked=True` to obtain an idea of the overall calibration + of a posterior approximator. + + [1] Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and + its applications in goodness-of-fit evaluation and multiple sample comparison. Statistics and Computing, + 32(2), 1-21. https://arxiv.org/abs/2103.10522 + + Parameters + ---------- + post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) + The posterior draws obtained from n_data_sets + prior_samples : np.ndarray of shape (n_data_sets, n_params) + The prior draws obtained for generating n_data_sets + difference : bool, optional, default: False + If `True`, plots the ECDF difference. Enables a more dynamic visualization range. + stacked : bool, optional, default: False + If `True`, all ECDFs will be plotted on the same plot. If `False`, each ECDF will + have its own subplot, similar to the behavior of `plot_sbc_histograms`. + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None. Only relevant if `stacked=False`. + fig_size : tuple or None, optional, default: None + The figure size passed to the matplotlib constructor. Inferred if None. + label_fontsize : int, optional, default: 16 + The font size of the y-label and y-label texts + legend_fontsize : int, optional, default: 14 + The font size of the legend text + title_fontsize : int, optional, default: 18 + The font size of the title text. Only relevant if `stacked=False` + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + rank_ecdf_color : str, optional, default: '#a34f4f' + The color to use for the rank ECDFs + fill_color : str, optional, default: 'grey' + The color of the fill arguments. + n_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + n_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + **kwargs : dict, optional, default: {} + Keyword arguments can be passed to control the behavior of ECDF simultaneous band computation + through the ``ecdf_bands_kwargs`` dictionary. See `simultaneous_ecdf_bands` for keyword arguments + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + ShapeError + If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. + """ + + f, ax, ax_it, n_row, n_col, n_params, param_names = preprocess( + post_samples, prior_samples, collapse=False, fig_size=fig_size) + + # Compute fractional ranks (using broadcasting) + post_samples = K.constant(post_samples) + prior_samples = K.constant(prior_samples) + + # Adding an extra dimension to prior_samples using K.expand_dims + prior_samples_expanded = K.expand_dims(prior_samples, axis=1) + + # Performing element-wise comparison + comparison = K.less(post_samples, prior_samples_expanded) + + # Summing along the specified axis (axis=1) + sums = K.sum(K.cast(comparison, dtype='float32'), axis=1) + + # Getting the shape of post_samples + post_samples_shape = K.shape(post_samples) + + # Computing the ranks + ranks = sums / K.cast(post_samples_shape[1], dtype='float32') + + # ranks = ops.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) / post_samples.shape[1] + + + # Plot individual ecdf of parameters + for j in range(ranks.shape[-1]): + ecdf_single = ops.sort(ranks[:, j]) + xx = ecdf_single + yy = ops.arange(1, xx.shape[-1] + 1) / float(xx.shape[-1]) + + # Difference, if specified + if difference: + yy -= xx + + if stacked: + if j == 0: + ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs") + else: + ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95) + else: + ax.flat[j].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDF") + + # Compute uniform ECDF and bands + alpha, z, L, H = simultaneous_ecdf_bands(post_samples.shape[0], **kwargs.pop("ecdf_bands_kwargs", {})) + + # Difference, if specified + if difference: + L -= z + H -= z + ylab = "ECDF difference" + else: + ylab = "ECDF" + + # Add simultaneous bounds + if stacked: + titles = [None] + axes = [ax] + else: + axes = ax.flat + if param_names is None: + titles = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] + else: + titles = param_names + + for _ax, title in zip(axes, titles): + _ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands") + + # Prettify plot + sns.despine(ax=_ax) + _ax.grid(alpha=0.35) + _ax.legend(fontsize=legend_fontsize) + _ax.set_title(title, fontsize=title_fontsize) + _ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) + _ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) + + # Only add x-labels to the bottom row + if stacked: + bottom_row = [ax] + else: + bottom_row = ax if n_row == 1 else ax[-1, :] + for _ax in bottom_row: + _ax.set_xlabel("Fractional rank statistic", fontsize=label_fontsize) + + # Only add y-labels to right left-most row + if n_row == 1: # if there is only one row, the ax array is 1D + axes[0].set_ylabel(ylab, fontsize=label_fontsize) + else: # if there is more than one row, the ax array is 2D + for _ax in ax[:, 0]: + _ax.set_ylabel(ylab, fontsize=label_fontsize) + + # Remove unused axes entirely + remove_unused_axes(ax) + + f.tight_layout() + return f diff --git a/bayesflow/experimental/diagnostics/plot_sbc_histograms.py b/bayesflow/experimental/diagnostics/plot_sbc_histograms.py new file mode 100644 index 000000000..51d29bd26 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_sbc_histograms.py @@ -0,0 +1,137 @@ +import logging +import seaborn as sns + +from scipy.stats import binom +from keras import ops +from keras import backend as K +from ..utils.plotutils import preprocess, remove_unused_axes + + +def plot_sbc_histograms( + post_samples, + prior_samples, + param_names: list = None, + fig_size: tuple = None, + num_bins: int = None, + binomial_interval: float = 0.99, + label_fontsize: int = 16, + title_fontsize: int = 18, + tick_fontsize: int = 12, + hist_color: str | tuple = "#a34f4f", + n_row: int = None, + n_col: int = None, +): + """Creates and plots publication-ready histograms of rank statistics for simulation-based calibration + (SBC) checks according to [1]. + + Any deviation from uniformity indicates miscalibration and thus poor convergence + of the networks or poor combination between generative model / networks. + + [1] Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). + Validating Bayesian inference algorithms with simulation-based calibration. + arXiv preprint arXiv:1804.06788. + + Parameters + ---------- + post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) + The posterior draws obtained from n_data_sets + prior_samples : np.ndarray of shape (n_data_sets, n_params) + The prior draws obtained for generating n_data_sets + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + fig_size : tuple or None, optional, default : None + The figure size passed to the matplotlib constructor. Inferred if None + num_bins : int, optional, default: 10 + The number of bins to use for each marginal histogram + binomial_interval : float in (0, 1), optional, default: 0.99 + The width of the confidence interval for the binomial distribution + label_fontsize : int, optional, default: 16 + The font size of the y-label text + title_fontsize : int, optional, default: 18 + The font size of the title text + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + hist_color : str, optional, default '#a34f4f' + The color to use for the histogram body + n_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + n_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + ShapeError + If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. + """ + + f, axarr, ax, n_row, n_col, n_params, param_names = preprocess(post_samples, prior_samples, fig_size=fig_size) + + # Determine the ratio of simulations to prior draws + n_sim, n_draws, _ = post_samples.shape + ratio = int(n_sim / n_draws) + + # Log a warning if N/B ratio recommended by Talts et al. (2018) < 20 + if ratio < 20: + logger = logging.getLogger() + logger.setLevel(logging.INFO) + logger.info( + f"The ratio of simulations / posterior draws should be > 20 " + + f"for reliable variance reduction, but your ratio is {ratio}.\ + Confidence intervals might be unreliable!" + ) + + # Set n_bins automatically, if nothing provided + if num_bins is None: + num_bins = int(ratio / 2) + # Attempt a fix if a single bin is determined so plot still makes sense + if num_bins == 1: + num_bins = 5 + + # Compute ranks (using broadcasting) + post_samples = K.constant(post_samples) + prior_samples = K.constant(prior_samples) + + # Adding an extra dimension to prior_samples using K.expand_dims + prior_samples_expanded = K.expand_dims(prior_samples, axis=1) + + # Performing element-wise comparison + comparison = K.less(post_samples, prior_samples_expanded) + + # Summing along the specified axis (axis=1) + ranks = K.sum(K.cast(comparison, dtype='float32'), axis=1) + # ranks = ops.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) + + # Compute confidence interval and mean + N = int(prior_samples.shape[0]) + # uniform distribution expected -> for all bins: equal probability + # p = 1 / num_bins that a rank lands in that bin + endpoints = binom.interval(binomial_interval, N, 1 / num_bins) + mean = N / num_bins # corresponds to binom.mean(N, 1 / num_bins) + + # Plot marginal histograms in a loop + for j in range(len(param_names)): + ax[j].axhspan(endpoints[0], endpoints[1], facecolor="gray", alpha=0.3) + ax[j].axhline(mean, color="gray", zorder=0, alpha=0.9) + sns.histplot(ranks[:, j], kde=False, ax=ax[j], color=hist_color, bins=num_bins, alpha=0.95) + ax[j].set_title(param_names[j], fontsize=title_fontsize) + ax[j].spines["right"].set_visible(False) + ax[j].spines["top"].set_visible(False) + ax[j].get_yaxis().set_ticks([]) + ax[j].set_ylabel("") + ax[j].tick_params(axis="both", which="major", labelsize=tick_fontsize) + ax[j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) + + # Only add x-labels to the bottom row + bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] + for _ax in bottom_row: + _ax.set_xlabel("Rank statistic", fontsize=label_fontsize) + + # Remove unused axes entirely + remove_unused_axes(axarr, n_params) + + f.tight_layout() + return f diff --git a/bayesflow/experimental/diagnostics/plot_z_score_contraction.py b/bayesflow/experimental/diagnostics/plot_z_score_contraction.py new file mode 100644 index 000000000..41ed78be4 --- /dev/null +++ b/bayesflow/experimental/diagnostics/plot_z_score_contraction.py @@ -0,0 +1,115 @@ +import seaborn as sns +from ..utils.plotutils import preprocess, postprocess + + +def plot_z_score_contraction( + post_samples, + prior_samples, + param_names: list = None, + fig_size: tuple = None, + label_fontsize: int = 16, + title_fontsize: int = 18, + tick_fontsize: int = 12, + color: str | tuple = "#8f2727", + x_label: str = "Posterior contraction", + y_label: str = "Posterior z-score", + n_col: int = None, + n_row: int = None, +): + """Implements a graphical check for global model sensitivity by plotting the posterior + z-score over the posterior contraction for each set of posterior samples in ``post_samples`` + according to [1]. + + - The definition of the posterior z-score is: + + post_z_score = (posterior_mean - true_parameters) / posterior_std + + And the score is adequate if it centers around zero and spreads roughly in the interval [-3, 3] + + - The definition of posterior contraction is: + + post_contraction = 1 - (posterior_variance / prior_variance) + + In other words, the posterior contraction is a proxy for the reduction in uncertainty gained by + replacing the prior with the posterior. The ideal posterior contraction tends to 1. + Contraction near zero indicates that the posterior variance is almost identical to + the prior variance for the particular marginal parameter distribution. + + Note: Means and variances will be estimated via their sample-based estimators. + + [1] Schad, D. J., Betancourt, M., & Vasishth, S. (2021). + Toward a principled Bayesian workflow in cognitive science. + Psychological methods, 26(1), 103. + + Paper also available at https://arxiv.org/abs/1904.12765 + + Parameters + ---------- + post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) + The posterior draws obtained from n_data_sets + prior_samples : np.ndarray of shape (n_data_sets, n_params) + The prior draws (true parameters) obtained for generating the n_data_sets + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + fig_size : tuple or None, optional, default : None + The figure size passed to the matplotlib constructor. Inferred if None. + label_fontsize : int, optional, default: 16 + The font size of the y-label text + title_fontsize : int, optional, default: 18 + The font size of the title text + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + color : str, optional, default: '#8f2727' + The color for the true vs. estimated scatter points and error bars + x_label : str, optional, default: Posterior contraction + The label for the x-axis + y_label : str, optional, default: Posterior z-score + The label for the y-axis + n_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + n_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + ShapeError + If there is a deviation from the expected shapes of ``post_samples`` and ``prior_samples``. + """ + + f, axarr, axarr_it, n_row, n_col, n_params, param_names = preprocess(post_samples, prior_samples, fig_size=fig_size) + + # Estimate posterior means and stds + post_means = post_samples.mean(axis=1) + post_stds = post_samples.std(axis=1, ddof=1) + post_vars = post_samples.var(axis=1, ddof=1) + + # Estimate prior variance + prior_vars = prior_samples.var(axis=0, keepdims=True, ddof=1) + + # Compute contraction + post_cont = 1 - (post_vars / prior_vars) + + # Compute posterior z score + z_score = (post_means - prior_samples) / post_stds + + # Loop and plot + for i, ax in enumerate(axarr_it): + if i >= n_params: + break + + ax.scatter(post_cont[:, i], z_score[:, i], color=color, alpha=0.5) + ax.set_title(param_names[i], fontsize=title_fontsize) + sns.despine(ax=ax) + ax.grid(alpha=0.5) + ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) + ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) + ax.set_xlim([-0.05, 1.05]) + + postprocess(axarr, axarr_it, n_row, n_col, n_params, x_label, y_label, label_fontsize) + + f.tight_layout() + return f From c5cc2fd2fd25821afaa7fe53a37980d26766acc0 Mon Sep 17 00:00:00 2001 From: Jerry Date: Fri, 14 Jun 2024 12:43:06 -0400 Subject: [PATCH 02/22] Bug fix and reformatting for diagnostics submodules --- .../diagnostics/plot_calibration_curves.py | 1 + .../diagnostics/plot_confusion_matrix.py | 1 + .../diagnostics/plot_distribution_2d.py | 6 ++++-- .../diagnostics/plot_latent_space_2d.py | 1 + .../experimental/diagnostics/plot_losses.py | 1 + .../diagnostics/plot_mmd_hypothesis_test.py | 1 + .../diagnostics/plot_posterior_2d.py | 1 + .../experimental/diagnostics/plot_prior_2d.py | 1 + .../experimental/diagnostics/plot_recovery.py | 1 + .../experimental/diagnostics/plot_sbc_ecdf.py | 1 + .../diagnostics/plot_sbc_histograms.py | 1 + .../diagnostics/plot_z_score_contraction.py | 1 + tests/test_diagnostics/__init__.py | 0 tests/test_diagnostics/test_diagnostics.py | 16 ++++++++++++++++ 14 files changed, 31 insertions(+), 2 deletions(-) create mode 100644 tests/test_diagnostics/__init__.py create mode 100644 tests/test_diagnostics/test_diagnostics.py diff --git a/bayesflow/experimental/diagnostics/plot_calibration_curves.py b/bayesflow/experimental/diagnostics/plot_calibration_curves.py index 7efe79ca2..21bdb36a4 100644 --- a/bayesflow/experimental/diagnostics/plot_calibration_curves.py +++ b/bayesflow/experimental/diagnostics/plot_calibration_curves.py @@ -1,3 +1,4 @@ + from ..utils.plotutils import preprocess, postprocess from ..utils.computils import expected_calibration_error from keras import ops diff --git a/bayesflow/experimental/diagnostics/plot_confusion_matrix.py b/bayesflow/experimental/diagnostics/plot_confusion_matrix.py index ac522c778..8345b7276 100644 --- a/bayesflow/experimental/diagnostics/plot_confusion_matrix.py +++ b/bayesflow/experimental/diagnostics/plot_confusion_matrix.py @@ -1,3 +1,4 @@ + import matplotlib.pyplot as plt from keras import ops diff --git a/bayesflow/experimental/diagnostics/plot_distribution_2d.py b/bayesflow/experimental/diagnostics/plot_distribution_2d.py index 9b6855b10..28da8faba 100644 --- a/bayesflow/experimental/diagnostics/plot_distribution_2d.py +++ b/bayesflow/experimental/diagnostics/plot_distribution_2d.py @@ -1,3 +1,4 @@ + import logging import seaborn as sns import pandas as pd @@ -10,7 +11,7 @@ def plot_distribution_2d( color: str | tuple = "#8f2727", alpha: float = 0.9, n_params: int = None, - param_names: list = None, + param_names: list[str] = None, render: bool = True, **kwargs ): @@ -34,7 +35,8 @@ def plot_distribution_2d( param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None render : bool, optional, default: True - The boolean that determines whether to render the plot visually. If true, then the plot will render; otherwise, the plot will go through further steps for postprocessing + The boolean that determines whether to render the plot visually. If true, then the plot will render; + otherwise, the plot will go through further steps for postprocessing **kwargs : dict, optional Additional keyword arguments passed to the sns.PairGrid constructor """ diff --git a/bayesflow/experimental/diagnostics/plot_latent_space_2d.py b/bayesflow/experimental/diagnostics/plot_latent_space_2d.py index 7e91b2a11..d7426f7cd 100644 --- a/bayesflow/experimental/diagnostics/plot_latent_space_2d.py +++ b/bayesflow/experimental/diagnostics/plot_latent_space_2d.py @@ -1,3 +1,4 @@ + from .plot_distribution_2d import plot_distribution_2d from keras import backend as K diff --git a/bayesflow/experimental/diagnostics/plot_losses.py b/bayesflow/experimental/diagnostics/plot_losses.py index 0b0557c4e..45ff284da 100644 --- a/bayesflow/experimental/diagnostics/plot_losses.py +++ b/bayesflow/experimental/diagnostics/plot_losses.py @@ -1,3 +1,4 @@ + import seaborn as sns from keras import ops diff --git a/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py b/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py index 01b336094..1a935d60c 100644 --- a/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py +++ b/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py @@ -1,3 +1,4 @@ + import matplotlib.pyplot as plt import seaborn as sns diff --git a/bayesflow/experimental/diagnostics/plot_posterior_2d.py b/bayesflow/experimental/diagnostics/plot_posterior_2d.py index c546627c1..d2f96da67 100644 --- a/bayesflow/experimental/diagnostics/plot_posterior_2d.py +++ b/bayesflow/experimental/diagnostics/plot_posterior_2d.py @@ -1,3 +1,4 @@ + import pandas as pd import seaborn as sns diff --git a/bayesflow/experimental/diagnostics/plot_prior_2d.py b/bayesflow/experimental/diagnostics/plot_prior_2d.py index 5280486e4..6cc2ff7ce 100644 --- a/bayesflow/experimental/diagnostics/plot_prior_2d.py +++ b/bayesflow/experimental/diagnostics/plot_prior_2d.py @@ -1,3 +1,4 @@ + from .plot_distribution_2d import plot_distribution_2d diff --git a/bayesflow/experimental/diagnostics/plot_recovery.py b/bayesflow/experimental/diagnostics/plot_recovery.py index d8a4f83b7..65258fe2c 100644 --- a/bayesflow/experimental/diagnostics/plot_recovery.py +++ b/bayesflow/experimental/diagnostics/plot_recovery.py @@ -1,3 +1,4 @@ + import numpy as np from scipy.stats import median_abs_deviation from sklearn.metrics import r2_score diff --git a/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py b/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py index 1863d6620..d9546db44 100644 --- a/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py +++ b/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py @@ -1,3 +1,4 @@ + import seaborn as sns from keras import ops diff --git a/bayesflow/experimental/diagnostics/plot_sbc_histograms.py b/bayesflow/experimental/diagnostics/plot_sbc_histograms.py index 51d29bd26..6e74f5823 100644 --- a/bayesflow/experimental/diagnostics/plot_sbc_histograms.py +++ b/bayesflow/experimental/diagnostics/plot_sbc_histograms.py @@ -1,3 +1,4 @@ + import logging import seaborn as sns diff --git a/bayesflow/experimental/diagnostics/plot_z_score_contraction.py b/bayesflow/experimental/diagnostics/plot_z_score_contraction.py index 41ed78be4..5454957ff 100644 --- a/bayesflow/experimental/diagnostics/plot_z_score_contraction.py +++ b/bayesflow/experimental/diagnostics/plot_z_score_contraction.py @@ -1,3 +1,4 @@ + import seaborn as sns from ..utils.plotutils import preprocess, postprocess diff --git a/tests/test_diagnostics/__init__.py b/tests/test_diagnostics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_diagnostics/test_diagnostics.py b/tests/test_diagnostics/test_diagnostics.py new file mode 100644 index 000000000..4018ee577 --- /dev/null +++ b/tests/test_diagnostics/test_diagnostics.py @@ -0,0 +1,16 @@ + +import keras +import pytest + +from bayesflow.experimental.diagnostics import ( + plot_distribution_2d, + plot_prior_2d, + plot_posterior_2d +) + + +@pytest.fixture() +def test_plot_distribution_2d(): + pass + +#TODO \ No newline at end of file From 16ac39e5d85947b067bfa948b8630079cb060b45 Mon Sep 17 00:00:00 2001 From: Jerry Date: Thu, 22 Aug 2024 09:44:37 -0400 Subject: [PATCH 03/22] Miscellaneous --- .../diagnostics/plot_calibration_curves.py | 111 ----------- .../diagnostics/plot_confusion_matrix.py | 125 ------------ .../diagnostics/plot_distribution_2d.py | 67 +++---- .../diagnostics/plot_latent_space_2d.py | 37 ---- .../experimental/diagnostics/plot_losses.py | 122 ------------ .../diagnostics/plot_mmd_hypothesis_test.py | 101 ---------- .../diagnostics/plot_posterior_2d.py | 135 ------------- .../experimental/diagnostics/plot_prior_2d.py | 50 ----- .../experimental/diagnostics/plot_recovery.py | 165 ---------------- .../experimental/diagnostics/plot_sbc_ecdf.py | 178 ------------------ .../diagnostics/plot_sbc_histograms.py | 138 -------------- .../diagnostics/plot_z_score_contraction.py | 116 ------------ 12 files changed, 34 insertions(+), 1311 deletions(-) diff --git a/bayesflow/experimental/diagnostics/plot_calibration_curves.py b/bayesflow/experimental/diagnostics/plot_calibration_curves.py index 21bdb36a4..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_calibration_curves.py +++ b/bayesflow/experimental/diagnostics/plot_calibration_curves.py @@ -1,111 +0,0 @@ - -from ..utils.plotutils import preprocess, postprocess -from ..utils.computils import expected_calibration_error -from keras import ops - - -def plot_calibration_curves( - true_models, - pred_models, - model_names: list = None, - num_bins: int = 10, - label_fontsize: int = 16, - legend_fontsize: int = 14, - title_fontsize: int = 18, - tick_fontsize: int = 12, - epsilon: float = 0.02, - fig_size: tuple = None, - color: str | tuple = "#8f2727", - x_label: str = "Predicted probability", - y_label: str = "True probability", - n_row: int = None, - n_col: int = None, -): - """Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities - for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin. - Depends on the ``expected_calibration_error`` function for computing the ECE. - - Parameters - ---------- - true_models : np.ndarray of shape (num_data_sets, num_models) - The one-hot-encoded true model indices per data set. - pred_models : np.ndarray of shape (num_data_sets, num_models) - The predicted posterior model probabilities (PMPs) per data set. - model_names : list or None, optional, default: None - The model names for nice plot titles. Inferred if None. - num_bins : int, optional, default: 10 - The number of bins to use for the calibration curves (and marginal histograms). - label_fontsize : int, optional, default: 16 - The font size of the y-label and y-label texts - legend_fontsize : int, optional, default: 14 - The font size of the legend text (ECE value) - title_fontsize : int, optional, default: 18 - The font size of the title text. Only relevant if `stacked=False` - tick_fontsize : int, optional, default: 12 - The font size of the axis ticklabels - epsilon : float, optional, default: 0.02 - A small amount to pad the [0, 1]-bounded axes from both side. - fig_size : tuple or None, optional, default: None - The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` - color : str, optional, default: '#8f2727' - The color of the calibration curves - x_label : str, optional, default: Predicted probability - The x-axis label - y_label : str, optional, default: True probability - The y-axis label - n_row : int, optional, default: None - The number of rows for the subplots. Dynamically determined if None. - n_col : int, optional, default: None - The number of columns for the subplots. Dynamically determined if None. - - Returns - ------- - fig : plt.Figure - the figure instance for optional saving - """ - - f, axarr, ax, n_row, n_col, num_models, model_names = preprocess(true_models, pred_models, fig_size=fig_size) - - # Compute calibration - cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins) - - # Plot marginal calibration curves in a loop - for j in range(num_models): - # Plot calibration curve - ax[j].plot(probs_pred[j], probs_true[j], "o-", color=color) - - # Plot PMP distribution over bins - uniform_bins = ops.linspace(0.0, 1.0, num_bins + 1) - norm_weights = ops.ones_like(pred_models) / len(pred_models) - ax[j].hist(pred_models[:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3) - - # Plot AB line - ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9) - - # Tweak plot - ax[j].tick_params(axis="both", which="major", labelsize=tick_fontsize) - ax[j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) - ax[j].set_title(model_names[j], fontsize=title_fontsize) - ax[j].spines["right"].set_visible(False) - ax[j].spines["top"].set_visible(False) - ax[j].set_xlim([0 - epsilon, 1 + epsilon]) - ax[j].set_ylim([0 - epsilon, 1 + epsilon]) - ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) - ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) - ax[j].grid(alpha=0.5) - - # Add ECE label - ax[j].text( - 0.1, - 0.9, - r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}".format(cal_errs[j]), - horizontalalignment="left", - verticalalignment="center", - transform=ax[j].transAxes, - size=legend_fontsize, - ) - - # Post-processing - postprocess(axarr, ax, n_row, n_col, num_models, x_label, y_label, label_fontsize) - - f.tight_layout() - return f diff --git a/bayesflow/experimental/diagnostics/plot_confusion_matrix.py b/bayesflow/experimental/diagnostics/plot_confusion_matrix.py index 8345b7276..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_confusion_matrix.py +++ b/bayesflow/experimental/diagnostics/plot_confusion_matrix.py @@ -1,125 +0,0 @@ - -import matplotlib.pyplot as plt - -from keras import ops -from keras import backend as K -from sklearn.metrics import confusion_matrix -from matplotlib.colors import LinearSegmentedColormap -from ..utils.plotutils import initialize_figure - - -def plot_confusion_matrix( - true_models, - pred_models, - model_names: list = None, - fig_size=(5, 5), - label_fontsize: int = 16, - title_fontsize: int = 18, - value_fontsize: int = 10, - tick_fontsize: int = 12, - xtick_rotation: int = None, - ytick_rotation: int = None, - normalize: bool = True, - cmap=None, - title: bool = True, -): - """Plots a confusion matrix for validating a neural network trained for Bayesian model comparison. - - Parameters - ---------- - true_models : np.ndarray of shape (num_data_sets, num_models) - The one-hot-encoded true model indices per data set. - pred_models : np.ndarray of shape (num_data_sets, num_models) - The predicted posterior model probabilities (PMPs) per data set. - model_names : list or None, optional, default: None - The model names for nice plot titles. Inferred if None. - fig_size : tuple or None, optional, default: (5, 5) - The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` - label_fontsize : int, optional, default: 16 - The font size of the y-label and y-label texts - title_fontsize : int, optional, default: 18 - The font size of the title text. - value_fontsize : int, optional, default: 10 - The font size of the text annotations and the colorbar tick labels. - tick_fontsize : int, optional, default: 12 - The font size of the axis label and model name texts. - xtick_rotation: int, optional, default: None - Rotation of x-axis tick labels (helps with long model names). - ytick_rotation: int, optional, default: None - Rotation of y-axis tick labels (helps with long model names). - normalize : bool, optional, default: True - A flag for normalization of the confusion matrix. - If True, each row of the confusion matrix is normalized to sum to 1. - cmap : matplotlib.colors.Colormap or str, optional, default: None - Colormap to be used for the cells. If a str, it should be the name of a registered colormap, - e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red. - title : bool, optional, default True - A flag for adding 'Confusion Matrix' above the matrix. - - Returns - ------- - fig : plt.Figure - the figure instance for optional saving - """ - - if model_names is None: - num_models = true_models.shape[-1] - model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)] - - if cmap is None: - cmap = LinearSegmentedColormap.from_list("", ["white", "#8f2727"]) - - # Flatten input - true_models = ops.argmax(true_models, axis=1) - pred_models = ops.argmax(pred_models, axis=1) - - # Compute confusion matrix - cm = confusion_matrix(true_models, pred_models) - - if normalize: - # Convert to Keras tensor - cm_tensor = K.constant(cm, dtype='float32') - - # Sum along rows and keep dimensions for broadcasting - cm_sum = K.sum(cm_tensor, axis=1, keepdims=True) - - # Broadcast division for normalization - cm_normalized = cm_tensor / cm_sum - - # Since we might need to use this outside of a session, evaluate using K.eval() if necessary - cm_normalized = K.eval(cm_normalized) - - # Initialize figure - fig, ax = initialize_figure(1, 1, fig_size=fig_size) - # fig, ax = plt.subplots(1, 1, figsize=fig_size) - im = ax.imshow(cm, interpolation="nearest", cmap=cmap) - cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75) - - cbar.ax.tick_params(labelsize=value_fontsize) - - ax.set(xticks=ops.arange(cm.shape[1]), yticks=ops.arange(cm.shape[0])) - ax.set_xticklabels(model_names, fontsize=tick_fontsize) - if xtick_rotation: - plt.xticks(rotation=xtick_rotation, ha="right") - ax.set_yticklabels(model_names, fontsize=tick_fontsize) - if ytick_rotation: - plt.yticks(rotation=ytick_rotation) - ax.set_xlabel("Predicted model", fontsize=label_fontsize) - ax.set_ylabel("True model", fontsize=label_fontsize) - - # Loop over data dimensions and create text annotations - fmt = ".2f" if normalize else "d" - thresh = cm.max() / 2.0 - for i in range(cm.shape[0]): - for j in range(cm.shape[1]): - ax.text( - j, - i, - format(cm[i, j], fmt), - fontsize=value_fontsize, - ha="center", - va="center", - color="white" if cm[i, j] > thresh else "black", - ) - if title: - ax.set_title("Confusion Matrix", fontsize=title_fontsize) - return fig \ No newline at end of file diff --git a/bayesflow/experimental/diagnostics/plot_distribution_2d.py b/bayesflow/experimental/diagnostics/plot_distribution_2d.py index 28da8faba..f03ba7907 100644 --- a/bayesflow/experimental/diagnostics/plot_distribution_2d.py +++ b/bayesflow/experimental/diagnostics/plot_distribution_2d.py @@ -3,40 +3,42 @@ import seaborn as sns import pandas as pd +from bayesflow.types import Tensor def plot_distribution_2d( - samples, - context: str = None, - height: float = 2.5, - color: str | tuple = "#8f2727", - alpha: float = 0.9, - n_params: int = None, - param_names: list[str] = None, - render: bool = True, - **kwargs + samples: dict[str, Tensor] = None, + parameters: str = None, + n_params: int = None, + param_names: list = None, + height: float = 2.5, + color: str | tuple = "#8f2727", + alpha: float = 0.9, + render: bool = True, + **kwargs ): """ - A more flexible pairplot function for multiple distributions based upon collected samples. + A more flexible pair plot function for multiple distributions based upon collected samples. Parameters ---------- - samples : np.ndarray or tf.Tensor of shape (n_sim, n_params) + samples : dict[str, Tensor], default: None Sample draws from any dataset - context : str + parameters : str, default: None The context that the sample represents height : float, optional, default: 2.5 The height of the pair plot color : str, optional, default : '#8f2727' The color of the plot - alpha : float in [0, 1], optonal, default: 0.9 + alpha : float in [0, 1], optional, default: 0.9 The opacity of the plot n_params : int, optional, default: None The number of parameters in the collection of distributions param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None render : bool, optional, default: True - The boolean that determines whether to render the plot visually. If true, then the plot will render; - otherwise, the plot will go through further steps for postprocessing + The boolean that determines whether to render the plot visually. + If true, then the plot will render; + otherwise, the plot will go through further steps for postprocessing. **kwargs : dict, optional Additional keyword arguments passed to the sns.PairGrid constructor """ @@ -48,40 +50,39 @@ def plot_distribution_2d( n_params = dim # Generate context if there is none - if context is None: - context = "Generic" + if parameters is None: + parameters = "Parameter" # Generate titles if param_names is None: - titles = [f"{context} Param. {i}" for i in range(1, dim + 1)] + titles = [f"{parameters} {i}" for i in range(1, dim + 1)] else: - titles = [f"{context} {p}" for p in param_names] - + titles = [f"{parameters} {p}" for p in param_names] + # Convert samples to pd.DataFrame data_to_plot = pd.DataFrame(samples, columns=titles) # Generate plots - g = sns.PairGrid(data_to_plot, height=height, **kwargs) + artist = sns.PairGrid(data_to_plot, height=height, **kwargs) - g.map_diag(sns.histplot, fill=True, color=color, alpha=alpha, kde=True) + artist.map_diag(sns.histplot, fill=True, color=color, alpha=alpha, kde=True) # Incorporate exceptions for generating KDE plots - try: - g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha) + try: + artist.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha) except Exception as e: logging.warning("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.") - g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) - - g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) + artist.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) + + artist.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) if render: # Generate grids for i in range(dim): for j in range(dim): - g.axes[i, j].grid(alpha=0.5) - + artist.axes[i, j].grid(alpha=0.5) + # Return figure - g.tight_layout() - return g - else: - return g + artist.tight_layout() + + return artist diff --git a/bayesflow/experimental/diagnostics/plot_latent_space_2d.py b/bayesflow/experimental/diagnostics/plot_latent_space_2d.py index d7426f7cd..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_latent_space_2d.py +++ b/bayesflow/experimental/diagnostics/plot_latent_space_2d.py @@ -1,37 +0,0 @@ - -from .plot_distribution_2d import plot_distribution_2d - -from keras import backend as K - - -def plot_latent_space_2d( - z_samples, - height: float = 2.5, - color="#8f2727", - **kwargs -): - """Creates pair plots for the latent space learned by the inference network. Enables - visual inspection of the latent space and whether its structure corresponds to the - one enforced by the optimization criterion. - - Parameters - ---------- - z_samples : np.ndarray or tf.Tensor of shape (n_sim, n_params) - The latent samples computed through a forward pass of the inference network. - height : float, optional, default: 2.5 - The height of the pair plot. - color : str, optional, default : '#8f2727' - The color of the plot - **kwargs : dict, optional - Additional keyword arguments passed to the sns.PairGrid constructor - - Returns - ------- - f : plt.Figure - the figure instance for optional saving - """ - - # Try to convert z_samples, if eventually tf.Tensor is passed - if not isinstance(z_samples, K.tf.Tensor): - z_samples = K.constant(z_samples) - - plot_distribution_2d(z_samples, context="Latent Dim", height=height, color=color, render=True, **kwargs) diff --git a/bayesflow/experimental/diagnostics/plot_losses.py b/bayesflow/experimental/diagnostics/plot_losses.py index 45ff284da..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_losses.py +++ b/bayesflow/experimental/diagnostics/plot_losses.py @@ -1,122 +0,0 @@ - -import seaborn as sns - -from keras import ops -from ..utils.plotutils import initialize_figure - - -def plot_losses( - train_losses, - val_losses=None, - moving_average: bool = False, - ma_window_fraction: float = 0.01, - fig_size=None, - train_color: str = "#8f2727", - val_color: str = "black", - lw_train: int = 2, - lw_val: int = 3, - grid_alpha: float = 0.5, - legend_fontsize: int = 14, - label_fontsize: int = 14, - title_fontsize: int = 16, -): - """A generic helper function to plot the losses of a series of training epochs and runs. - - Parameters - ---------- - - train_losses : pd.DataFrame - The (plottable) history as returned by a train_[...] method of a ``Trainer`` instance. - Alternatively, you can just pass a data frame of validation losses instead of train losses, - if you only want to plot the validation loss. - val_losses : pd.DataFrame or None, optional, default: None - The (plottable) validation history as returned by a train_[...] method of a ``Trainer`` instance. - If left ``None``, only train losses are plotted. Should have the same number of columns - as ``train_losses``. - moving_average : bool, optional, default: False - A flag for adding a moving average line of the train_losses. - ma_window_fraction : int, optional, default: 0.01 - Window size for the moving average as a fraction of total training steps. - fig_size : tuple or None, optional, default: None - The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` - train_color : str, optional, default: '#8f2727' - The color for the train loss trajectory - val_color : str, optional, default: black - The color for the optional validation loss trajectory - lw_train : int, optional, default: 2 - The linewidth for the training loss curve - lw_val : int, optional, default: 3 - The linewidth for the validation loss curve - grid_alpha : float, optional, default 0.5 - The opacity factor for the background gridlines - legend_fontsize : int, optional, default: 14 - The font size of the legend text - label_fontsize : int, optional, default: 14 - The font size of the y-label text - title_fontsize : int, optional, default: 16 - The font size of the title text - - Returns - ------- - f : plt.Figure - the figure instance for optional saving - - Raises - ------ - AssertionError - If the number of columns in ``train_losses`` does not match the - number of columns in ``val_losses``. - """ - - # Determine the number of rows for plot - n_row = len(train_losses.columns) - - # Initialize figure - f, axarr = initialize_figure(n_row=n_row, n_col=1, fig_size=(16, int(4 * n_row))) - - # if fig_size is None: - # fig_size = (16, int(4 * n_row)) - # f, axarr = plt.subplots(n_row, 1, figsize=fig_size) - - # Get the number of steps as an array - train_step_index = ops.arange(1, len(train_losses) + 1) - if val_losses is not None: - val_step = int(ops.floor(len(train_losses) / len(val_losses))) - val_step_index = train_step_index[(val_step - 1) :: val_step] - - # If unequal length due to some reason, attempt a fix - if val_step_index.shape[0] > val_losses.shape[0]: - val_step_index = val_step_index[: val_losses.shape[0]] - - # Loop through loss entries and populate plot - looper = [axarr] if n_row == 1 else axarr.flat - for i, ax in enumerate(looper): - # Plot train curve - ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training") - if moving_average and train_losses.columns[i] == "Loss": - moving_average_window = int(train_losses.shape[0] * ma_window_fraction) - smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean() - ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)") - - # Plot optional val curve - if val_losses is not None: - if i < val_losses.shape[1]: - ax.plot( - val_step_index, - val_losses.iloc[:, i], - linestyle="--", - marker="o", - color=val_color, - lw=lw_val, - label="Validation", - ) - # Schmuck - ax.set_xlabel("Training step #", fontsize=label_fontsize) - ax.set_ylabel("Value", fontsize=label_fontsize) - sns.despine(ax=ax) - ax.grid(alpha=grid_alpha) - ax.set_title(train_losses.columns[i], fontsize=title_fontsize) - # Only add legend if there is a validation curve - if val_losses is not None or moving_average: - ax.legend(fontsize=legend_fontsize) - f.tight_layout() - return f diff --git a/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py b/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py index 1a935d60c..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py +++ b/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py @@ -1,101 +0,0 @@ - -import matplotlib.pyplot as plt -import seaborn as sns - -from keras import ops - - -def plot_mmd_hypothesis_test( - mmd_null, - mmd_observed: float = None, - alpha_level: float = 0.05, - null_color: str | tuple = (0.16407, 0.020171, 0.577478), - observed_color: str | tuple = "red", - alpha_color: str | tuple = "orange", - truncate_v_lines_at_kde: bool = False, - x_min: float = None, - x_max: float = None, - bw_factor: float = 1.5, -): - """ - - Parameters - ---------- - mmd_null : np.ndarray - The samples from the MMD sampling distribution under the null hypothesis "the model is well-specified" - mmd_observed : float - The observed MMD value - alpha_level : float, optional, default: 0.05 - The rejection probability (type I error) - null_color : str or tuple, optional, default: (0.16407, 0.020171, 0.577478) - The color of the H0 sampling distribution - observed_color : str or tuple, optional, default: "red" - The color of the observed MMD - alpha_color : str or tuple, optional, default: "orange" - The color of the rejection area - truncate_v_lines_at_kde: bool, optional, default: False - true: cut off the vlines at the kde - false: continue kde lines across the plot - x_min : float, optional, default: None - The lower x-axis limit - x_max : float, optional, default: None - The upper x-axis limit - bw_factor : float, optional, default: 1.5 - bandwidth (aka. smoothing parameter) of the kernel density estimate - - Returns - ------- - f : plt.Figure - the figure instance for optional saving - """ - - def draw_v_line_to_kde(x, kde_object, color, label=None, **kwargs): - kde_x, kde_y = kde_object.lines[0].get_data() - idx = ops.argmin(ops.abs(kde_x - x)) - plt.vlines(x=x, ymin=0, ymax=kde_y[idx], color=color, linewidth=3, label=label, **kwargs) - - def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs): - kde_x, kde_y = kde_object.lines[0].get_data() - if x_end is not None: - plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start) & (kde_x <= x_end), interpolate=True, **kwargs) - else: - plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start), interpolate=True, **kwargs) - - f = plt.figure(figsize=(8, 4)) - - kde = sns.kdeplot(mmd_null, fill=False, linewidth=0, bw_adjust=bw_factor) - sns.kdeplot(mmd_null, fill=True, alpha=0.12, color=null_color, bw_adjust=bw_factor) - - if truncate_v_lines_at_kde: - draw_v_line_to_kde(x=mmd_observed, kde_object=kde, color=observed_color, label=r"Observed data") - else: - plt.vlines( - x=mmd_observed, - ymin=0, - ymax=plt.gca().get_ylim()[1], - color=observed_color, - linewidth=3, - label=r"Observed data", - ) - - mmd_critical = ops.quantile(mmd_null, 1 - alpha_level) - fill_area_under_kde( - kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level*100)}% rejection area" - ) - - if truncate_v_lines_at_kde: - draw_v_line_to_kde(x=mmd_critical, kde_object=kde, color=alpha_color) - else: - plt.vlines(x=mmd_critical, color=alpha_color, linewidth=3, ymin=0, ymax=plt.gca().get_ylim()[1]) - - sns.kdeplot(mmd_null, fill=False, linewidth=3, color=null_color, label=r"$H_0$", bw_adjust=bw_factor) - - plt.xlabel(r"MMD", fontsize=20) - plt.ylabel("") - plt.yticks([]) - plt.xlim(x_min, x_max) - plt.tick_params(axis="both", which="major", labelsize=16) - - plt.legend(fontsize=20) - sns.despine() - - return f diff --git a/bayesflow/experimental/diagnostics/plot_posterior_2d.py b/bayesflow/experimental/diagnostics/plot_posterior_2d.py index d2f96da67..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_posterior_2d.py +++ b/bayesflow/experimental/diagnostics/plot_posterior_2d.py @@ -1,135 +0,0 @@ - -import pandas as pd -import seaborn as sns - -from matplotlib.lines import Line2D -from .plot_distribution_2d import plot_distribution_2d - - -def plot_posterior_2d( - posterior_draws, - prior=None, - prior_draws=None, - param_names: list = None, - height: int = 3, - label_fontsize: int = 14, - legend_fontsize: int = 16, - tick_fontsize: int = 12, - post_color: str | tuple = "#8f2727", - prior_color: str | tuple = "gray", - post_alpha: float = 0.9, - prior_alpha: float = 0.7, - **kwargs -): - """Generates a bivariate pairplot given posterior draws and optional prior or prior draws. - - posterior_draws : np.ndarray of shape (n_post_draws, n_params) - The posterior draws obtained for a SINGLE observed data set. - prior : bayesflow.forward_inference.Prior instance or None, optional, default: None - The optional prior object having an input-output signature as given by ayesflow.forward_inference.Prior - prior_draws : np.ndarray of shape (n_prior_draws, n_params) or None, optonal (default: None) - The optional prior draws obtained from the prior. If both prior and prior_draws are provided, prior_draws - will be used. - param_names : list or None, optional, default: None - The parameter names for nice plot titles. Inferred if None - height : float, optional, default: 3 - The height of the pairplot - label_fontsize : int, optional, default: 14 - The font size of the x and y-label texts (parameter names) - legend_fontsize : int, optional, default: 16 - The font size of the legend text - tick_fontsize : int, optional, default: 12 - The font size of the axis ticklabels - post_color : str, optional, default: '#8f2727' - The color for the posterior histograms and KDEs - priors_color : str, optional, default: gray - The color for the optional prior histograms and KDEs - post_alpha : float in [0, 1], optonal, default: 0.9 - The opacity of the posterior plots - prior_alpha : float in [0, 1], optonal, default: 0.7 - The opacity of the prior plots - - Returns - ------- - f : plt.Figure - the figure instance for optional saving - - Raises - ------ - AssertionError - If the shape of posterior_draws is not 2-dimensional. - """ - - # Ensure correct shape - assert ( - len(posterior_draws.shape) - ) == 2, "Shape of `posterior_samples` for a single data set should be 2 dimensional!" - - # Plot posterior first - g = plot_distribution_2d( - posterior_draws, - context="\\theta", - param_names=param_names, - render=False, - **kwargs - ) - - # Obtain n_draws and n_params - n_draws, n_params = posterior_draws.shape - - # If prior object is given and no draws, obtain draws - if prior is not None and prior_draws is None: - draws = prior(n_draws) - if type(draws) is dict: - prior_draws = draws["prior_draws"] - else: - prior_draws = draws - - # Attempt to determine parameter names - if param_names is None: - if hasattr(prior, "param_names"): - if prior.param_names is not None: - param_names = prior.param_names - else: - param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] - else: - param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] - - # Add prior, if given - if prior_draws is not None: - prior_draws_df = pd.DataFrame(prior_draws, columns=param_names) - g.data = prior_draws_df - g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1) - g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1) - - # Add legend, if prior also given - if prior_draws is not None or prior is not None: - handles = [ - Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha), - Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha), - ] - g.legend(handles, ["Posterior", "Prior"], fontsize=legend_fontsize, loc="center right") - - n_row, n_col = g.axes.shape - - for i in range(n_row): - # Remove upper axis - for j in range(i+1, n_col): - g.axes[i, j].axis("off") - - # Modify tick sizes - for j in range(i + 1): - g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize) - g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) - - # Add nice labels - for i, param_name in enumerate(param_names): - g.axes[i, 0].set_ylabel(param_name, fontsize=label_fontsize) - g.axes[len(param_names) - 1, i].set_xlabel(param_name, fontsize=label_fontsize) - - # Add grids - for i in range(n_params): - for j in range(n_params): - g.axes[i, j].grid(alpha=0.5) - - g.tight_layout() - return g diff --git a/bayesflow/experimental/diagnostics/plot_prior_2d.py b/bayesflow/experimental/diagnostics/plot_prior_2d.py index 6cc2ff7ce..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_prior_2d.py +++ b/bayesflow/experimental/diagnostics/plot_prior_2d.py @@ -1,50 +0,0 @@ - -from .plot_distribution_2d import plot_distribution_2d - - -def plot_prior_2d( - prior, - param_names: list = None, - n_samples: int = 2000, - height: float = 2.5, - color: str | tuple = "#8f2727", - **kwargs -): - """Creates pair-plots for a given joint prior. - - Parameters - ---------- - prior : callable - The prior object which takes a single integer argument and generates random draws. - param_names : list of str or None, optional, default None - An optional list of strings which - n_samples : int, optional, default: 1000 - The number of random draws from the joint prior - height : float, optional, default: 2.5 - The height of the pair plot - color : str, optional, default : '#8f2727' - The color of the plot - **kwargs : dict, optional - Additional keyword arguments passed to the sns.PairGrid constructor - - Returns - ------- - f : plt.Figure - the figure instance for optional saving - """ - - # Generate prior draws - prior_samples = prior(n_samples) - - # Handle dict type - if type(prior_samples) is dict: - prior_samples = prior_samples["prior_draws"] - - plot_distribution_2d( - prior_samples, - context="Prior", - height=height, - color=color, - param_names=param_names, - render=True, - **kwargs - ) diff --git a/bayesflow/experimental/diagnostics/plot_recovery.py b/bayesflow/experimental/diagnostics/plot_recovery.py index 65258fe2c..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_recovery.py +++ b/bayesflow/experimental/diagnostics/plot_recovery.py @@ -1,165 +0,0 @@ - -import numpy as np -from scipy.stats import median_abs_deviation -from sklearn.metrics import r2_score -import seaborn as sns - -from ..utils.plotutils import preprocess, postprocess - - -def plot_recovery( - post_samples, - prior_samples, - point_agg=np.median, - uncertainty_agg=median_abs_deviation, - param_names: list = None, - fig_size: tuple = None, - label_fontsize: int = 16, - title_fontsize: int = 18, - metric_fontsize: int = 16, - tick_fontsize: int = 12, - add_corr: bool = True, - add_r2: bool = True, - color: str | tuple = "#8f2727", - n_col: int = None, - n_row: int = None, - xlabel: str = "Ground truth", - ylabel: str = "Estimated", - **kwargs, -): - """Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty. - The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate - can be controlled with the ``uncertainty_agg`` argument. - - This plot yields similar information as the "posterior z-score", but allows for generic - point and uncertainty estimates: - - https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html - - Important: Posterior aggregates play no special role in Bayesian inference and should only - be used heuristically. For instance, in the case of multi-modal posteriors, common point - estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing. - - Parameters - ---------- - post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) - The posterior draws obtained from n_data_sets - prior_samples : np.ndarray of shape (n_data_sets, n_params) - The prior draws (true parameters) obtained for generating the n_data_sets - point_agg : callable, optional, default: ``np.median`` - The function to apply to the posterior draws to get a point estimate for each marginal. - The default computes the marginal median for each marginal posterior as a robust - point estimate. - uncertainty_agg : callable or None, optional, default: scipy.stats.median_abs_deviation - The function to apply to the posterior draws to get an uncertainty estimate. - If ``None`` provided, a simple scatter using only ``point_agg`` will be plotted. - param_names : list or None, optional, default: None - The parameter names for nice plot titles. Inferred if None - fig_size : tuple or None, optional, default : None - The figure size passed to the matplotlib constructor. Inferred if None. - label_fontsize : int, optional, default: 16 - The font size of the y-label text - title_fontsize : int, optional, default: 18 - The font size of the title text - metric_fontsize : int, optional, default: 16 - The font size of the goodness-of-fit metric (if provided) - tick_fontsize : int, optional, default: 12 - The font size of the axis tick labels - add_corr : bool, optional, default: True - A flag for adding correlation between true and estimates to the plot - add_r2 : bool, optional, default: True - A flag for adding R^2 between true and estimates to the plot - color : str, optional, default: '#8f2727' - The color for the true vs. estimated scatter points and error bars - n_row : int, optional, default: None - The number of rows for the subplots. Dynamically determined if None. - n_col : int, optional, default: None - The number of columns for the subplots. Dynamically determined if None. - xlabel : str, optional, default: 'Ground truth' - The label on the x-axis of the plot - ylabel : str, optional, default: 'Estimated' - The label on the y-axis of the plot - **kwargs : optional - Additional keyword arguments passed to ax.errorbar or ax.scatter. - Example: `rasterized=True` to reduce PDF file size with many dots - - Returns - ------- - f : plt.Figure - the figure instance for optional saving - - Raises - ------ - ShapeError - If there is a deviation from the expected shapes of ``post_samples`` and ``prior_samples``. - """ - - # Preprocess - f, axarr, axarr_it, n_row, n_col, n_params, param_names = preprocess( - post_samples, prior_samples, fig_size=fig_size - ) - - # Compute point estimates and uncertainties - est = point_agg(post_samples, axis=1) - if uncertainty_agg is not None: - u = uncertainty_agg(post_samples, axis=1) - - # Loop and plot - for i, ax in enumerate(axarr_it): - if i >= n_params: - break - - # Add scatter and error bars - if uncertainty_agg is not None: - _ = ax.errorbar(prior_samples[:, i], est[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color, **kwargs) - else: - _ = ax.scatter(prior_samples[:, i], est[:, i], alpha=0.5, color=color, **kwargs) - - # Make plots quadratic to avoid visual illusions - lower = min(prior_samples[:, i].min(), est[:, i].min()) - upper = max(prior_samples[:, i].max(), est[:, i].max()) - eps = (upper - lower) * 0.1 - ax.set_xlim([lower - eps, upper + eps]) - ax.set_ylim([lower - eps, upper + eps]) - ax.plot( - [ax.get_xlim()[0], ax.get_xlim()[1]], - [ax.get_ylim()[0], ax.get_ylim()[1]], - color="black", - alpha=0.9, - linestyle="dashed", - ) - - # Add optional metrics and title - if add_r2: - r2 = r2_score(prior_samples[:, i], est[:, i]) - ax.text( - 0.1, - 0.9, - "$R^2$ = {:.3f}".format(r2), - horizontalalignment="left", - verticalalignment="center", - transform=ax.transAxes, - size=metric_fontsize, - ) - if add_corr: - corr = np.corrcoef(prior_samples[:, i], est[:, i])[0, 1] - ax.text( - 0.1, - 0.8, - "$r$ = {:.3f}".format(corr), - horizontalalignment="left", - verticalalignment="center", - transform=ax.transAxes, - size=metric_fontsize, - ) - ax.set_title(param_names[i], fontsize=title_fontsize) - - # Prettify - sns.despine(ax=ax) - ax.grid(alpha=0.5) - ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) - ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) - - postprocess(axarr, axarr_it, n_row, n_col, n_params, xlabel, ylabel, label_fontsize) - - f.tight_layout() - return f diff --git a/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py b/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py index d9546db44..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py +++ b/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py @@ -1,178 +0,0 @@ - -import seaborn as sns - -from keras import ops -from keras import backend as K -from ..utils.computils import simultaneous_ecdf_bands -from ..utils.plotutils import preprocess, remove_unused_axes - - -def plot_sbc_ecdf( - post_samples, - prior_samples, - difference: bool = False, - stacked: bool = False, - fig_size: tuple = None, - param_names: list = None, - label_fontsize: int = 16, - legend_fontsize: int = 14, - title_fontsize: int = 18, - tick_fontsize: int = 12, - rank_ecdf_color: str | tuple = "#a34f4f", - fill_color: str | tuple = "grey", - n_row: int = None, - n_col: int = None, - **kwargs, -): - """Creates the empirical CDFs for each marginal rank distribution and plots it against - a uniform ECDF. ECDF simultaneous bands are drawn using simulations from the uniform, - as proposed by [1]. - - For models with many parameters, use `stacked=True` to obtain an idea of the overall calibration - of a posterior approximator. - - [1] Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and - its applications in goodness-of-fit evaluation and multiple sample comparison. Statistics and Computing, - 32(2), 1-21. https://arxiv.org/abs/2103.10522 - - Parameters - ---------- - post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) - The posterior draws obtained from n_data_sets - prior_samples : np.ndarray of shape (n_data_sets, n_params) - The prior draws obtained for generating n_data_sets - difference : bool, optional, default: False - If `True`, plots the ECDF difference. Enables a more dynamic visualization range. - stacked : bool, optional, default: False - If `True`, all ECDFs will be plotted on the same plot. If `False`, each ECDF will - have its own subplot, similar to the behavior of `plot_sbc_histograms`. - param_names : list or None, optional, default: None - The parameter names for nice plot titles. Inferred if None. Only relevant if `stacked=False`. - fig_size : tuple or None, optional, default: None - The figure size passed to the matplotlib constructor. Inferred if None. - label_fontsize : int, optional, default: 16 - The font size of the y-label and y-label texts - legend_fontsize : int, optional, default: 14 - The font size of the legend text - title_fontsize : int, optional, default: 18 - The font size of the title text. Only relevant if `stacked=False` - tick_fontsize : int, optional, default: 12 - The font size of the axis ticklabels - rank_ecdf_color : str, optional, default: '#a34f4f' - The color to use for the rank ECDFs - fill_color : str, optional, default: 'grey' - The color of the fill arguments. - n_row : int, optional, default: None - The number of rows for the subplots. Dynamically determined if None. - n_col : int, optional, default: None - The number of columns for the subplots. Dynamically determined if None. - **kwargs : dict, optional, default: {} - Keyword arguments can be passed to control the behavior of ECDF simultaneous band computation - through the ``ecdf_bands_kwargs`` dictionary. See `simultaneous_ecdf_bands` for keyword arguments - - Returns - ------- - f : plt.Figure - the figure instance for optional saving - - Raises - ------ - ShapeError - If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. - """ - - f, ax, ax_it, n_row, n_col, n_params, param_names = preprocess( - post_samples, prior_samples, collapse=False, fig_size=fig_size) - - # Compute fractional ranks (using broadcasting) - post_samples = K.constant(post_samples) - prior_samples = K.constant(prior_samples) - - # Adding an extra dimension to prior_samples using K.expand_dims - prior_samples_expanded = K.expand_dims(prior_samples, axis=1) - - # Performing element-wise comparison - comparison = K.less(post_samples, prior_samples_expanded) - - # Summing along the specified axis (axis=1) - sums = K.sum(K.cast(comparison, dtype='float32'), axis=1) - - # Getting the shape of post_samples - post_samples_shape = K.shape(post_samples) - - # Computing the ranks - ranks = sums / K.cast(post_samples_shape[1], dtype='float32') - - # ranks = ops.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) / post_samples.shape[1] - - - # Plot individual ecdf of parameters - for j in range(ranks.shape[-1]): - ecdf_single = ops.sort(ranks[:, j]) - xx = ecdf_single - yy = ops.arange(1, xx.shape[-1] + 1) / float(xx.shape[-1]) - - # Difference, if specified - if difference: - yy -= xx - - if stacked: - if j == 0: - ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs") - else: - ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95) - else: - ax.flat[j].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDF") - - # Compute uniform ECDF and bands - alpha, z, L, H = simultaneous_ecdf_bands(post_samples.shape[0], **kwargs.pop("ecdf_bands_kwargs", {})) - - # Difference, if specified - if difference: - L -= z - H -= z - ylab = "ECDF difference" - else: - ylab = "ECDF" - - # Add simultaneous bounds - if stacked: - titles = [None] - axes = [ax] - else: - axes = ax.flat - if param_names is None: - titles = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] - else: - titles = param_names - - for _ax, title in zip(axes, titles): - _ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands") - - # Prettify plot - sns.despine(ax=_ax) - _ax.grid(alpha=0.35) - _ax.legend(fontsize=legend_fontsize) - _ax.set_title(title, fontsize=title_fontsize) - _ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) - _ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) - - # Only add x-labels to the bottom row - if stacked: - bottom_row = [ax] - else: - bottom_row = ax if n_row == 1 else ax[-1, :] - for _ax in bottom_row: - _ax.set_xlabel("Fractional rank statistic", fontsize=label_fontsize) - - # Only add y-labels to right left-most row - if n_row == 1: # if there is only one row, the ax array is 1D - axes[0].set_ylabel(ylab, fontsize=label_fontsize) - else: # if there is more than one row, the ax array is 2D - for _ax in ax[:, 0]: - _ax.set_ylabel(ylab, fontsize=label_fontsize) - - # Remove unused axes entirely - remove_unused_axes(ax) - - f.tight_layout() - return f diff --git a/bayesflow/experimental/diagnostics/plot_sbc_histograms.py b/bayesflow/experimental/diagnostics/plot_sbc_histograms.py index 6e74f5823..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_sbc_histograms.py +++ b/bayesflow/experimental/diagnostics/plot_sbc_histograms.py @@ -1,138 +0,0 @@ - -import logging -import seaborn as sns - -from scipy.stats import binom -from keras import ops -from keras import backend as K -from ..utils.plotutils import preprocess, remove_unused_axes - - -def plot_sbc_histograms( - post_samples, - prior_samples, - param_names: list = None, - fig_size: tuple = None, - num_bins: int = None, - binomial_interval: float = 0.99, - label_fontsize: int = 16, - title_fontsize: int = 18, - tick_fontsize: int = 12, - hist_color: str | tuple = "#a34f4f", - n_row: int = None, - n_col: int = None, -): - """Creates and plots publication-ready histograms of rank statistics for simulation-based calibration - (SBC) checks according to [1]. - - Any deviation from uniformity indicates miscalibration and thus poor convergence - of the networks or poor combination between generative model / networks. - - [1] Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). - Validating Bayesian inference algorithms with simulation-based calibration. - arXiv preprint arXiv:1804.06788. - - Parameters - ---------- - post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) - The posterior draws obtained from n_data_sets - prior_samples : np.ndarray of shape (n_data_sets, n_params) - The prior draws obtained for generating n_data_sets - param_names : list or None, optional, default: None - The parameter names for nice plot titles. Inferred if None - fig_size : tuple or None, optional, default : None - The figure size passed to the matplotlib constructor. Inferred if None - num_bins : int, optional, default: 10 - The number of bins to use for each marginal histogram - binomial_interval : float in (0, 1), optional, default: 0.99 - The width of the confidence interval for the binomial distribution - label_fontsize : int, optional, default: 16 - The font size of the y-label text - title_fontsize : int, optional, default: 18 - The font size of the title text - tick_fontsize : int, optional, default: 12 - The font size of the axis ticklabels - hist_color : str, optional, default '#a34f4f' - The color to use for the histogram body - n_row : int, optional, default: None - The number of rows for the subplots. Dynamically determined if None. - n_col : int, optional, default: None - The number of columns for the subplots. Dynamically determined if None. - - Returns - ------- - f : plt.Figure - the figure instance for optional saving - - Raises - ------ - ShapeError - If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. - """ - - f, axarr, ax, n_row, n_col, n_params, param_names = preprocess(post_samples, prior_samples, fig_size=fig_size) - - # Determine the ratio of simulations to prior draws - n_sim, n_draws, _ = post_samples.shape - ratio = int(n_sim / n_draws) - - # Log a warning if N/B ratio recommended by Talts et al. (2018) < 20 - if ratio < 20: - logger = logging.getLogger() - logger.setLevel(logging.INFO) - logger.info( - f"The ratio of simulations / posterior draws should be > 20 " - + f"for reliable variance reduction, but your ratio is {ratio}.\ - Confidence intervals might be unreliable!" - ) - - # Set n_bins automatically, if nothing provided - if num_bins is None: - num_bins = int(ratio / 2) - # Attempt a fix if a single bin is determined so plot still makes sense - if num_bins == 1: - num_bins = 5 - - # Compute ranks (using broadcasting) - post_samples = K.constant(post_samples) - prior_samples = K.constant(prior_samples) - - # Adding an extra dimension to prior_samples using K.expand_dims - prior_samples_expanded = K.expand_dims(prior_samples, axis=1) - - # Performing element-wise comparison - comparison = K.less(post_samples, prior_samples_expanded) - - # Summing along the specified axis (axis=1) - ranks = K.sum(K.cast(comparison, dtype='float32'), axis=1) - # ranks = ops.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) - - # Compute confidence interval and mean - N = int(prior_samples.shape[0]) - # uniform distribution expected -> for all bins: equal probability - # p = 1 / num_bins that a rank lands in that bin - endpoints = binom.interval(binomial_interval, N, 1 / num_bins) - mean = N / num_bins # corresponds to binom.mean(N, 1 / num_bins) - - # Plot marginal histograms in a loop - for j in range(len(param_names)): - ax[j].axhspan(endpoints[0], endpoints[1], facecolor="gray", alpha=0.3) - ax[j].axhline(mean, color="gray", zorder=0, alpha=0.9) - sns.histplot(ranks[:, j], kde=False, ax=ax[j], color=hist_color, bins=num_bins, alpha=0.95) - ax[j].set_title(param_names[j], fontsize=title_fontsize) - ax[j].spines["right"].set_visible(False) - ax[j].spines["top"].set_visible(False) - ax[j].get_yaxis().set_ticks([]) - ax[j].set_ylabel("") - ax[j].tick_params(axis="both", which="major", labelsize=tick_fontsize) - ax[j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) - - # Only add x-labels to the bottom row - bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] - for _ax in bottom_row: - _ax.set_xlabel("Rank statistic", fontsize=label_fontsize) - - # Remove unused axes entirely - remove_unused_axes(axarr, n_params) - - f.tight_layout() - return f diff --git a/bayesflow/experimental/diagnostics/plot_z_score_contraction.py b/bayesflow/experimental/diagnostics/plot_z_score_contraction.py index 5454957ff..e69de29bb 100644 --- a/bayesflow/experimental/diagnostics/plot_z_score_contraction.py +++ b/bayesflow/experimental/diagnostics/plot_z_score_contraction.py @@ -1,116 +0,0 @@ - -import seaborn as sns -from ..utils.plotutils import preprocess, postprocess - - -def plot_z_score_contraction( - post_samples, - prior_samples, - param_names: list = None, - fig_size: tuple = None, - label_fontsize: int = 16, - title_fontsize: int = 18, - tick_fontsize: int = 12, - color: str | tuple = "#8f2727", - x_label: str = "Posterior contraction", - y_label: str = "Posterior z-score", - n_col: int = None, - n_row: int = None, -): - """Implements a graphical check for global model sensitivity by plotting the posterior - z-score over the posterior contraction for each set of posterior samples in ``post_samples`` - according to [1]. - - - The definition of the posterior z-score is: - - post_z_score = (posterior_mean - true_parameters) / posterior_std - - And the score is adequate if it centers around zero and spreads roughly in the interval [-3, 3] - - - The definition of posterior contraction is: - - post_contraction = 1 - (posterior_variance / prior_variance) - - In other words, the posterior contraction is a proxy for the reduction in uncertainty gained by - replacing the prior with the posterior. The ideal posterior contraction tends to 1. - Contraction near zero indicates that the posterior variance is almost identical to - the prior variance for the particular marginal parameter distribution. - - Note: Means and variances will be estimated via their sample-based estimators. - - [1] Schad, D. J., Betancourt, M., & Vasishth, S. (2021). - Toward a principled Bayesian workflow in cognitive science. - Psychological methods, 26(1), 103. - - Paper also available at https://arxiv.org/abs/1904.12765 - - Parameters - ---------- - post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) - The posterior draws obtained from n_data_sets - prior_samples : np.ndarray of shape (n_data_sets, n_params) - The prior draws (true parameters) obtained for generating the n_data_sets - param_names : list or None, optional, default: None - The parameter names for nice plot titles. Inferred if None - fig_size : tuple or None, optional, default : None - The figure size passed to the matplotlib constructor. Inferred if None. - label_fontsize : int, optional, default: 16 - The font size of the y-label text - title_fontsize : int, optional, default: 18 - The font size of the title text - tick_fontsize : int, optional, default: 12 - The font size of the axis ticklabels - color : str, optional, default: '#8f2727' - The color for the true vs. estimated scatter points and error bars - x_label : str, optional, default: Posterior contraction - The label for the x-axis - y_label : str, optional, default: Posterior z-score - The label for the y-axis - n_row : int, optional, default: None - The number of rows for the subplots. Dynamically determined if None. - n_col : int, optional, default: None - The number of columns for the subplots. Dynamically determined if None. - - Returns - ------- - f : plt.Figure - the figure instance for optional saving - - Raises - ------ - ShapeError - If there is a deviation from the expected shapes of ``post_samples`` and ``prior_samples``. - """ - - f, axarr, axarr_it, n_row, n_col, n_params, param_names = preprocess(post_samples, prior_samples, fig_size=fig_size) - - # Estimate posterior means and stds - post_means = post_samples.mean(axis=1) - post_stds = post_samples.std(axis=1, ddof=1) - post_vars = post_samples.var(axis=1, ddof=1) - - # Estimate prior variance - prior_vars = prior_samples.var(axis=0, keepdims=True, ddof=1) - - # Compute contraction - post_cont = 1 - (post_vars / prior_vars) - - # Compute posterior z score - z_score = (post_means - prior_samples) / post_stds - - # Loop and plot - for i, ax in enumerate(axarr_it): - if i >= n_params: - break - - ax.scatter(post_cont[:, i], z_score[:, i], color=color, alpha=0.5) - ax.set_title(param_names[i], fontsize=title_fontsize) - sns.despine(ax=ax) - ax.grid(alpha=0.5) - ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) - ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) - ax.set_xlim([-0.05, 1.05]) - - postprocess(axarr, axarr_it, n_row, n_col, n_params, x_label, y_label, label_fontsize) - - f.tight_layout() - return f From e5c7ee8f3aea42eb3f3579eef5480b3dc5396a56 Mon Sep 17 00:00:00 2001 From: Jerry Date: Tue, 24 Sep 2024 18:10:43 -0400 Subject: [PATCH 04/22] Update with dev --- bayesflow/data_adapters/__init__.py | 6 + .../data_adapters/composite_data_adapter.py | 52 ++ .../concatenate_keys_data_adapter.py | 105 +++ bayesflow/data_adapters/data_adapter.py | 29 + .../flow_matching_data_adapter.py | 44 ++ .../data_adapters/transforms/__init__.py | 5 + .../transforms/constrain_bounded.py | 140 ++++ .../transforms/lambda_transform.py | 47 ++ .../transforms/numpy_transform.py | 38 + .../data_adapters/transforms/standardize.py | 62 ++ .../data_adapters/transforms/transform.py | 50 ++ .../simulators/composite_lambda_simulator.py | 20 + env_dev.yml | 25 + examples/TwoMoons_FlowMatching.ipynb | 670 ++++++++++++++++++ tests/test_data_adapters/__init__.py | 0 tests/test_data_adapters/conftest.py | 42 ++ .../test_data_adapters/test_data_adapters.py | 22 + 17 files changed, 1357 insertions(+) create mode 100644 bayesflow/data_adapters/__init__.py create mode 100644 bayesflow/data_adapters/composite_data_adapter.py create mode 100644 bayesflow/data_adapters/concatenate_keys_data_adapter.py create mode 100644 bayesflow/data_adapters/data_adapter.py create mode 100644 bayesflow/data_adapters/flow_matching_data_adapter.py create mode 100644 bayesflow/data_adapters/transforms/__init__.py create mode 100644 bayesflow/data_adapters/transforms/constrain_bounded.py create mode 100644 bayesflow/data_adapters/transforms/lambda_transform.py create mode 100644 bayesflow/data_adapters/transforms/numpy_transform.py create mode 100644 bayesflow/data_adapters/transforms/standardize.py create mode 100644 bayesflow/data_adapters/transforms/transform.py create mode 100644 bayesflow/simulators/composite_lambda_simulator.py create mode 100644 env_dev.yml create mode 100644 examples/TwoMoons_FlowMatching.ipynb create mode 100644 tests/test_data_adapters/__init__.py create mode 100644 tests/test_data_adapters/conftest.py create mode 100644 tests/test_data_adapters/test_data_adapters.py diff --git a/bayesflow/data_adapters/__init__.py b/bayesflow/data_adapters/__init__.py new file mode 100644 index 000000000..15a202cfc --- /dev/null +++ b/bayesflow/data_adapters/__init__.py @@ -0,0 +1,6 @@ +from . import transforms + +from .composite_data_adapter import CompositeDataAdapter +from .concatenate_keys_data_adapter import ConcatenateKeysDataAdapter +from .data_adapter import DataAdapter +from .flow_matching_data_adapter import FlowMatchingDataAdapter diff --git a/bayesflow/data_adapters/composite_data_adapter.py b/bayesflow/data_adapters/composite_data_adapter.py new file mode 100644 index 000000000..2bad78848 --- /dev/null +++ b/bayesflow/data_adapters/composite_data_adapter.py @@ -0,0 +1,52 @@ +from collections.abc import Mapping +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) +import numpy as np + +from .data_adapter import DataAdapter + + +TRaw = Mapping[str, np.ndarray] +TProcessed = Mapping[str, np.ndarray] + + +@serializable(package="bayesflow.data_adapters") +class CompositeDataAdapter(DataAdapter[TRaw, TProcessed]): + """Composes multiple simple data adapters into a single more complex adapter.""" + + def __init__(self, data_adapters: Mapping[str, DataAdapter[TRaw, np.ndarray | None]]): + self.data_adapters = data_adapters + self.variable_counts = None + + def configure(self, raw_data: TRaw) -> TProcessed: + processed_data = {} + for key, data_adapter in self.data_adapters.items(): + data = data_adapter.configure(raw_data) + if data is not None: + processed_data[key] = data + + return processed_data + + def deconfigure(self, processed_data: TProcessed) -> TRaw: + raw_data = {} + for key, data_adapter in self.data_adapters.items(): + data = processed_data.get(key) + if data is not None: + raw_data |= data_adapter.deconfigure(data) + + return raw_data + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "CompositeDataAdapter": + return cls( + { + key: deserialize(data_adapter, custom_objects) + for key, data_adapter in config.pop("data_adapters").items() + } + ) + + def get_config(self) -> dict: + return {"data_adapters": {key: serialize(configurator) for key, configurator in self.data_adapters.items()}} diff --git a/bayesflow/data_adapters/concatenate_keys_data_adapter.py b/bayesflow/data_adapters/concatenate_keys_data_adapter.py new file mode 100644 index 000000000..bb5bd1551 --- /dev/null +++ b/bayesflow/data_adapters/concatenate_keys_data_adapter.py @@ -0,0 +1,105 @@ +from collections.abc import Mapping, Sequence +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) +import numpy as np + +from .composite_data_adapter import CompositeDataAdapter +from .data_adapter import DataAdapter +from .transforms import Transform + +TRaw = Mapping[str, np.ndarray] +TProcessed = np.ndarray | None + + +@serializable(package="bayesflow.data_adapters") +class _ConcatenateKeysDataAdapter(DataAdapter[TRaw, TProcessed]): + """Concatenates data from multiple keys into a single tensor.""" + + def __init__(self, keys: Sequence[str]): + if not keys: + raise ValueError("At least one key must be provided.") + + self.keys = keys + self.data_shapes = None + self.is_configured = False + + def configure(self, raw_data: TRaw) -> TProcessed: + if not self.is_configured: + self.data_shapes = {key: value.shape for key, value in raw_data.items()} + self.is_configured = True + + # filter and reorder data + data = {} + for key in self.keys: + if key not in raw_data: + # if a key is missing, we cannot configure, so we return None + return None + + data[key] = raw_data[key] + + # concatenate all tensors + return np.concatenate(list(data.values()), axis=-1) + + def deconfigure(self, processed_data: TProcessed) -> TRaw: + if not self.is_configured: + raise ValueError("You must call `configure` at least once before calling `deconfigure`.") + + data = {} + start = 0 + for key in self.keys: + stop = start + self.data_shapes[key][-1] + data[key] = np.take(processed_data, list(range(start, stop)), axis=-1) + start = stop + + return data + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "_ConcatenateKeysDataAdapter": + instance = cls(config["keys"]) + instance.data_shapes = config.get("data_shapes") + instance.is_configured = config.get("is_configured", False) + return instance + + def get_config(self) -> dict: + return {"keys": self.keys, "data_shapes": self.data_shapes, "is_configured": self.is_configured} + + +@serializable(package="bayesflow.data_adapters") +class ConcatenateKeysDataAdapter(CompositeDataAdapter): + """Concatenates data from multiple keys into multiple tensors.""" + + def __init__(self, *, transforms: Sequence[Transform] = None, **keys: Sequence[str]): + self.transforms = transforms or [] + self.keys = keys + configurators = {key: _ConcatenateKeysDataAdapter(value) for key, value in keys.items()} + super().__init__(configurators) + + def configure(self, raw_data): + data = raw_data + + for transform in self.transforms: + data = transform(data, inverse=False) + + data = super().configure(data) + + return data + + def deconfigure(self, processed_data): + data = processed_data + + data = super().deconfigure(data) + + for transform in reversed(self.transforms): + data = transform(data, inverse=True) + + return data + + @classmethod + def from_config(cls, config: Mapping[str, any], custom_objects=None) -> "ConcatenateKeysDataAdapter": + return cls(**config["keys"], transforms=deserialize(config.get("transforms"))) + + def get_config(self) -> dict[str, any]: + return {"keys": self.keys, "transforms": serialize(self.transforms)} diff --git a/bayesflow/data_adapters/data_adapter.py b/bayesflow/data_adapters/data_adapter.py new file mode 100644 index 000000000..0acf02ff6 --- /dev/null +++ b/bayesflow/data_adapters/data_adapter.py @@ -0,0 +1,29 @@ +from typing import Generic, TypeVar + + +TRaw = TypeVar("TRaw") +TProcessed = TypeVar("TProcessed") + + +class DataAdapter(Generic[TRaw, TProcessed]): + """Construct and deconstruct deep-learning ready data from and into raw data.""" + + def configure(self, raw_data: TRaw) -> TProcessed: + """Construct deep-learning ready data from raw data.""" + raise NotImplementedError + + def deconfigure(self, processed_data: TProcessed) -> TRaw: + """Reconstruct raw data from deep-learning ready processed data. + Note that configuration is not required to be bijective, so this method is only meant to be a 'best effort' + attempt, and may return incomplete or different raw data. + """ + raise NotImplementedError + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "DataAdapter": + """Construct a data adapter from a configuration dictionary.""" + raise NotImplementedError + + def get_config(self) -> dict: + """Return a configuration dictionary.""" + raise NotImplementedError diff --git a/bayesflow/data_adapters/flow_matching_data_adapter.py b/bayesflow/data_adapters/flow_matching_data_adapter.py new file mode 100644 index 000000000..993ba6e64 --- /dev/null +++ b/bayesflow/data_adapters/flow_matching_data_adapter.py @@ -0,0 +1,44 @@ +from keras.saving import register_keras_serializable as serializable +import numpy as np +from typing import TypeVar + +from bayesflow.utils import optimal_transport + +from .data_adapter import DataAdapter + + +TRaw = TypeVar("TRaw") +TProcessed = dict[str, np.ndarray | tuple[np.ndarray, ...]] + + +@serializable(package="bayesflow.data_adapters") +class FlowMatchingDataAdapter(DataAdapter[TRaw, TProcessed]): + """Wraps a data adapter, applying all further processing required for Optimal Transport Flow Matching. + Useful to move these operations into a worker process, so as not to slow down training. + """ + + def __init__(self, inner: DataAdapter[TRaw, dict[str, np.ndarray]], key: str = "inference_variables", **kwargs): + self.inner = inner + self.key = key + self.kwargs = kwargs + + def configure(self, raw_data: TRaw) -> TProcessed: + processed_data = self.inner.configure(raw_data) + + x1 = processed_data[self.key] + x0 = np.random.standard_normal(size=x1.shape).astype(x1.dtype) + t = np.random.uniform(size=x1.shape[0]).astype(x1.dtype) + + expand_index = [slice(None)] + [None] * (x1.ndim - 1) + t = t[tuple(expand_index)] + + x0, x1 = optimal_transport(x0, x1, **self.kwargs, numpy=True) + + x = t * x1 + (1 - t) * x0 + + target_velocity = x1 - x0 + + return processed_data | {self.key: (x0, x1, t, x, target_velocity)} + + def deconfigure(self, variables: TProcessed) -> TRaw: + return self.inner.deconfigure(variables) diff --git a/bayesflow/data_adapters/transforms/__init__.py b/bayesflow/data_adapters/transforms/__init__.py new file mode 100644 index 000000000..c2d633ea2 --- /dev/null +++ b/bayesflow/data_adapters/transforms/__init__.py @@ -0,0 +1,5 @@ +from .constrain_bounded import ConstrainBounded +from .lambda_transform import LambdaTransform +from .numpy_transform import NumpyTransform +from .standardize import Standardize +from .transform import Transform diff --git a/bayesflow/data_adapters/transforms/constrain_bounded.py b/bayesflow/data_adapters/transforms/constrain_bounded.py new file mode 100644 index 000000000..59a4b4338 --- /dev/null +++ b/bayesflow/data_adapters/transforms/constrain_bounded.py @@ -0,0 +1,140 @@ +from collections.abc import Sequence +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) +import numpy as np + +from bayesflow.utils.numpy_utils import ( + inverse_sigmoid, + inverse_softplus, + sigmoid, + softplus, +) + +from .lambda_transform import LambdaTransform + + +@serializable(package="bayesflow.data_adapters") +class ConstrainBounded(LambdaTransform): + """Constrains a parameter with a lower and/or upper bound.""" + + def __init__( + self, + parameters: str | Sequence[str] | None = None, + /, + *, + lower: np.ndarray = None, + upper: np.ndarray = None, + method: str, + ): + self.lower = lower + self.upper = upper + self.method = method + + if lower is None and upper is None: + raise ValueError("At least one of 'lower' or 'upper' must be provided.") + + if lower is not None and upper is not None: + if np.any(lower >= upper): + raise ValueError("The lower bound must be strictly less than the upper bound.") + + # double bounded case + match method: + case "clip": + + def constrain(x): + return np.clip(x, lower, upper) + + def unconstrain(x): + # not bijective + return x + case "sigmoid": + + def constrain(x): + return (upper - lower) * sigmoid(x) + lower + + def unconstrain(x): + return inverse_sigmoid((x - lower) / (upper - lower)) + case str() as name: + raise ValueError(f"Unsupported method name for double bounded constraint: '{name}'.") + case other: + raise TypeError(f"Expected a method name, got {other!r}.") + else: + # single bounded case + if lower is not None: + match method: + case "clip": + + def constrain(x): + return np.clip(x, lower, np.inf) + + def unconstrain(x): + # not bijective + return x + case "softplus": + + def constrain(x): + return softplus(x) + lower + + def unconstrain(x): + return inverse_softplus(x - lower) + case "exp": + + def constrain(x): + return np.exp(x) + lower + + def unconstrain(x): + return np.log(x - lower) + case str() as name: + raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.") + case other: + raise TypeError(f"Expected a method name, got {other!r}.") + else: + match method: + case "clip": + + def constrain(x): + return np.clip(x, -np.inf, upper) + + def unconstrain(x): + # not bijective + return x + case "softplus": + + def constrain(x): + return -softplus(-x) + upper + + def unconstrain(x): + return -inverse_softplus(-(x - upper)) + case "exp": + + def constrain(x): + return -np.exp(-x) + upper + + def unconstrain(x): + return -np.log(-x + upper) + case str() as name: + raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.") + case other: + raise TypeError(f"Expected a method name, got {other!r}.") + + super().__init__(parameters, forward=unconstrain, inverse=constrain) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "ConstrainBounded": + return cls( + deserialize(config["parameters"], custom_objects), + lower=deserialize(config["lower"], custom_objects), + upper=deserialize(config["upper"], custom_objects), + method=deserialize(config["method"], custom_objects), + ) + + def get_config(self) -> dict: + return { + "parameters": serialize(self.parameters), + "lower": serialize(self.lower), + "upper": serialize(self.upper), + "method": serialize(self.method), + } diff --git a/bayesflow/data_adapters/transforms/lambda_transform.py b/bayesflow/data_adapters/transforms/lambda_transform.py new file mode 100644 index 000000000..9e2daec7e --- /dev/null +++ b/bayesflow/data_adapters/transforms/lambda_transform.py @@ -0,0 +1,47 @@ +from collections.abc import Sequence + +import numpy as np +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) +from .transform import ElementwiseTransform + + +@serializable(package="bayesflow.data_adapters") +class LambdaTransform(ElementwiseTransform): + """ + Transforms a parameter using a pair of forward and inverse functions. + + Important note: This class is only serializable if the forward and inverse functions are serializable. + This most likely means you will have to pass the scope that the forward and inverse functions are contained in + to the `custom_objects` argument of the `deserialize` function when deserializing this class. + """ + + def __init__(self, parameters: str | Sequence[str] | None = None, /, *, forward: callable, inverse: callable): + super().__init__(parameters) + + self._forward = forward + self._inverse = inverse + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "LambdaTransform": + return cls( + deserialize(config["parameters"], custom_objects), + forward=deserialize(config["forward"], custom_objects), + inverse=deserialize(config["inverse"], custom_objects), + ) + + def forward(self, parameter_name: str, parameter_value: np.ndarray) -> np.ndarray: + return self._forward(parameter_value) + + def get_config(self) -> dict: + return { + "parameters": serialize(self.parameters), + "forward": serialize(self._forward), + "inverse": serialize(self._inverse), + } + + def inverse(self, parameter_name: str, parameter_value: np.ndarray) -> np.ndarray: + return self._inverse(parameter_value) diff --git a/bayesflow/data_adapters/transforms/numpy_transform.py b/bayesflow/data_adapters/transforms/numpy_transform.py new file mode 100644 index 000000000..ecb8f602c --- /dev/null +++ b/bayesflow/data_adapters/transforms/numpy_transform.py @@ -0,0 +1,38 @@ +from collections.abc import Sequence +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) + +import numpy as np + +from .lambda_transform import LambdaTransform + + +@serializable(package="bayesflow.data_adapters") +class NumpyTransform(LambdaTransform): + """A LambdaTransform for numpy functions. Automatically serializable, unlike LambdaTransform.""" + + def __init__(self, parameters: str | Sequence[str] | None = None, /, *, forward: str, inverse: str): + self.forward_name = forward + self.inverse_name = inverse + + forward = getattr(np, forward) + inverse = getattr(np, inverse) + super().__init__(parameters, forward=forward, inverse=inverse) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "NumpyTransform": + return cls( + deserialize(config["parameters"], custom_objects), + forward=deserialize(config["forward_name"], custom_objects), + inverse=deserialize(config["inverse_name"], custom_objects), + ) + + def get_config(self) -> dict: + return { + "parameters": serialize(self.parameters), + "forward_name": serialize(self.forward_name), + "inverse_name": serialize(self.inverse_name), + } diff --git a/bayesflow/data_adapters/transforms/standardize.py b/bayesflow/data_adapters/transforms/standardize.py new file mode 100644 index 000000000..e19ea96a7 --- /dev/null +++ b/bayesflow/data_adapters/transforms/standardize.py @@ -0,0 +1,62 @@ +from collections.abc import Mapping, Sequence +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, +) +import numpy as np + +from .transform import ElementwiseTransform + + +@serializable(package="bayesflow.data_adapters") +class Standardize(ElementwiseTransform): + """Normalizes a parameter to have zero mean and unit standard deviation. + By default, this is lazily initialized; the mean and standard deviation are computed from the first batch of data. + For eager initialization, pass the mean and standard deviation to the constructor. + """ + + def __init__( + self, + parameters: str | Sequence[str] | None = None, + /, + *, + means: Mapping[str, np.ndarray] = None, + stds: Mapping[str, np.ndarray] = None, + ): + super().__init__(parameters) + self.means = means or {} + self.stds = stds or {} + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "Standardize": + return cls( + deserialize(config["parameters"], custom_objects), + means=deserialize(config["means"], custom_objects), + stds=deserialize(config["stds"], custom_objects), + ) + + def forward(self, parameter_name: str, parameter_value: np.ndarray) -> np.ndarray: + if parameter_name not in self.means: + self.means[parameter_name] = np.mean( + parameter_value, axis=tuple(range(parameter_value.ndim)), keepdims=True + ) + if parameter_name not in self.stds: + self.stds[parameter_name] = np.std(parameter_value, axis=tuple(range(parameter_value.ndim)), keepdims=True) + + return (parameter_value - self.means[parameter_name]) / self.stds[parameter_name] + + def get_config(self) -> dict: + return { + "parameters": serialize(self.parameters), + "means": serialize(self.means), + "stds": serialize(self.stds), + } + + def inverse(self, parameter_name: str, parameter_value: np.ndarray) -> np.ndarray: + if not self.means or not self.stds: + raise ValueError( + f"Cannot call `inverse` before calling `forward` at least once for parameter {parameter_name}." + ) + + return parameter_value * self.stds[parameter_name] + self.means[parameter_name] diff --git a/bayesflow/data_adapters/transforms/transform.py b/bayesflow/data_adapters/transforms/transform.py new file mode 100644 index 000000000..6dcaf5ea0 --- /dev/null +++ b/bayesflow/data_adapters/transforms/transform.py @@ -0,0 +1,50 @@ +from collections.abc import Sequence +import numpy as np + + +class Transform: + """Implements typical data transformations that can be applied as part of the adapter pipeline.""" + + def __call__(self, data: dict[str, np.ndarray], inverse: bool = False) -> dict[str, np.ndarray]: + raise NotImplementedError + + +class ElementwiseTransform(Transform): + """Intermediate layer for transforms that are applied on a per-parameter basis.""" + + def __init__(self, parameters: str | Sequence[str] | None = None, /): + self.parameters = parameters + + def __call__(self, data: dict[str, np.ndarray], inverse: bool = False) -> dict[str, np.ndarray]: + """Apply the transform to the data""" + data = data.copy() + + if self.parameters is None: + # apply to all parameters + parameters = list(data.keys()) + elif isinstance(self.parameters, str): + # apply just to this parameter + parameters = [self.parameters] + else: + # apply to all given parameters + parameters = self.parameters + + for parameter in parameters: + if data.get(parameter) is None: + # skip when in partial configuration + continue + + if inverse: + data[parameter] = self.inverse(parameter, data[parameter]) + else: + data[parameter] = self.forward(parameter, data[parameter]) + + return data + + def forward(self, parameter_name: str, parameter_value: np.ndarray) -> np.ndarray: + """Implements the forward direction of the transform, i.e., user/data space -> network/latent space""" + raise NotImplementedError + + def inverse(self, parameter_name: str, parameter_value: np.ndarray) -> np.ndarray: + """Implements the inverse direction of the transform, i.e., network/latent space -> userdata space""" + raise NotImplementedError diff --git a/bayesflow/simulators/composite_lambda_simulator.py b/bayesflow/simulators/composite_lambda_simulator.py new file mode 100644 index 000000000..46d886cb7 --- /dev/null +++ b/bayesflow/simulators/composite_lambda_simulator.py @@ -0,0 +1,20 @@ +from collections.abc import Sequence +import numpy as np + +from bayesflow.types import Shape + +from .simulator import Simulator +from .composite_simulator import CompositeSimulator +from .lambda_simulator import LambdaSimulator + + +class CompositeLambdaSimulator(Simulator): + """Combines multiple lambda simulators into one, sequentially.""" + + def __init__(self, sample_fns: Sequence[callable], expand_outputs: bool = True, **kwargs): + self.inner = CompositeSimulator( + [LambdaSimulator(fn, **kwargs) for fn in sample_fns], expand_outputs=expand_outputs + ) + + def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: + return self.inner.sample(batch_shape, **kwargs) diff --git a/env_dev.yml b/env_dev.yml new file mode 100644 index 000000000..2526d53a7 --- /dev/null +++ b/env_dev.yml @@ -0,0 +1,25 @@ +# "dev" conda envs are to be used by devs in setting their local environments +name: bf-dev +channels: +- conda-forge +- defaults +dependencies: +- aesara +- flake8>=6.1.0 +- ipywidgets +- jupyter +- jupytext +- matplotlib +- minikanren +- mypy>=1.5.1 +- pandas +- pip +- pytest>=7.2.0 +- pytest-xdist>=3.5.0 +- pytest-cov>=4.1.0 +- scikit-learn +- seaborn +- tensorflow<2.16,>=2.10.1 +- tensorflow-probability<0.24,>=0.17 +- tox>=4.10.0 +- watermark diff --git a/examples/TwoMoons_FlowMatching.ipynb b/examples/TwoMoons_FlowMatching.ipynb new file mode 100644 index 000000000..4fcfd655d --- /dev/null +++ b/examples/TwoMoons_FlowMatching.ipynb @@ -0,0 +1,670 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "009b6adf", + "metadata": {}, + "source": [ + "# Two Moons: Tackling Bimodal Posteriors" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d5f88a59", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.551814Z", + "start_time": "2024-09-23T14:39:46.032170Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:\n", + "Outdated cuDNN installation found.\n", + "Version JAX was built against: 8907\n", + "Minimum supported: 9100\n", + "Installed version: 8907\n", + "The local installation version must be no lower than 9100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "\n", + "# ensure the backend is set\n", + "import os\n", + "if \"KERAS_BACKEND\" not in os.environ:\n", + " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", + " os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "\n", + "import keras\n", + "\n", + "# for BayesFlow devs: this ensures that the latest dev version can be found\n", + "import sys\n", + "sys.path.append('../')\n", + "\n", + "import bayesflow as bf" + ] + }, + { + "cell_type": "markdown", + "id": "c63b26ba", + "metadata": {}, + "source": [ + "## Simulator" + ] + }, + { + "cell_type": "markdown", + "id": "9525ffd7", + "metadata": {}, + "source": [ + "This example will demonstrate amortized estimation of a somewhat strange Bayesian model, whose posterior evaluated at the origin $x = (0, 0)$ of the \"data\" will resemble two crescent moons. The forward process is a noisy non-linear transformation on a 2D plane:\n", + "\n", + "$$\n", + "\\begin{align}\n", + "x_1 &= -|\\theta_1 + \\theta_2|/\\sqrt{2} + r \\cos(\\alpha) + 0.25\\\\\n", + "x_2 &= (-\\theta_1 + \\theta_2)/\\sqrt{2} + r\\sin{\\alpha}\n", + "\\end{align}\n", + "$$\n", + "\n", + "with $x = (x_1, x_2)$ playing the role of \"observables\" (data to be learned from), $\\alpha \\sim \\text{Uniform}(-\\pi/2, \\pi/2)$, and $r \\sim \\text{Normal}(0.1, 0.01)$ being latent variables creating noise in the data, and $\\theta = (\\theta_1, \\theta_2)$ being the parameters that we will later seek to infer from new $x$. We set their priors to\n", + "\n", + "$$\n", + "\\begin{align}\n", + "\\theta_1, \\theta_2 \\sim \\text{Uniform}(-1, 1).\n", + "\\end{align}\n", + "$$\n", + "\n", + "This model is typically used for benchmarking simulation-based inference (SBI) methods (see https://arxiv.org/pdf/2101.04653) and any method for amortized Bayesian inference should be capable of recovering the two moons posterior *without* using a gazillion of simulations. Note, that this is a considerably harder task than modeling the common unconditional two moons data set used often in the context of normalizing flows." + ] + }, + { + "cell_type": "markdown", + "id": "21bf228e706a010", + "metadata": {}, + "source": [ + "BayesFlow offers many ways to define your data generating process. Here, we use sequential functions to build a simulator object for online training. Within this composite simulator, each function has access to the outputs of the previous functions. This effectively allows you to define any generative graph." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f761b142a0e1da66", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.703381Z", + "start_time": "2024-09-23T14:39:46.700649Z" + } + }, + "outputs": [], + "source": [ + "def alpha_prior():\n", + " alpha = np.random.uniform(-np.pi / 2, np.pi / 2)\n", + " return dict(alpha=alpha)\n", + "\n", + "def r_prior():\n", + " r = np.random.normal(0.1, 0.01)\n", + " return dict(r=r)\n", + "\n", + "def theta_prior():\n", + " theta = np.random.uniform(-1, 1, 2)\n", + " return dict(theta=theta)\n", + "\n", + "def forward_model(theta, alpha, r):\n", + " x1 = -np.abs(theta[0] + theta[1]) / np.sqrt(2) + r * np.cos(alpha) + 0.25\n", + " x2 = (-theta[0] + theta[1]) / np.sqrt(2) + r * np.sin(alpha)\n", + " return dict(x=np.array([x1, x2]))" + ] + }, + { + "cell_type": "markdown", + "id": "722cb773", + "metadata": {}, + "source": [ + "Within the composite simulator, every simulator has access to the outputs of the previous simulators in the list. For example, the last simulator `forward_model` has access to the outputs of the three other simulators." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4b89c861527c13b8", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.747091Z", + "start_time": "2024-09-23T14:39:46.744830Z" + } + }, + "outputs": [], + "source": [ + "simulator = bf.simulators.CompositeLambdaSimulator([alpha_prior, r_prior, theta_prior, forward_model])" + ] + }, + { + "cell_type": "markdown", + "id": "f6e1eb5777c59eba", + "metadata": {}, + "source": [ + "Let's generate some data to see what the simulator does:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e6218e61d529e357", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.798575Z", + "start_time": "2024-09-23T14:39:46.790581Z" + } + }, + "outputs": [], + "source": [ + "# generate 128 random draws from the joint distribution p(r, alpha, theta, x)\n", + "sample_data = simulator.sample((128,))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "46174ccb0167026c", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.854911Z", + "start_time": "2024-09-23T14:39:46.852129Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of sample_data:\n", + "\t \n", + "Keys of sample_data:\n", + "\t dict_keys(['alpha', 'r', 'theta', 'x'])\n", + "Types of sample_data values:\n", + "\t {'alpha': , 'r': , 'theta': , 'x': }\n", + "Shapes of sample_data values:\n", + "\t {'alpha': (128, 1), 'r': (128, 1), 'theta': (128, 2), 'x': (128, 2)}\n" + ] + } + ], + "source": [ + "print(\"Type of sample_data:\\n\\t\", type(sample_data))\n", + "print(\"Keys of sample_data:\\n\\t\", sample_data.keys())\n", + "print(\"Types of sample_data values:\\n\\t\", {k: type(v) for k, v in sample_data.items()})\n", + "print(\"Shapes of sample_data values:\\n\\t\", {k: v.shape for k, v in sample_data.items()})" + ] + }, + { + "cell_type": "markdown", + "id": "17f158bd2d7abf75", + "metadata": {}, + "source": [ + "BayesFlow also provides this simulator and a collection of others in the `bayesflow.benchmarks` module." + ] + }, + { + "cell_type": "markdown", + "id": "fee88fcfd7a373b0", + "metadata": {}, + "source": [ + "## Data Adapter\n", + "\n", + "The next step is to tell BayesFlow how to deal with all the simulated variables. You may also think of this as informing BayesFlow about the data flow, i.e., which variables go into which network.\n", + "\n", + "For this example, we want to learn the posterior distribution $p(\\theta | x)$, so we **infer** $\\theta$, **conditioning** on $x$." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c9637c576d4ad4e5", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.905081Z", + "start_time": "2024-09-23T14:39:46.903091Z" + } + }, + "outputs": [], + "source": [ + "data_adapter = bf.ContinuousApproximator.build_data_adapter(\n", + " inference_variables=[\"theta\"],\n", + " inference_conditions=[\"x\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "254e287b2bccdad", + "metadata": {}, + "source": [ + "## Dataset\n", + "\n", + "For this example, we will sample our training data ahead of time and use offline training with a `bf.datasets.OfflineDataset`.\n", + "\n", + "This makes the training process faster, since we avoid repeated sampling. If you want to use online training, you can use an `OnlineDataset` analogously, or just pass your simulator directly to `approximator.fit()`!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "39cb5a1c9824246f", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:46.950573Z", + "start_time": "2024-09-23T14:39:46.948624Z" + } + }, + "outputs": [], + "source": [ + "num_training_batches = 1024\n", + "num_validation_batches = 256\n", + "batch_size = 128" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9dee7252ef99affa", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.268860Z", + "start_time": "2024-09-23T14:39:46.994697Z" + } + }, + "outputs": [], + "source": [ + "training_samples = simulator.sample((num_training_batches * batch_size,))\n", + "validation_samples = simulator.sample((num_validation_batches * batch_size,))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "51045bbed88cb5c2", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.281170Z", + "start_time": "2024-09-23T14:39:53.275921Z" + } + }, + "outputs": [], + "source": [ + "training_dataset = bf.datasets.OfflineDataset(training_samples, batch_size=batch_size, data_adapter=data_adapter)\n", + "validation_dataset = bf.datasets.OfflineDataset(validation_samples, batch_size=batch_size, data_adapter=data_adapter)" + ] + }, + { + "cell_type": "markdown", + "id": "2d4c6eb0", + "metadata": {}, + "source": [ + "## Traing a neural network to approximate all posteriors\n", + "\n", + "The next step is to set up the neural network that will approximate the posterior $p(\\theta|x)$.\n", + "\n", + "We choose Flow Matching as the architecture for this example, as it can deal well with the multimodal nature of the posteriors that some observables imply." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "09206e6f", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.339590Z", + "start_time": "2024-09-23T14:39:53.319852Z" + } + }, + "outputs": [], + "source": [ + "inference_network = bf.networks.FlowMatching(\n", + " subnet=\"mlp\",\n", + " subnet_kwargs=dict(\n", + " depth=6,\n", + " width=256,\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "851e522f", + "metadata": {}, + "source": [ + "This inference network is just a general Flow Matching architecture, not yet adapted to the specific inference task at hand (i.e., posterior appproximation). To achieve this adaptation, we combine the network with our data adapter, which together form an `approximator`. In this case, we need a `ContinuousApproximator` since the target we want to approximate is the posterior of the *continuous* parameter vector $\\theta$." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "96ca6ffa", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.371691Z", + "start_time": "2024-09-23T14:39:53.369375Z" + } + }, + "outputs": [], + "source": [ + "approximator = bf.ContinuousApproximator(\n", + " inference_network=inference_network,\n", + " data_adapter=data_adapter,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "566264eadc76c2c", + "metadata": {}, + "source": [ + "### Optimizer and Learning Rate\n", + "For this example, it is sufficient to use a static learning rate. In practice, you may want to use a learning rate schedule, like [cosine decay](https://keras.io/api/optimizers/learning_rate_schedules/cosine_decay/)." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e8d7e053", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.433012Z", + "start_time": "2024-09-23T14:39:53.415903Z" + } + }, + "outputs": [], + "source": [ + "learning_rate = 1e-4\n", + "optimizer = keras.optimizers.Adam(learning_rate=learning_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "51808fcd560489ac", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:39:53.476089Z", + "start_time": "2024-09-23T14:39:53.466001Z" + } + }, + "outputs": [], + "source": [ + "approximator.compile(optimizer=optimizer)" + ] + }, + { + "cell_type": "markdown", + "id": "708b1303", + "metadata": {}, + "source": [ + "### Training\n", + "\n", + "We are ready to train our deep posterior approximator on the two moons example. We pass the dataset object to the `fit` method and watch as Bayesflow trains." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0f496bda", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:42:36.067393Z", + "start_time": "2024-09-23T14:39:53.513436Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:bayesflow:Fitting on dataset instance of OfflineDataset.\n", + "INFO:bayesflow:Building on a test batch.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 6ms/step - loss: 0.6938 - loss/inference_loss: 0.6938 - val_loss: 0.5508 - val_loss/inference_loss: 0.5508\n", + "Epoch 2/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.6250 - loss/inference_loss: 0.6250 - val_loss: 0.6023 - val_loss/inference_loss: 0.6023\n", + "Epoch 3/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.6056 - loss/inference_loss: 0.6056 - val_loss: 0.4454 - val_loss/inference_loss: 0.4454\n", + "Epoch 4/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.6006 - loss/inference_loss: 0.6006 - val_loss: 0.5079 - val_loss/inference_loss: 0.5079\n", + "Epoch 5/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.6020 - loss/inference_loss: 0.6020 - val_loss: 0.5414 - val_loss/inference_loss: 0.5414\n", + "Epoch 6/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5973 - loss/inference_loss: 0.5973 - val_loss: 0.6961 - val_loss/inference_loss: 0.6961\n", + "Epoch 7/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5874 - loss/inference_loss: 0.5874 - val_loss: 0.5399 - val_loss/inference_loss: 0.5399\n", + "Epoch 8/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5939 - loss/inference_loss: 0.5939 - val_loss: 0.4877 - val_loss/inference_loss: 0.4877\n", + "Epoch 9/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5841 - loss/inference_loss: 0.5841 - val_loss: 0.5115 - val_loss/inference_loss: 0.5115\n", + "Epoch 10/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5827 - loss/inference_loss: 0.5827 - val_loss: 0.5383 - val_loss/inference_loss: 0.5383\n", + "Epoch 11/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5807 - loss/inference_loss: 0.5807 - val_loss: 0.4411 - val_loss/inference_loss: 0.4411\n", + "Epoch 12/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5774 - loss/inference_loss: 0.5774 - val_loss: 0.5844 - val_loss/inference_loss: 0.5844\n", + "Epoch 13/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5813 - loss/inference_loss: 0.5813 - val_loss: 0.8106 - val_loss/inference_loss: 0.8106\n", + "Epoch 14/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 5ms/step - loss: 0.5756 - loss/inference_loss: 0.5756 - val_loss: 0.4150 - val_loss/inference_loss: 0.4150\n", + "Epoch 15/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 5ms/step - loss: 0.5761 - loss/inference_loss: 0.5761 - val_loss: 0.5451 - val_loss/inference_loss: 0.5451\n", + "Epoch 16/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5747 - loss/inference_loss: 0.5747 - val_loss: 0.6248 - val_loss/inference_loss: 0.6248\n", + "Epoch 17/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5666 - loss/inference_loss: 0.5666 - val_loss: 0.4689 - val_loss/inference_loss: 0.4689\n", + "Epoch 18/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5705 - loss/inference_loss: 0.5705 - val_loss: 0.3853 - val_loss/inference_loss: 0.3853\n", + "Epoch 19/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5739 - loss/inference_loss: 0.5739 - val_loss: 0.5055 - val_loss/inference_loss: 0.5055\n", + "Epoch 20/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5688 - loss/inference_loss: 0.5688 - val_loss: 0.5032 - val_loss/inference_loss: 0.5032\n", + "Epoch 21/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5663 - loss/inference_loss: 0.5663 - val_loss: 0.5237 - val_loss/inference_loss: 0.5237\n", + "Epoch 22/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5650 - loss/inference_loss: 0.5650 - val_loss: 0.3955 - val_loss/inference_loss: 0.3955\n", + "Epoch 23/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5666 - loss/inference_loss: 0.5666 - val_loss: 0.7317 - val_loss/inference_loss: 0.7317\n", + "Epoch 24/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5632 - loss/inference_loss: 0.5632 - val_loss: 0.6094 - val_loss/inference_loss: 0.6094\n", + "Epoch 25/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5701 - loss/inference_loss: 0.5701 - val_loss: 0.5721 - val_loss/inference_loss: 0.5721\n", + "Epoch 26/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5711 - loss/inference_loss: 0.5711 - val_loss: 0.6184 - val_loss/inference_loss: 0.6184\n", + "Epoch 27/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5672 - loss/inference_loss: 0.5672 - val_loss: 0.6326 - val_loss/inference_loss: 0.6326\n", + "Epoch 28/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5585 - loss/inference_loss: 0.5585 - val_loss: 0.6209 - val_loss/inference_loss: 0.6209\n", + "Epoch 29/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5594 - loss/inference_loss: 0.5594 - val_loss: 0.5672 - val_loss/inference_loss: 0.5672\n", + "Epoch 30/30\n", + "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5597 - loss/inference_loss: 0.5597 - val_loss: 0.4648 - val_loss/inference_loss: 0.4648\n" + ] + } + ], + "source": [ + "history = approximator.fit(\n", + " epochs=30,\n", + " dataset=training_dataset,\n", + " validation_data=validation_dataset,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b90a6062", + "metadata": {}, + "source": [ + "## Validation" + ] + }, + { + "cell_type": "markdown", + "id": "ca62b21d", + "metadata": {}, + "source": [ + "### Two Moons Posterior\n", + "\n", + "The two moons posterior at point $x = (0, 0)$ should resemble two crescent shapes. Below, we plot the corresponding posterior samples and posterior density. \n", + "These results suggest that our flow matching setup can approximate the expected analytical posterior well. (Note that you can achieve an even better fit if you use online training and more epochs.)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8562caeb", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:42:38.584554Z", + "start_time": "2024-09-23T14:42:36.076923Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(-0.5, 0.5)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Set the number of posterior draws you want to get\n", + "num_samples = 5000\n", + "\n", + "# Obtain samples from amortized posterior\n", + "conditions = {\"x\": np.array([[0.0, 0.0]]).astype(\"float32\")}\n", + "samples_at_origin = approximator.sample(conditions=conditions, num_samples=num_samples)[\"theta\"]\n", + "\n", + "# Prepare figure\n", + "f, axes = plt.subplots(1, figsize=(6, 6))\n", + "\n", + "# Plot samples\n", + "axes.scatter(samples_at_origin[0, :, 0], samples_at_origin[0, :, 1], color=\"#153c7a\", alpha=0.75, s=0.5)\n", + "sns.despine(ax=axes)\n", + "axes.set_title(r\"Posterior samples at origin $x=(0, 0)$\")\n", + "axes.grid(alpha=0.3)\n", + "axes.set_aspect(\"equal\", adjustable=\"box\")\n", + "axes.set_xlim([-0.5, 0.5])\n", + "axes.set_ylim([-0.5, 0.5])" + ] + }, + { + "cell_type": "markdown", + "id": "01821d24", + "metadata": {}, + "source": [ + "\n", + "The posterior looks as we have expected in this case. However, in general, we do not know how the posterior is supposed to look like for any specific dataset. As such, we need diagnostics that validate the correctness of the inferred posterior. One such diagnostic is simulation-based calibration(SBC), which we can apply for free due to amortization. For more details on SBC and diagnostic plots, see:\n", + "\n", + "1. Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian inference algorithms with simulation-based calibration. *arXiv preprint*.\n", + "2. Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and its applications in goodness-of-fit evaluation and multiple sample comparison. *Statistics and Computing*." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f76289b3", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:42:38.595234Z", + "start_time": "2024-09-23T14:42:38.593542Z" + } + }, + "outputs": [], + "source": [ + "# Will be added soon." + ] + }, + { + "cell_type": "markdown", + "id": "66248a2f", + "metadata": {}, + "source": [ + "## Further Experimentation " + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "89dcb727", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-23T14:42:38.639240Z", + "start_time": "2024-09-23T14:42:38.637439Z" + } + }, + "outputs": [], + "source": [ + "# Will be added soon." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": true, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "165px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_data_adapters/__init__.py b/tests/test_data_adapters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_data_adapters/conftest.py b/tests/test_data_adapters/conftest.py new file mode 100644 index 000000000..465bbf268 --- /dev/null +++ b/tests/test_data_adapters/conftest.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest + + +def forward_transform(x): + return x + 1 + + +def inverse_transform(x): + return x - 1 + + +@pytest.fixture() +def custom_objects(): + return globals() | np.__dict__ + + +@pytest.fixture() +def data_adapter(): + from bayesflow.data_adapters import ConcatenateKeysDataAdapter + from bayesflow.data_adapters.transforms import LambdaTransform, Standardize + + return ConcatenateKeysDataAdapter( + x=["x1", "x2"], + y=["y1", "y2"], + transforms=[ + # normalize all parameters + Standardize(), + # use a lambda transform with global functions + LambdaTransform("x2", forward=forward_transform, inverse=inverse_transform), + ], + ) + + +@pytest.fixture() +def random_data(): + return { + "x1": np.random.standard_normal(size=(32, 1)).astype("float32"), + "x2": np.random.standard_normal(size=(32, 1)).astype("float32"), + "y1": np.random.standard_normal(size=(32, 2)).astype("float32"), + "y2": np.random.standard_normal(size=(32, 2)).astype("float32"), + } diff --git a/tests/test_data_adapters/test_data_adapters.py b/tests/test_data_adapters/test_data_adapters.py new file mode 100644 index 000000000..f5ff155d9 --- /dev/null +++ b/tests/test_data_adapters/test_data_adapters.py @@ -0,0 +1,22 @@ +import keras +from keras.saving import ( + deserialize_keras_object as deserialize, + serialize_keras_object as serialize, +) + + +def test_cycle_consistency(data_adapter, random_data): + processed = data_adapter.configure(random_data) + deprocessed = data_adapter.deconfigure(processed) + + for key, value in random_data.items(): + assert key in deprocessed + assert keras.ops.all(keras.ops.isclose(value, deprocessed[key])) + + +def test_serialize_deserialize(data_adapter, custom_objects): + serialized = serialize(data_adapter) + deserialized = deserialize(serialized, custom_objects) + reserialized = serialize(deserialized) + + assert reserialized == serialized From 0b65cf22a5151cf43b33d6b0f92f8d2db7587b85 Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 25 Sep 2024 01:01:14 -0400 Subject: [PATCH 05/22] Temporarily migrate plot_distribution_2d --- .../diagnostics/plot_distribution_2d.py | 0 .../diagnostics/plot_calibration_curves.py | 0 .../diagnostics/plot_confusion_matrix.py | 0 .../diagnostics/plot_latent_space_2d.py | 0 .../experimental/diagnostics/plot_losses.py | 0 .../diagnostics/plot_mmd_hypothesis_test.py | 0 .../diagnostics/plot_posterior_2d.py | 0 .../experimental/diagnostics/plot_prior_2d.py | 0 .../experimental/diagnostics/plot_recovery.py | 0 .../experimental/diagnostics/plot_sbc_ecdf.py | 0 .../diagnostics/plot_sbc_histograms.py | 0 .../diagnostics/plot_z_score_contraction.py | 0 tests/test_diagnostics/__init__.py | 0 tests/test_diagnostics/test_diagnostics.py | 16 ---------------- 14 files changed, 16 deletions(-) rename bayesflow/{experimental => }/diagnostics/plot_distribution_2d.py (100%) delete mode 100644 bayesflow/experimental/diagnostics/plot_calibration_curves.py delete mode 100644 bayesflow/experimental/diagnostics/plot_confusion_matrix.py delete mode 100644 bayesflow/experimental/diagnostics/plot_latent_space_2d.py delete mode 100644 bayesflow/experimental/diagnostics/plot_losses.py delete mode 100644 bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py delete mode 100644 bayesflow/experimental/diagnostics/plot_posterior_2d.py delete mode 100644 bayesflow/experimental/diagnostics/plot_prior_2d.py delete mode 100644 bayesflow/experimental/diagnostics/plot_recovery.py delete mode 100644 bayesflow/experimental/diagnostics/plot_sbc_ecdf.py delete mode 100644 bayesflow/experimental/diagnostics/plot_sbc_histograms.py delete mode 100644 bayesflow/experimental/diagnostics/plot_z_score_contraction.py delete mode 100644 tests/test_diagnostics/__init__.py delete mode 100644 tests/test_diagnostics/test_diagnostics.py diff --git a/bayesflow/experimental/diagnostics/plot_distribution_2d.py b/bayesflow/diagnostics/plot_distribution_2d.py similarity index 100% rename from bayesflow/experimental/diagnostics/plot_distribution_2d.py rename to bayesflow/diagnostics/plot_distribution_2d.py diff --git a/bayesflow/experimental/diagnostics/plot_calibration_curves.py b/bayesflow/experimental/diagnostics/plot_calibration_curves.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bayesflow/experimental/diagnostics/plot_confusion_matrix.py b/bayesflow/experimental/diagnostics/plot_confusion_matrix.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bayesflow/experimental/diagnostics/plot_latent_space_2d.py b/bayesflow/experimental/diagnostics/plot_latent_space_2d.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bayesflow/experimental/diagnostics/plot_losses.py b/bayesflow/experimental/diagnostics/plot_losses.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py b/bayesflow/experimental/diagnostics/plot_mmd_hypothesis_test.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bayesflow/experimental/diagnostics/plot_posterior_2d.py b/bayesflow/experimental/diagnostics/plot_posterior_2d.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bayesflow/experimental/diagnostics/plot_prior_2d.py b/bayesflow/experimental/diagnostics/plot_prior_2d.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bayesflow/experimental/diagnostics/plot_recovery.py b/bayesflow/experimental/diagnostics/plot_recovery.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py b/bayesflow/experimental/diagnostics/plot_sbc_ecdf.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bayesflow/experimental/diagnostics/plot_sbc_histograms.py b/bayesflow/experimental/diagnostics/plot_sbc_histograms.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bayesflow/experimental/diagnostics/plot_z_score_contraction.py b/bayesflow/experimental/diagnostics/plot_z_score_contraction.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_diagnostics/__init__.py b/tests/test_diagnostics/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_diagnostics/test_diagnostics.py b/tests/test_diagnostics/test_diagnostics.py deleted file mode 100644 index 4018ee577..000000000 --- a/tests/test_diagnostics/test_diagnostics.py +++ /dev/null @@ -1,16 +0,0 @@ - -import keras -import pytest - -from bayesflow.experimental.diagnostics import ( - plot_distribution_2d, - plot_prior_2d, - plot_posterior_2d -) - - -@pytest.fixture() -def test_plot_distribution_2d(): - pass - -#TODO \ No newline at end of file From 24362dbc611ef5ca78094acb7bfccf0ee2701744 Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 9 Oct 2024 01:32:39 -0400 Subject: [PATCH 06/22] Add preliminary plot utils --- bayesflow/utils/exceptions/__init__.py | 0 bayesflow/utils/exceptions/shape_error.py | 5 + bayesflow/utils/plot_utils.py | 300 ++++++++++++++++++++++ 3 files changed, 305 insertions(+) create mode 100644 bayesflow/utils/exceptions/__init__.py create mode 100644 bayesflow/utils/exceptions/shape_error.py create mode 100644 bayesflow/utils/plot_utils.py diff --git a/bayesflow/utils/exceptions/__init__.py b/bayesflow/utils/exceptions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bayesflow/utils/exceptions/shape_error.py b/bayesflow/utils/exceptions/shape_error.py new file mode 100644 index 000000000..c167bdab6 --- /dev/null +++ b/bayesflow/utils/exceptions/shape_error.py @@ -0,0 +1,5 @@ + +class ShapeError(Exception): + """Class for error in expected shapes.""" + + pass diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py new file mode 100644 index 000000000..402a61e05 --- /dev/null +++ b/bayesflow/utils/plot_utils.py @@ -0,0 +1,300 @@ + +import numpy as np +import matplotlib.pyplot as plt + +from .exceptions.shape_error import ShapeError + + +def check_posterior_prior_shapes(post_samples, prior_samples): + """ + Checks requirements for the shapes of posterior and prior draws as + necessitated by most diagnostic functions. + + Parameters + ---------- + post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) + The posterior draws obtained from n_data_sets + prior_samples : np.ndarray of shape (n_data_sets, n_params) + The prior draws obtained for generating n_data_sets + + Raises + ------ + ShapeError + If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. + """ + + if len(post_samples.shape) != 3: + raise ShapeError( + f"post_samples should be a 3-dimensional array, with the " + + f"first dimension being the number of (simulated) data sets, " + + f"the second dimension being the number of posterior draws per data set, " + + f"and the third dimension being the number of parameters (marginal distributions), " + + f"but your input has dimensions {len(post_samples.shape)}" + ) + elif len(prior_samples.shape) != 2: + raise ShapeError( + f"prior_samples should be a 2-dimensional array, with the " + + f"first dimension being the number of (simulated) data sets / prior draws " + + f"and the second dimension being the number of parameters (marginal distributions), " + + f"but your input has dimensions {len(prior_samples.shape)}" + ) + elif post_samples.shape[0] != prior_samples.shape[0]: + raise ShapeError( + f"The number of elements over the first dimension of post_samples and prior_samples" + + f"should match, but post_samples has {post_samples.shape[0]} and prior_samples has " + + f"{prior_samples.shape[0]} elements, respectively." + ) + elif post_samples.shape[-1] != prior_samples.shape[-1]: + raise ShapeError( + f"The number of elements over the last dimension of post_samples and prior_samples" + + f"should match, but post_samples has {post_samples.shape[1]} and prior_samples has " + + f"{prior_samples.shape[-1]} elements, respectively." + ) + + +def get_count_and_names( + samples, + names: list = None, + symbol: str = None, + n_objects: int = None +): + """ + Determine the number of objects, such as parameters or models, + and their respective names if None given. + + Parameters + ---------- + samples : np.ndarray of shape(..., n_objects) + The objects of interest + names : list[str], optional, default: None + The names of individual object + symbol : str, optional, default: None + The symbol used for naming the individual object. + If none given, default is associated with a parameter named $\\theta$. + n_objects : int, optional, default: None + The number of individual objects + + Returns + ------- + n_objects : int + Number of individual objects + names : list[str] + List of names for the individual object + """ + if n_objects is None: + n_objects = samples.shape[-1] + if names is None: + if symbol is None: + symbol = "\\theta" + names = [f"${symbol}_{{{i}}}$" for i in range(1, n_objects + 1)] + + return n_objects, names + + +def configure_layout( + n_total: int, + n_row: int = None, + n_col: int = None, + stacked: bool = False +): + """ + Determine the number of rows and columns in diagnostics visualizations. + + Parameters + ---------- + n_total : int + Total number of parameters + n_row : int, default = None + Number of rows for the visualization layout + n_col : int, default = None + Number of columns for the visualization layout + stacked : bool, default = False + Boolean that determines whether to stack the plot or not. + + Returns + ------- + n_row : int + Number of rows for the visualization layout + n_col : int + Number of columns for the visualization layout + """ + if stacked: + n_row, n_col = 1, 1 + else: + if n_row is None and n_col is None: + n_row = int(np.ceil(n_total / 6)) + n_col = int(np.ceil(n_total / n_row)) + elif n_row is None and n_col is not None: + n_row = int(np.ceil(n_total / n_col)) + elif n_row is not None and n_col is None: + n_col = int(np.ceil(n_total / n_row)) + + return n_row, n_col + + +def initialize_figure( + n_row: int = None, + n_col: int = None, + fig_size: tuple = None, +): + """ + Initialize a set of figures + + Parameters + ---------- + n_row : int + Number of rows in a figure + n_col : int + Number of columns in a figure + stacked : bool + Whether subplots in a figure are stacked by rows + fig_size : tuple + Size of the figure adjusting to the display resolution + or the designer's desire + + Returns + ------- + f, axarr + Initialized figures + """ + if n_row == 1 and n_col == 1: + f, axarr = plt.subplots(1, 1, figsize=fig_size) + else: + if fig_size is None: + fig_size = (int(5 * n_col), int(5 * n_row)) + + f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) + + return f, axarr + + +def collapse_axes(axarr, n_row: int = 1, n_col: int = 1): + """ + Collapse a 2D array of subplot Axes into a 1D array + + Parameters + ---------- + axarr : 2D array of Axes + An array of axes for subplots + n_row : int, default: 1 + Number of rows for the axes + n_col : int, default: 1 + Number of columns for the axes + + Returns + ------- + ax : 1D array of Axes + Collapsed axes for subplots + """ + + ax = np.atleast_1d(axarr) + # turn axarr into 1D list + if n_row > 1 or n_col > 1: + ax = axarr.flat + else: + ax = axarr + + return ax + + +def add_xlabels( + axarr, + n_row: int = None, + n_col: int = None, + xlabel: str = None, + label_fontsize: int = None +): + # Only add x-labels to the bottom row + bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] + for _ax in bottom_row: + _ax.set_xlabel(xlabel, fontsize=label_fontsize) + + +def add_ylabels( + axarr, + n_row: int = None, + ylabel: str = None, + label_fontsize: int = None +): + # Only add y-labels to right left-most row + if n_row == 1: # if there is only one row, the ax array is 1D + axarr[0].set_ylabel(ylabel, fontsize=label_fontsize) + # If there is more than one row, the ax array is 2D + else: + for _ax in axarr[:, 0]: + _ax.set_ylabel(ylabel, fontsize=label_fontsize) + + +def add_labels( + axarr, + n_row: int = None, + n_col: int = None, + xlabel: str = None, + ylabel: str = None, + label_fontsize: int = None +): + """ + Wrapper function for configuring labels for both axes. + """ + add_xlabels(axarr, n_row, n_col, xlabel, label_fontsize) + add_ylabels(axarr, n_row, ylabel, label_fontsize) + + +def remove_unused_axes(axarr_it, n_params: int = None): + for _ax in axarr_it[n_params:]: + _ax.remove() + + +def preprocess( + post_samples, + prior_samples, + fig_size: tuple = None, + collapse: bool = True +): + """ + Procedural wrapper that encompasses all preprocessing steps, + including shape-checking, parameter name generation, layout configuration, + figure initialization, and axial collapsing for loop and plot. + + Parameters + ---------- + post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) + The posterior draws obtained from n_data_sets + prior_samples : np.ndarray of shape (n_data_sets, n_params) + The prior draws obtained for generating n_data_sets + fig_size : tuple, optional, default: None + Size of the figure adjusting to the display resolution + stacked : bool, optional, default: False + Whether subplots in a figure are stacked by rows + collapse : bool, optional, default: True + Whether subplots in a figure are collapsed into rows + """ + + # Sanity check + check_posterior_prior_shapes(post_samples, prior_samples) + + # Determine parameters and parameter names + n_params, param_names = get_count_and_names(post_samples) + + # Configure layout + n_row, n_col = configure_layout(n_params) + + # Initialize figure + f, axarr = initialize_figure(n_row, n_col, fig_size=fig_size) + + # turn axarr into 1D list + if collapse: + axarr_it = collapse_axes(axarr, n_row, n_col) + else: + axarr_it = axarr + + return f, axarr, axarr_it, n_row, n_col, n_params, param_names + + +def postprocess(*args): + """ + Procedural wrapper for postprocessing steps, including adding labels and removing unused axes. + """ + + add_labels(args) + remove_unused_axes(args) From 2df83141519f45091ae8a8a307a5237c09b6d72b Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 9 Oct 2024 01:36:18 -0400 Subject: [PATCH 07/22] Add refactored plot recovery from previous --- bayesflow/diagnostics/plot_recovery.py | 165 +++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 bayesflow/diagnostics/plot_recovery.py diff --git a/bayesflow/diagnostics/plot_recovery.py b/bayesflow/diagnostics/plot_recovery.py new file mode 100644 index 000000000..2001c95e6 --- /dev/null +++ b/bayesflow/diagnostics/plot_recovery.py @@ -0,0 +1,165 @@ + +import numpy as np +from scipy.stats import median_abs_deviation +from sklearn.metrics import r2_score +import seaborn as sns + +from ..utils.plot_utils import preprocess, postprocess + + +def plot_recovery( + post_samples, + prior_samples, + point_agg=np.median, + uncertainty_agg=median_abs_deviation, + param_names: list = None, + fig_size: tuple = None, + label_fontsize: int = 16, + title_fontsize: int = 18, + metric_fontsize: int = 16, + tick_fontsize: int = 12, + add_corr: bool = True, + add_r2: bool = True, + color: str | tuple = "#8f2727", + n_col: int = None, + n_row: int = None, + xlabel: str = "Ground truth", + ylabel: str = "Estimated", + **kwargs, +): + """Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty. + The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate + can be controlled with the ``uncertainty_agg`` argument. + + This plot yields similar information as the "posterior z-score", but allows for generic + point and uncertainty estimates: + + https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html + + Important: Posterior aggregates play no special role in Bayesian inference and should only + be used heuristically. For instance, in the case of multi-modal posteriors, common point + estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing. + + Parameters + ---------- + post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) + The posterior draws obtained from n_data_sets + prior_samples : np.ndarray of shape (n_data_sets, n_params) + The prior draws (true parameters) obtained for generating the n_data_sets + point_agg : callable, optional, default: ``np.median`` + The function to apply to the posterior draws to get a point estimate for each marginal. + The default computes the marginal median for each marginal posterior as a robust + point estimate. + uncertainty_agg : callable or None, optional, default: scipy.stats.median_abs_deviation + The function to apply to the posterior draws to get an uncertainty estimate. + If ``None`` provided, a simple scatter using only ``point_agg`` will be plotted. + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + fig_size : tuple or None, optional, default : None + The figure size passed to the matplotlib constructor. Inferred if None. + label_fontsize : int, optional, default: 16 + The font size of the y-label text + title_fontsize : int, optional, default: 18 + The font size of the title text + metric_fontsize : int, optional, default: 16 + The font size of the goodness-of-fit metric (if provided) + tick_fontsize : int, optional, default: 12 + The font size of the axis tick labels + add_corr : bool, optional, default: True + A flag for adding correlation between true and estimates to the plot + add_r2 : bool, optional, default: True + A flag for adding R^2 between true and estimates to the plot + color : str, optional, default: '#8f2727' + The color for the true vs. estimated scatter points and error bars + n_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + n_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + xlabel : str, optional, default: 'Ground truth' + The label on the x-axis of the plot + ylabel : str, optional, default: 'Estimated' + The label on the y-axis of the plot + **kwargs : optional + Additional keyword arguments passed to ax.errorbar or ax.scatter. + Example: `rasterized=True` to reduce PDF file size with many dots + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + ShapeError + If there is a deviation from the expected shapes of ``post_samples`` and ``prior_samples``. + """ + + # Preprocess + f, axarr, axarr_it, n_row, n_col, n_params, param_names = preprocess( + post_samples, prior_samples, fig_size=fig_size + ) + + # Compute point estimates and uncertainties + est = point_agg(post_samples, axis=1) + if uncertainty_agg is not None: + u = uncertainty_agg(post_samples, axis=1) + + # Loop and plot + for i, ax in enumerate(axarr_it): + if i >= n_params: + break + + # Add scatter and error bars + if uncertainty_agg is not None: + _ = ax.errorbar(prior_samples[:, i], est[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color, **kwargs) + else: + _ = ax.scatter(prior_samples[:, i], est[:, i], alpha=0.5, color=color, **kwargs) + + # Make plots quadratic to avoid visual illusions + lower = min(prior_samples[:, i].min(), est[:, i].min()) + upper = max(prior_samples[:, i].max(), est[:, i].max()) + eps = (upper - lower) * 0.1 + ax.set_xlim([lower - eps, upper + eps]) + ax.set_ylim([lower - eps, upper + eps]) + ax.plot( + [ax.get_xlim()[0], ax.get_xlim()[1]], + [ax.get_ylim()[0], ax.get_ylim()[1]], + color="black", + alpha=0.9, + linestyle="dashed", + ) + + # Add optional metrics and title + if add_r2: + r2 = r2_score(prior_samples[:, i], est[:, i]) + ax.text( + 0.1, + 0.9, + "$R^2$ = {:.3f}".format(r2), + horizontalalignment="left", + verticalalignment="center", + transform=ax.transAxes, + size=metric_fontsize, + ) + if add_corr: + corr = np.corrcoef(prior_samples[:, i], est[:, i])[0, 1] + ax.text( + 0.1, + 0.8, + "$r$ = {:.3f}".format(corr), + horizontalalignment="left", + verticalalignment="center", + transform=ax.transAxes, + size=metric_fontsize, + ) + ax.set_title(param_names[i], fontsize=title_fontsize) + + # Prettify + sns.despine(ax=ax) + ax.grid(alpha=0.5) + ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) + ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) + + postprocess(axarr, axarr_it, n_row, n_col, n_params, xlabel, ylabel, label_fontsize) + + f.tight_layout() + return f From d83fa8160a60e835837aee733457fb8752838bb8 Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 9 Oct 2024 01:38:39 -0400 Subject: [PATCH 08/22] Add loss plot from previous --- bayesflow/diagnostics/plot_losses.py | 121 +++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 bayesflow/diagnostics/plot_losses.py diff --git a/bayesflow/diagnostics/plot_losses.py b/bayesflow/diagnostics/plot_losses.py new file mode 100644 index 000000000..d778e1e8b --- /dev/null +++ b/bayesflow/diagnostics/plot_losses.py @@ -0,0 +1,121 @@ + +import seaborn as sns + +from tensorflow.keras import ops +from ..utils.plot_utils import initialize_figure + + +def plot_losses( + train_losses, + val_losses=None, + moving_average: bool = False, + ma_window_fraction: float = 0.01, + train_color: str = "#8f2727", + val_color: str = "black", + lw_train: int = 2, + lw_val: int = 3, + grid_alpha: float = 0.5, + legend_fontsize: int = 14, + label_fontsize: int = 14, + title_fontsize: int = 16, +): + """A generic helper function to plot the losses of a series of training epochs and runs. + + Parameters + ---------- + + train_losses : pd.DataFrame + The (plottable) history as returned by a train_[...] method of a ``Trainer`` instance. + Alternatively, you can just pass a data frame of validation losses instead of train losses, + if you only want to plot the validation loss. + val_losses : pd.DataFrame or None, optional, default: None + The (plottable) validation history as returned by a train_[...] method of a ``Trainer`` instance. + If left ``None``, only train losses are plotted. Should have the same number of columns + as ``train_losses``. + moving_average : bool, optional, default: False + A flag for adding a moving average line of the train_losses. + ma_window_fraction : int, optional, default: 0.01 + Window size for the moving average as a fraction of total training steps. + fig_size : tuple or None, optional, default: None + The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` + train_color : str, optional, default: '#8f2727' + The color for the train loss trajectory + val_color : str, optional, default: black + The color for the optional validation loss trajectory + lw_train : int, optional, default: 2 + The linewidth for the training loss curve + lw_val : int, optional, default: 3 + The linewidth for the validation loss curve + grid_alpha : float, optional, default 0.5 + The opacity factor for the background gridlines + legend_fontsize : int, optional, default: 14 + The font size of the legend text + label_fontsize : int, optional, default: 14 + The font size of the y-label text + title_fontsize : int, optional, default: 16 + The font size of the title text + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + AssertionError + If the number of columns in ``train_losses`` does not match the + number of columns in ``val_losses``. + """ + + # Determine the number of rows for plot + n_row = len(train_losses.columns) + + # Initialize figure + f, axarr = initialize_figure(n_row=n_row, n_col=1, fig_size=(16, int(4 * n_row))) + + # if fig_size is None: + # fig_size = (16, int(4 * n_row)) + # f, axarr = plt.subplots(n_row, 1, figsize=fig_size) + + # Get the number of steps as an array + train_step_index = ops.arange(1, len(train_losses) + 1) + if val_losses is not None: + val_step = int(ops.floor(len(train_losses) / len(val_losses))) + val_step_index = train_step_index[(val_step - 1) :: val_step] + + # If unequal length due to some reason, attempt a fix + if val_step_index.shape[0] > val_losses.shape[0]: + val_step_index = val_step_index[: val_losses.shape[0]] + + # Loop through loss entries and populate plot + looper = [axarr] if n_row == 1 else axarr.flat + for i, ax in enumerate(looper): + # Plot train curve + ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training") + if moving_average and train_losses.columns[i] == "Loss": + moving_average_window = int(train_losses.shape[0] * ma_window_fraction) + smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean() + ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)") + + # Plot optional val curve + if val_losses is not None: + if i < val_losses.shape[1]: + ax.plot( + val_step_index, + val_losses.iloc[:, i], + linestyle="--", + marker="o", + color=val_color, + lw=lw_val, + label="Validation", + ) + # Schmuck + ax.set_xlabel("Training step #", fontsize=label_fontsize) + ax.set_ylabel("Value", fontsize=label_fontsize) + sns.despine(ax=ax) + ax.grid(alpha=grid_alpha) + ax.set_title(train_losses.columns[i], fontsize=title_fontsize) + # Only add legend if there is a validation curve + if val_losses is not None or moving_average: + ax.legend(fontsize=legend_fontsize) + f.tight_layout() + return f From 9e75b283e5761ba89ffc01fddd93f7c1366e3a81 Mon Sep 17 00:00:00 2001 From: Jerry Date: Mon, 14 Oct 2024 18:14:56 -0400 Subject: [PATCH 09/22] Add TwoMoons notebook for testing diagnostics --- bayesflow/diagnostics/__init__.py | 2 + examples/TwoMoons_Diagnostics.ipynb | 1637 ++++++++++++++++++++++++++ examples/TwoMoons_FlowMatching.ipynb | 83 +- 3 files changed, 1674 insertions(+), 48 deletions(-) create mode 100644 examples/TwoMoons_Diagnostics.ipynb diff --git a/bayesflow/diagnostics/__init__.py b/bayesflow/diagnostics/__init__.py index e69de29bb..28adaf7ba 100644 --- a/bayesflow/diagnostics/__init__.py +++ b/bayesflow/diagnostics/__init__.py @@ -0,0 +1,2 @@ +from .plot_losses import plot_losses +from .plot_recovery import plot_recovery \ No newline at end of file diff --git a/examples/TwoMoons_Diagnostics.ipynb b/examples/TwoMoons_Diagnostics.ipynb new file mode 100644 index 000000000..43abbdaa4 --- /dev/null +++ b/examples/TwoMoons_Diagnostics.ipynb @@ -0,0 +1,1637 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-10-14T21:00:47.608845Z", + "start_time": "2024-10-14T21:00:47.596803Z" + } + }, + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "\n", + "# ensure the backend is set\n", + "import os\n", + "if \"KERAS_BACKEND\" not in os.environ:\n", + " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", + " os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "\n", + "import keras\n", + "\n", + "# for BayesFlow devs: this ensures that the latest dev version can be found\n", + "import sys\n", + "sys.path.append('../')\n", + "\n", + "import bayesflow as bf\n", + "from bayesflow.diagnostics.plot_losses import plot_losses\n", + "from bayesflow.diagnostics.plot_recovery import plot_recovery" + ], + "outputs": [], + "execution_count": 29 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:27:46.593978Z", + "start_time": "2024-10-14T20:27:46.585246Z" + } + }, + "cell_type": "code", + "source": [ + "def alpha_prior():\n", + " alpha = np.random.uniform(-np.pi / 2, np.pi / 2)\n", + " return dict(alpha=alpha)\n", + "\n", + "def r_prior():\n", + " r = np.random.normal(0.1, 0.01)\n", + " return dict(r=r)\n", + "\n", + "def theta_prior():\n", + " theta = np.random.uniform(-1, 1, 2)\n", + " return dict(theta=theta)\n", + "\n", + "def forward_model(theta, alpha, r):\n", + " x1 = -np.abs(theta[0] + theta[1]) / np.sqrt(2) + r * np.cos(alpha) + 0.25\n", + " x2 = (-theta[0] + theta[1]) / np.sqrt(2) + r * np.sin(alpha)\n", + " return dict(x=np.array([x1, x2]))" + ], + "id": "2aa9c9710a36b980", + "outputs": [], + "execution_count": 2 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:27:47.665818Z", + "start_time": "2024-10-14T20:27:47.650596Z" + } + }, + "cell_type": "code", + "source": "simulator = bf.simulators.CompositeLambdaSimulator([alpha_prior, r_prior, theta_prior, forward_model])", + "id": "7db949c6bfecc86d", + "outputs": [], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:27:48.154339Z", + "start_time": "2024-10-14T20:27:48.131066Z" + } + }, + "cell_type": "code", + "source": [ + "# generate 128 random draws from the joint distribution p(r, alpha, theta, x)\n", + "sample_data = simulator.sample((128,))" + ], + "id": "b3a0fe5293beec1b", + "outputs": [], + "execution_count": 4 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:27:48.809696Z", + "start_time": "2024-10-14T20:27:48.792150Z" + } + }, + "cell_type": "code", + "source": [ + "print(\"Type of sample_data:\\n\\t\", type(sample_data))\n", + "print(\"Keys of sample_data:\\n\\t\", sample_data.keys())\n", + "print(\"Types of sample_data values:\\n\\t\", {k: type(v) for k, v in sample_data.items()})\n", + "print(\"Shapes of sample_data values:\\n\\t\", {k: v.shape for k, v in sample_data.items()})" + ], + "id": "7b75698477fdf1f3", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of sample_data:\n", + "\t \n", + "Keys of sample_data:\n", + "\t dict_keys(['alpha', 'r', 'theta', 'x'])\n", + "Types of sample_data values:\n", + "\t {'alpha': , 'r': , 'theta': , 'x': }\n", + "Shapes of sample_data values:\n", + "\t {'alpha': (128, 1), 'r': (128, 1), 'theta': (128, 2), 'x': (128, 2)}\n" + ] + } + ], + "execution_count": 5 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:27:50.913110Z", + "start_time": "2024-10-14T20:27:50.887083Z" + } + }, + "cell_type": "code", + "source": "sample_data", + "id": "ad5010214a02d3de", + "outputs": [ + { + "data": { + "text/plain": [ + "{'alpha': array([[-0.32667086],\n", + " [ 0.4273289 ],\n", + " [-1.4049091 ],\n", + " [ 0.6495664 ],\n", + " [-1.2431567 ],\n", + " [ 0.00233715],\n", + " [ 0.8480671 ],\n", + " [ 0.3825317 ],\n", + " [ 0.51094717],\n", + " [-1.089614 ],\n", + " [-1.3354424 ],\n", + " [ 1.521167 ],\n", + " [-0.39602077],\n", + " [ 0.88564837],\n", + " [ 0.9123921 ],\n", + " [-0.8354839 ],\n", + " [ 0.4103786 ],\n", + " [-1.3677335 ],\n", + " [ 1.5597553 ],\n", + " [ 0.857706 ],\n", + " [ 1.3851596 ],\n", + " [ 0.08588123],\n", + " [-1.3278203 ],\n", + " [-1.4882586 ],\n", + " [-0.7284002 ],\n", + " [ 0.13790596],\n", + " [ 1.0051142 ],\n", + " [-1.2707354 ],\n", + " [-1.4755409 ],\n", + " [-0.03580285],\n", + " [-1.5595189 ],\n", + " [-0.67198914],\n", + " [-1.5414019 ],\n", + " [ 0.46680772],\n", + " [-0.49131817],\n", + " [ 1.3081697 ],\n", + " [ 1.3158842 ],\n", + " [ 0.5498999 ],\n", + " [ 0.28842863],\n", + " [ 1.5669599 ],\n", + " [ 1.1178267 ],\n", + " [-1.3773671 ],\n", + " [ 0.21009572],\n", + " [ 1.247244 ],\n", + " [-0.70116574],\n", + " [-1.3302704 ],\n", + " [-1.2457173 ],\n", + " [ 1.5158378 ],\n", + " [ 0.4760814 ],\n", + " [ 0.5130099 ],\n", + " [ 0.38356388],\n", + " [ 0.51667935],\n", + " [-1.3263792 ],\n", + " [ 1.1610556 ],\n", + " [-1.0892584 ],\n", + " [ 1.5635314 ],\n", + " [ 0.6141801 ],\n", + " [ 0.7341371 ],\n", + " [ 0.23875016],\n", + " [ 0.6227148 ],\n", + " [-0.4782562 ],\n", + " [ 0.3656121 ],\n", + " [ 0.6328574 ],\n", + " [-1.3631004 ],\n", + " [ 1.4811654 ],\n", + " [-1.373762 ],\n", + " [-0.24200128],\n", + " [ 0.00168619],\n", + " [ 1.0400674 ],\n", + " [ 0.18901351],\n", + " [ 0.9238482 ],\n", + " [ 0.02520916],\n", + " [ 1.4398099 ],\n", + " [-0.51892877],\n", + " [-1.1703858 ],\n", + " [-0.12021437],\n", + " [ 0.80984557],\n", + " [-0.9726958 ],\n", + " [-0.2244538 ],\n", + " [ 0.30686495],\n", + " [ 0.59431726],\n", + " [-1.322811 ],\n", + " [ 0.8136638 ],\n", + " [-1.5020558 ],\n", + " [-1.4799207 ],\n", + " [-1.3136892 ],\n", + " [ 0.06446969],\n", + " [ 1.3328581 ],\n", + " [ 0.66848814],\n", + " [ 0.7860198 ],\n", + " [ 1.315634 ],\n", + " [-0.23607427],\n", + " [-0.8002341 ],\n", + " [ 1.5251296 ],\n", + " [ 0.15763855],\n", + " [-1.5531527 ],\n", + " [ 0.56706136],\n", + " [ 1.047334 ],\n", + " [ 0.89252347],\n", + " [ 1.2277393 ],\n", + " [ 0.8999341 ],\n", + " [ 1.0635433 ],\n", + " [ 0.04854681],\n", + " [ 0.84339076],\n", + " [ 0.42572305],\n", + " [-0.13823606],\n", + " [ 0.36718416],\n", + " [-1.1577339 ],\n", + " [-0.5522179 ],\n", + " [ 0.7911456 ],\n", + " [ 0.8179233 ],\n", + " [-0.62356246],\n", + " [-0.33656436],\n", + " [ 0.17404567],\n", + " [ 0.4389914 ],\n", + " [-0.9474675 ],\n", + " [-1.1168886 ],\n", + " [-0.09231075],\n", + " [ 1.0462689 ],\n", + " [-0.90480804],\n", + " [-1.4208354 ],\n", + " [-0.16266003],\n", + " [ 0.58943385],\n", + " [ 0.9045791 ],\n", + " [ 0.42233914],\n", + " [-0.9887428 ],\n", + " [ 1.3377244 ],\n", + " [ 1.3765699 ]], dtype=float32),\n", + " 'r': array([[0.08992655],\n", + " [0.11262761],\n", + " [0.09988438],\n", + " [0.10757513],\n", + " [0.09842595],\n", + " [0.08780682],\n", + " [0.09458143],\n", + " [0.08305464],\n", + " [0.10684904],\n", + " [0.09243193],\n", + " [0.09636895],\n", + " [0.10226896],\n", + " [0.10459247],\n", + " [0.10921837],\n", + " [0.11026573],\n", + " [0.08566567],\n", + " [0.08520468],\n", + " [0.09628233],\n", + " [0.09432837],\n", + " [0.09125222],\n", + " [0.09802835],\n", + " [0.08590778],\n", + " [0.10220997],\n", + " [0.09208615],\n", + " [0.09563595],\n", + " [0.11774731],\n", + " [0.08680242],\n", + " [0.11012331],\n", + " [0.09633071],\n", + " [0.09756536],\n", + " [0.12378401],\n", + " [0.11334178],\n", + " [0.09062148],\n", + " [0.10854411],\n", + " [0.09699341],\n", + " [0.09652205],\n", + " [0.10683485],\n", + " [0.10969307],\n", + " [0.1108022 ],\n", + " [0.10317604],\n", + " [0.07496227],\n", + " [0.09012368],\n", + " [0.09105562],\n", + " [0.08200298],\n", + " [0.0828375 ],\n", + " [0.10398124],\n", + " [0.1007928 ],\n", + " [0.10111594],\n", + " [0.10167468],\n", + " [0.08223265],\n", + " [0.09048541],\n", + " [0.09253196],\n", + " [0.11334959],\n", + " [0.10842754],\n", + " [0.09764497],\n", + " [0.08462145],\n", + " [0.11413962],\n", + " [0.11527291],\n", + " [0.10542885],\n", + " [0.09038962],\n", + " [0.10374972],\n", + " [0.10187822],\n", + " [0.10547098],\n", + " [0.0985254 ],\n", + " [0.11656975],\n", + " [0.10378908],\n", + " [0.09430881],\n", + " [0.10135388],\n", + " [0.09672231],\n", + " [0.10255771],\n", + " [0.09387974],\n", + " [0.09308615],\n", + " [0.09995835],\n", + " [0.10125452],\n", + " [0.08677949],\n", + " [0.10938775],\n", + " [0.08700917],\n", + " [0.10388696],\n", + " [0.10093628],\n", + " [0.08200264],\n", + " [0.10838373],\n", + " [0.11670296],\n", + " [0.0975048 ],\n", + " [0.10851161],\n", + " [0.11573117],\n", + " [0.08443198],\n", + " [0.11458082],\n", + " [0.09952442],\n", + " [0.09616404],\n", + " [0.10941261],\n", + " [0.10953938],\n", + " [0.10442203],\n", + " [0.10339843],\n", + " [0.11485437],\n", + " [0.10533367],\n", + " [0.09481129],\n", + " [0.09040346],\n", + " [0.10173973],\n", + " [0.10177024],\n", + " [0.11780889],\n", + " [0.09570873],\n", + " [0.11882972],\n", + " [0.08545157],\n", + " [0.09944647],\n", + " [0.08443879],\n", + " [0.08220201],\n", + " [0.1000874 ],\n", + " [0.09128848],\n", + " [0.08286219],\n", + " [0.09595444],\n", + " [0.0979057 ],\n", + " [0.10283633],\n", + " [0.09927638],\n", + " [0.08226943],\n", + " [0.08991795],\n", + " [0.08759429],\n", + " [0.09739396],\n", + " [0.09011568],\n", + " [0.1022917 ],\n", + " [0.10698826],\n", + " [0.09153012],\n", + " [0.11289136],\n", + " [0.09984156],\n", + " [0.09630046],\n", + " [0.09598115],\n", + " [0.09263593],\n", + " [0.11081006],\n", + " [0.10720474]], dtype=float32),\n", + " 'theta': array([[-3.05581540e-02, 1.45420015e-01],\n", + " [ 2.11661726e-01, -8.16319406e-01],\n", + " [-2.41767168e-01, -8.90369564e-02],\n", + " [-9.58784044e-01, -6.43721148e-02],\n", + " [-8.37376893e-01, -3.60958546e-01],\n", + " [-5.58282018e-01, 9.12082434e-01],\n", + " [ 4.27646607e-01, 8.86675537e-01],\n", + " [ 1.40667871e-01, -2.22100895e-02],\n", + " [ 1.42964154e-01, -7.41437301e-02],\n", + " [-2.23557115e-01, 6.94955945e-01],\n", + " [ 1.82445496e-01, -8.65079284e-01],\n", + " [ 4.57260549e-01, -5.30316174e-01],\n", + " [-6.21315300e-01, 2.07263768e-01],\n", + " [ 6.10240161e-01, -8.74304950e-01],\n", + " [-2.59324551e-01, 4.74186003e-01],\n", + " [-3.78903113e-02, 9.00502682e-01],\n", + " [ 6.82602823e-01, -7.56820023e-01],\n", + " [ 5.94463050e-01, -1.21822745e-01],\n", + " [ 8.27706277e-01, 3.59144360e-01],\n", + " [-7.00999856e-01, 4.03989136e-01],\n", + " [-4.31024581e-01, -7.00606406e-01],\n", + " [ 9.86329079e-01, -8.04728150e-01],\n", + " [-5.38084447e-01, 3.05770040e-01],\n", + " [ 6.28210962e-01, 3.67884368e-01],\n", + " [ 1.20291792e-01, 5.00622094e-01],\n", + " [-7.59945214e-01, 4.06511813e-01],\n", + " [ 9.54336047e-01, -4.55123216e-01],\n", + " [-4.13251549e-01, 5.47428668e-01],\n", + " [-9.78350341e-01, 4.92099434e-01],\n", + " [ 6.58699691e-01, -5.58192194e-01],\n", + " [ 3.67015302e-01, 3.97362381e-01],\n", + " [-8.03872824e-01, -5.61456561e-01],\n", + " [-7.27509439e-01, -2.47731626e-01],\n", + " [-2.85620838e-01, -4.79526430e-01],\n", + " [-1.94825262e-01, -2.55926311e-01],\n", + " [ 1.67872280e-01, -9.53495979e-01],\n", + " [-1.38880566e-01, -6.23547696e-02],\n", + " [ 8.53347182e-01, -2.39163131e-01],\n", + " [ 6.25432789e-01, -7.36523867e-01],\n", + " [ 7.81854391e-01, 2.92491883e-01],\n", + " [-8.66746664e-01, -8.48121166e-01],\n", + " [-1.03085585e-01, -3.94162923e-01],\n", + " [-7.86425531e-01, 8.93066466e-01],\n", + " [ 4.55110759e-01, -3.10456127e-01],\n", + " [ 8.81353259e-01, 5.10756969e-01],\n", + " [-7.21052825e-01, -2.99451917e-01],\n", + " [-7.17009425e-01, 7.83575058e-01],\n", + " [ 3.48350137e-01, 7.37254381e-01],\n", + " [-7.80723929e-01, -8.84120941e-01],\n", + " [-1.95760652e-01, -2.31012523e-01],\n", + " [ 6.60519898e-01, -6.55787408e-01],\n", + " [-7.19074786e-01, -2.03196511e-01],\n", + " [ 6.15297891e-02, -5.67862332e-01],\n", + " [-6.49809659e-01, 5.80364406e-01],\n", + " [ 2.61088669e-01, 5.19194543e-01],\n", + " [ 3.60945612e-01, 3.35697114e-01],\n", + " [ 3.50850195e-01, -7.03177214e-01],\n", + " [ 7.74477005e-01, -2.59503335e-01],\n", + " [-7.42385924e-01, -8.22378099e-01],\n", + " [-2.85933644e-01, 6.74357533e-01],\n", + " [-8.53843629e-01, 4.06876981e-01],\n", + " [ 9.58840549e-01, 9.16405022e-01],\n", + " [ 7.74163187e-01, 7.48323083e-01],\n", + " [ 8.63941729e-01, -3.28039467e-01],\n", + " [-2.40420654e-01, 2.52833039e-01],\n", + " [-5.98286033e-01, 9.57714319e-01],\n", + " [-7.42678821e-01, -3.43333989e-01],\n", + " [ 9.68854368e-01, 3.85829717e-01],\n", + " [-2.75604427e-01, -7.95602053e-02],\n", + " [-2.03438953e-01, 9.11069755e-03],\n", + " [-1.43148184e-01, -1.46573469e-01],\n", + " [-9.49487031e-01, 3.15213263e-01],\n", + " [ 6.84628665e-01, 9.61936653e-01],\n", + " [-2.04393521e-01, 5.88398874e-01],\n", + " [-9.50749099e-01, 7.83227861e-01],\n", + " [ 2.88510352e-01, 9.24940228e-01],\n", + " [ 8.27838302e-01, -9.24422801e-01],\n", + " [ 6.05663717e-01, -7.59422839e-01],\n", + " [-9.75102127e-01, 1.72841713e-01],\n", + " [-2.73299128e-01, -2.70784408e-01],\n", + " [-4.23562735e-01, 1.31962135e-01],\n", + " [ 9.45790648e-01, -3.38832617e-01],\n", + " [ 1.88001692e-01, 2.13898122e-01],\n", + " [ 5.90824075e-02, -2.38077283e-01],\n", + " [-9.78379369e-01, -6.26421869e-01],\n", + " [-6.73744559e-01, 4.74431008e-01],\n", + " [ 5.22617698e-01, 8.93913805e-01],\n", + " [ 5.95449269e-01, 9.14583445e-01],\n", + " [ 3.90204303e-02, -2.39531472e-01],\n", + " [ 5.71989954e-01, -4.74963844e-01],\n", + " [ 6.58734083e-01, 5.09142876e-01],\n", + " [-2.53153235e-01, -3.53049845e-01],\n", + " [-4.54396307e-01, -6.79341435e-01],\n", + " [ 9.29932415e-01, 8.80713701e-01],\n", + " [ 2.17661366e-01, 6.54169917e-01],\n", + " [ 2.10923515e-02, -2.40477219e-01],\n", + " [-2.51346976e-01, -2.13967457e-01],\n", + " [ 5.17891407e-01, 2.40374476e-01],\n", + " [ 3.59319031e-01, -3.92084904e-02],\n", + " [-6.95064545e-01, 6.54330254e-01],\n", + " [ 4.39562589e-01, -7.08017647e-01],\n", + " [-2.80083835e-01, -2.79529452e-01],\n", + " [ 4.68703598e-01, -3.61453325e-01],\n", + " [ 1.75413545e-04, -6.19711876e-01],\n", + " [-4.56947744e-01, -5.46697043e-02],\n", + " [ 1.10423014e-01, 3.81866604e-01],\n", + " [ 4.95571673e-01, 6.30076528e-01],\n", + " [ 8.15737665e-01, 1.29877731e-01],\n", + " [-9.83589232e-01, -7.80846715e-01],\n", + " [-1.60895333e-01, -4.13245976e-01],\n", + " [-2.93852985e-01, 7.96879292e-01],\n", + " [-2.34337926e-01, 8.69962096e-01],\n", + " [ 8.72636318e-01, -3.94712389e-02],\n", + " [ 4.87689257e-01, -5.77459276e-01],\n", + " [-1.07371598e-01, -4.61379528e-01],\n", + " [-5.80118716e-01, 2.98643053e-01],\n", + " [ 3.65539849e-01, -8.42200577e-01],\n", + " [-5.28841615e-02, 4.88022923e-01],\n", + " [-9.50672925e-01, 2.79116750e-01],\n", + " [-6.66263402e-01, -8.50575149e-01],\n", + " [-2.95702636e-01, -7.57089794e-01],\n", + " [-4.11702067e-01, 9.68640268e-01],\n", + " [ 6.91891074e-01, 3.68866891e-01],\n", + " [-8.32779333e-03, 9.29424405e-01],\n", + " [ 5.37282348e-01, -9.28587794e-01],\n", + " [-5.35197616e-01, 4.34972018e-01],\n", + " [-8.97157609e-01, -2.00011898e-02],\n", + " [-2.74420083e-01, -9.79378104e-01]], dtype=float32),\n", + " 'x': array([[ 2.53951252e-01, 9.55786705e-02],\n", + " [-7.50578418e-02, -6.80214882e-01],\n", + " [ 3.25798206e-02, 9.48337466e-03],\n", + " [-3.87813628e-01, 6.97510600e-01],\n", + " [-5.65676749e-01, 2.43688509e-01],\n", + " [ 8.76319110e-02, 1.03990984e+00],\n", + " [-6.16806686e-01, 3.95518869e-01],\n", + " [ 2.43289366e-01, -8.41702744e-02],\n", + " [ 2.94539064e-01, -1.01268895e-01],\n", + " [-4.05492671e-02, 5.67550659e-01],\n", + " [-2.10222960e-01, -8.34424138e-01],\n", + " [ 2.03415319e-01, -5.96179128e-01],\n", + " [ 5.37187196e-02, 5.45547307e-01],\n", + " [ 1.32390037e-01, -9.65161324e-01],\n", + " [ 1.65536702e-01, 6.05887115e-01],\n", + " [-3.02492917e-01, 6.00012600e-01],\n", + " [ 2.75650620e-01, -9.83832717e-01],\n", + " [-6.47898912e-02, -6.00794613e-01],\n", + " [-5.88188708e-01, -2.37000689e-01],\n", + " [ 9.96765494e-02, 8.50363314e-01],\n", + " [-5.32090664e-01, -9.42790136e-02],\n", + " [ 2.07179919e-01, -1.25909996e+00],\n", + " [ 1.10319838e-01, 4.97487545e-01],\n", + " [-4.46753800e-01, -2.75851369e-01],\n", + " [-1.17684998e-01, 2.05271512e-01],\n", + " [ 1.16714276e-01, 8.40996325e-01],\n", + " [-5.64713925e-02, -9.23357546e-01],\n", + " [ 1.87672526e-01, 5.74100673e-01],\n", + " [-8.46691579e-02, 9.43871021e-01],\n", + " [ 2.76433289e-01, -8.63964856e-01],\n", + " [-2.89100736e-01, -1.02317505e-01],\n", + " [-6.26734078e-01, 1.00853950e-01],\n", + " [-4.36936170e-01, 2.48671815e-01],\n", + " [-1.94109902e-01, -8.82630050e-02],\n", + " [ 1.67907290e-02, -8.89653489e-02],\n", + " [-2.80460954e-01, -6.99714661e-01],\n", + " [ 1.34644642e-01, 1.57494441e-01],\n", + " [-9.07719210e-02, -7.15195656e-01],\n", + " [ 2.77671933e-01, -9.31531489e-01],\n", + " [-5.09281754e-01, -2.42856264e-01],\n", + " [-9.29788351e-01, 8.05726424e-02],\n", + " [-8.42837393e-02, -2.94265717e-01],\n", + " [ 2.63646871e-01, 1.20657015e+00],\n", + " [ 1.73785478e-01, -4.63589519e-01],\n", + " [-6.71075225e-01, -3.15490365e-01],\n", + " [-4.46836084e-01, 1.97128952e-01],\n", + " [ 2.35122561e-01, 9.65559661e-01],\n", + " [-5.12083948e-01, 3.75960112e-01],\n", + " [-8.36854875e-01, -2.65152324e-02],\n", + " [ 1.98727194e-02, 1.54331038e-02],\n", + " [ 3.30564082e-01, -8.96907687e-01],\n", + " [-3.21691066e-01, 4.10491407e-01],\n", + " [-8.06015953e-02, -5.55028141e-01],\n", + " [ 2.44089246e-01, 9.69316781e-01],\n", + " [-2.56520003e-01, 9.59672704e-02],\n", + " [-2.41986051e-01, 6.67658299e-02],\n", + " [ 9.41473544e-02, -6.79532588e-01],\n", + " [-2.85616964e-02, -6.53907835e-01],\n", + " [-7.54016936e-01, -3.16303074e-02],\n", + " [ 4.87661436e-02, 7.31747448e-01],\n", + " [ 2.60557346e-02, 8.43715191e-01],\n", + " [-9.80854273e-01, 6.41715247e-03],\n", + " [-7.41514742e-01, 4.41092215e-02],\n", + " [-1.08623601e-01, -9.39265966e-01],\n", + " [ 2.51657397e-01, 4.64884877e-01],\n", + " [ 1.61637682e-02, 9.98477459e-01],\n", + " [-4.26366359e-01, 2.59778708e-01],\n", + " [-6.06552601e-01, -4.12089765e-01],\n", + " [ 4.78178374e-02, 2.22041234e-01],\n", + " [ 2.13320345e-01, 1.69564873e-01],\n", + " [ 1.01722233e-01, 7.24871457e-02],\n", + " [-1.05442718e-01, 8.96624506e-01],\n", + " [-9.01241720e-01, 2.95188427e-01],\n", + " [ 6.63916618e-02, 5.10371685e-01],\n", + " [ 1.65370926e-01, 1.14619160e+00],\n", + " [-4.99440819e-01, 4.36905563e-01],\n", + " [ 2.41706863e-01, -1.17602539e+00],\n", + " [ 1.99771896e-01, -1.05111480e+00],\n", + " [-2.18879402e-01, 7.89253116e-01],\n", + " [-5.65532483e-02, 2.65488382e-02],\n", + " [ 1.33606523e-01, 4.53504145e-01],\n", + " [-1.50539234e-01, -1.02149868e+00],\n", + " [ 3.27841304e-02, 8.91788527e-02],\n", + " [ 1.30884781e-01, -3.18378985e-01],\n", + " [-8.74263108e-01, 1.33617908e-01],\n", + " [ 1.30533725e-01, 7.30226099e-01],\n", + " [-6.37296259e-01, 2.69927859e-01],\n", + " [-7.94296503e-01, 3.22382361e-01],\n", + " [ 1.83682933e-01, -1.37363449e-01],\n", + " [ 2.58710474e-01, -6.62893653e-01],\n", + " [-5.48165679e-01, 2.15799548e-04],\n", + " [-7.71245658e-02, -9.50605869e-02],\n", + " [-4.79652673e-01, -2.33250588e-01],\n", + " [-1.02507687e+00, 7.99317434e-02],\n", + " [-2.62450218e-01, 3.25194120e-01],\n", + " [ 9.65442061e-02, -2.79754162e-01],\n", + " [-2.77321483e-03, 7.49920383e-02],\n", + " [-2.35317081e-01, -1.08117975e-01],\n", + " [ 8.75033140e-02, -2.02557355e-01],\n", + " [ 2.60823607e-01, 1.06511045e+00],\n", + " [ 1.19672045e-01, -7.36494482e-01],\n", + " [-8.79814923e-02, 1.04258955e-01],\n", + " [ 2.59513497e-01, -5.82862794e-01],\n", + " [-1.21953085e-01, -3.64049733e-01],\n", + " [-3.48663330e-02, 3.19325000e-01],\n", + " [-1.66834649e-02, 1.80612490e-01],\n", + " [-4.52537745e-01, 1.31039545e-01],\n", + " [-3.82006407e-01, -5.68586946e-01],\n", + " [-9.27098870e-01, 9.98930261e-02],\n", + " [-8.85202736e-02, -1.10199966e-01],\n", + " [-3.87514569e-02, 8.42708707e-01],\n", + " [-1.15971282e-01, 7.20808744e-01],\n", + " [-2.45430186e-01, -6.77743077e-01],\n", + " [ 2.67549545e-01, -7.38927364e-01],\n", + " [-7.07757547e-02, -2.12103873e-01],\n", + " [ 1.02099039e-01, 5.50257146e-01],\n", + " [-4.43446413e-02, -9.41533327e-01],\n", + " [ 3.20424363e-02, 3.74172240e-01],\n", + " [-1.73633844e-01, 9.58132327e-01],\n", + " [-7.56465554e-01, -2.14453653e-01],\n", + " [-4.80762124e-01, -4.16752875e-01],\n", + " [-3.24135572e-02, 9.57767427e-01],\n", + " [-4.17075306e-01, -1.72911614e-01],\n", + " [-3.41798395e-01, 7.38798976e-01],\n", + " [ 6.08528070e-02, -9.97184515e-01],\n", + " [ 2.30055511e-01, 6.08631432e-01],\n", + " [-3.72935683e-01, 7.28057206e-01],\n", + " [-6.15877926e-01, -3.93291622e-01]], dtype=float32)}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T22:06:35.724356Z", + "start_time": "2024-10-14T22:06:35.716334Z" + } + }, + "cell_type": "code", + "source": [ + "# data_adapter = configurator\n", + "data_adapter = bf.ContinuousApproximator.build_data_adapter(\n", + " inference_variables=[\"theta\"],\n", + " inference_conditions=[\"x\"],\n", + ")" + ], + "id": "461e6dfcdf6944b", + "outputs": [], + "execution_count": 57 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:27:53.903783Z", + "start_time": "2024-10-14T20:27:53.883195Z" + } + }, + "cell_type": "code", + "source": [ + "num_training_batches = 1024\n", + "num_validation_batches = 256\n", + "batch_size = 128" + ], + "id": "ed2cec2c3fdedb22", + "outputs": [], + "execution_count": 8 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:28:10.611306Z", + "start_time": "2024-10-14T20:27:54.318482Z" + } + }, + "cell_type": "code", + "source": [ + "training_samples = simulator.sample((num_training_batches * batch_size,))\n", + "validation_samples = simulator.sample((num_validation_batches * batch_size,))" + ], + "id": "7d1bffc7f17b5aaa", + "outputs": [], + "execution_count": 9 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:28:22.181678Z", + "start_time": "2024-10-14T20:28:22.155164Z" + } + }, + "cell_type": "code", + "source": [ + "training_dataset = bf.datasets.OfflineDataset(training_samples, batch_size=batch_size, data_adapter=data_adapter)\n", + "validation_dataset = bf.datasets.OfflineDataset(validation_samples, batch_size=batch_size, data_adapter=data_adapter)" + ], + "id": "d7f545fd2ee536d8", + "outputs": [], + "execution_count": 10 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:28:23.402981Z", + "start_time": "2024-10-14T20:28:23.353636Z" + } + }, + "cell_type": "code", + "source": [ + "inference_network = bf.networks.FlowMatching(\n", + " subnet=\"mlp\",\n", + " subnet_kwargs=dict(\n", + " depth=6,\n", + " width=256,\n", + " ),\n", + ")" + ], + "id": "be6ed75d4d899021", + "outputs": [], + "execution_count": 11 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:28:24.123891Z", + "start_time": "2024-10-14T20:28:24.106889Z" + } + }, + "cell_type": "code", + "source": [ + "# Approximator is equivalent to Amortizer\n", + "approximator = bf.ContinuousApproximator(\n", + " inference_network=inference_network,\n", + " data_adapter=data_adapter,\n", + ")" + ], + "id": "b1dc4f27eb17b270", + "outputs": [], + "execution_count": 12 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:28:25.265638Z", + "start_time": "2024-10-14T20:28:25.223633Z" + } + }, + "cell_type": "code", + "source": [ + "learning_rate = 1e-4\n", + "optimizer = keras.optimizers.Adam(learning_rate=learning_rate)" + ], + "id": "ad75c807a7617e0a", + "outputs": [], + "execution_count": 13 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:28:26.705303Z", + "start_time": "2024-10-14T20:28:26.693791Z" + } + }, + "cell_type": "code", + "source": [ + "class BatchLossHistory(keras.callbacks.Callback):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.batch_losses = []\n", + "\n", + " def on_train_batch_end(self, batch, logs=None):\n", + " # 'logs' is a dictionary containing loss and other metrics\n", + " loss = logs.get('loss')\n", + " self.batch_losses.append(loss)" + ], + "id": "9d08447b96c58cf4", + "outputs": [], + "execution_count": 14 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:28:28.279798Z", + "start_time": "2024-10-14T20:28:28.258494Z" + } + }, + "cell_type": "code", + "source": [ + "approximator.compile(\n", + " optimizer=optimizer,\n", + " loss=\"sparse_categorical_crossentropy\"\n", + ")" + ], + "id": "120d9b0fed8a8a01", + "outputs": [], + "execution_count": 15 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:28:28.698900Z", + "start_time": "2024-10-14T20:28:28.690911Z" + } + }, + "cell_type": "code", + "source": "batch_loss_history = BatchLossHistory()", + "id": "50d2d9f6d6419075", + "outputs": [], + "execution_count": 16 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:38:09.636742Z", + "start_time": "2024-10-14T20:28:29.157027Z" + } + }, + "cell_type": "code", + "source": [ + "history = approximator.fit(\n", + " epochs=30,\n", + " dataset=training_dataset,\n", + " validation_data=validation_dataset,\n", + " callbacks=[batch_loss_history]\n", + ")" + ], + "id": "b80eda4adc2ad620", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:bayesflow:Fitting on dataset instance of OfflineDataset.\n", + "INFO:bayesflow:Building on a test batch.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 16ms/step - loss: 0.6919 - loss/inference_loss: 0.6919 - val_loss: 0.6134 - val_loss/inference_loss: 0.6134\n", + "Epoch 2/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m18s\u001B[0m 18ms/step - loss: 0.6234 - loss/inference_loss: 0.6234 - val_loss: 0.6321 - val_loss/inference_loss: 0.6321\n", + "Epoch 3/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m18s\u001B[0m 18ms/step - loss: 0.6018 - loss/inference_loss: 0.6018 - val_loss: 0.4567 - val_loss/inference_loss: 0.4567\n", + "Epoch 4/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m21s\u001B[0m 20ms/step - loss: 0.6079 - loss/inference_loss: 0.6079 - val_loss: 0.6692 - val_loss/inference_loss: 0.6692\n", + "Epoch 5/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m24s\u001B[0m 24ms/step - loss: 0.5956 - loss/inference_loss: 0.5956 - val_loss: 0.7312 - val_loss/inference_loss: 0.7312\n", + "Epoch 6/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m22s\u001B[0m 22ms/step - loss: 0.5911 - loss/inference_loss: 0.5911 - val_loss: 0.5461 - val_loss/inference_loss: 0.5461\n", + "Epoch 7/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m22s\u001B[0m 22ms/step - loss: 0.5907 - loss/inference_loss: 0.5907 - val_loss: 0.5829 - val_loss/inference_loss: 0.5829\n", + "Epoch 8/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m24s\u001B[0m 24ms/step - loss: 0.5820 - loss/inference_loss: 0.5820 - val_loss: 0.7137 - val_loss/inference_loss: 0.7137\n", + "Epoch 9/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m25s\u001B[0m 25ms/step - loss: 0.5801 - loss/inference_loss: 0.5801 - val_loss: 0.5453 - val_loss/inference_loss: 0.5453\n", + "Epoch 10/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m26s\u001B[0m 25ms/step - loss: 0.5841 - loss/inference_loss: 0.5841 - val_loss: 0.6155 - val_loss/inference_loss: 0.6155\n", + "Epoch 11/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 17ms/step - loss: 0.5748 - loss/inference_loss: 0.5748 - val_loss: 0.4574 - val_loss/inference_loss: 0.4574\n", + "Epoch 12/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 16ms/step - loss: 0.5714 - loss/inference_loss: 0.5714 - val_loss: 0.9205 - val_loss/inference_loss: 0.9205\n", + "Epoch 13/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m23s\u001B[0m 22ms/step - loss: 0.5804 - loss/inference_loss: 0.5804 - val_loss: 0.4696 - val_loss/inference_loss: 0.4696\n", + "Epoch 14/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m20s\u001B[0m 19ms/step - loss: 0.5691 - loss/inference_loss: 0.5691 - val_loss: 0.5795 - val_loss/inference_loss: 0.5795\n", + "Epoch 15/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m22s\u001B[0m 22ms/step - loss: 0.5663 - loss/inference_loss: 0.5663 - val_loss: 0.7035 - val_loss/inference_loss: 0.7035\n", + "Epoch 16/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m23s\u001B[0m 22ms/step - loss: 0.5692 - loss/inference_loss: 0.5692 - val_loss: 0.6051 - val_loss/inference_loss: 0.6051\n", + "Epoch 17/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 17ms/step - loss: 0.5635 - loss/inference_loss: 0.5635 - val_loss: 0.5303 - val_loss/inference_loss: 0.5303\n", + "Epoch 18/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 17ms/step - loss: 0.5730 - loss/inference_loss: 0.5730 - val_loss: 0.4921 - val_loss/inference_loss: 0.4921\n", + "Epoch 19/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m24s\u001B[0m 23ms/step - loss: 0.5641 - loss/inference_loss: 0.5641 - val_loss: 0.5474 - val_loss/inference_loss: 0.5474\n", + "Epoch 20/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 15ms/step - loss: 0.5669 - loss/inference_loss: 0.5669 - val_loss: 0.5979 - val_loss/inference_loss: 0.5979\n", + "Epoch 21/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5698 - loss/inference_loss: 0.5698 - val_loss: 0.6764 - val_loss/inference_loss: 0.6764\n", + "Epoch 22/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m15s\u001B[0m 15ms/step - loss: 0.5697 - loss/inference_loss: 0.5697 - val_loss: 0.5636 - val_loss/inference_loss: 0.5636\n", + "Epoch 23/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m19s\u001B[0m 19ms/step - loss: 0.5697 - loss/inference_loss: 0.5697 - val_loss: 0.5355 - val_loss/inference_loss: 0.5355\n", + "Epoch 24/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m18s\u001B[0m 17ms/step - loss: 0.5623 - loss/inference_loss: 0.5623 - val_loss: 0.4090 - val_loss/inference_loss: 0.4090\n", + "Epoch 25/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5686 - loss/inference_loss: 0.5686 - val_loss: 0.5841 - val_loss/inference_loss: 0.5841\n", + "Epoch 26/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5650 - loss/inference_loss: 0.5650 - val_loss: 0.5608 - val_loss/inference_loss: 0.5608\n", + "Epoch 27/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 16ms/step - loss: 0.5646 - loss/inference_loss: 0.5646 - val_loss: 0.5898 - val_loss/inference_loss: 0.5898\n", + "Epoch 28/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5608 - loss/inference_loss: 0.5608 - val_loss: 0.3862 - val_loss/inference_loss: 0.3862\n", + "Epoch 29/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5666 - loss/inference_loss: 0.5666 - val_loss: 0.5265 - val_loss/inference_loss: 0.5265\n", + "Epoch 30/30\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5673 - loss/inference_loss: 0.5673 - val_loss: 0.7562 - val_loss/inference_loss: 0.7562\n" + ] + } + ], + "execution_count": 17 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T20:38:43.704457Z", + "start_time": "2024-10-14T20:38:42.786411Z" + } + }, + "cell_type": "code", + "source": "plt.plot(batch_loss_history.batch_losses)", + "id": "3bc7cb16f130a630", + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 18 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Validation", + "id": "451fc9fda7232b4f" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T22:04:54.358194Z", + "start_time": "2024-10-14T22:04:53.528807Z" + } + }, + "cell_type": "code", + "source": [ + "# Set the number of posterior draws you want to get\n", + "num_samples = 500\n", + "\n", + "# Obtain samples from amortized posterior\n", + "conditions = {\"x\": np.array([[0.0, 0.0]]).astype(\"float32\")}\n", + "samples_at_origin = approximator.sample(conditions=conditions, num_samples=num_samples)" + ], + "id": "b72018bbbb9f1fee", + "outputs": [], + "execution_count": 55 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T22:04:58.192926Z", + "start_time": "2024-10-14T22:04:58.170406Z" + } + }, + "cell_type": "code", + "source": "samples_at_origin", + "id": "ee223f7bcd8021ee", + "outputs": [ + { + "data": { + "text/plain": [ + "{'theta': array([[[ 0.17654721, 0.27041006],\n", + " [ 0.12645097, 0.24587561],\n", + " [-0.19934806, -0.278135 ],\n", + " [-0.24838944, -0.24566503],\n", + " [-0.23984583, -0.25675145],\n", + " [ 0.15274203, 0.27565002],\n", + " [-0.17649002, -0.28321353],\n", + " [ 0.2847274 , 0.16072123],\n", + " [ 0.26681536, 0.14605851],\n", + " [ 0.17062895, 0.26429826],\n", + " [-0.28589755, -0.17278582],\n", + " [-0.18308169, -0.29656368],\n", + " [ 0.23003347, 0.24603625],\n", + " [ 0.24050273, 0.21084891],\n", + " [ 0.26269466, 0.12824804],\n", + " [ 0.1484585 , 0.26595044],\n", + " [ 0.26594985, 0.2410507 ],\n", + " [ 0.25229037, 0.24877094],\n", + " [-0.21716174, -0.25466734],\n", + " [ 0.22001 , 0.258107 ],\n", + " [-0.26463276, -0.12307405],\n", + " [ 0.25718766, 0.23288907],\n", + " [-0.23486198, -0.09413613],\n", + " [ 0.27260107, 0.22248606],\n", + " [-0.24408428, -0.2348922 ],\n", + " [ 0.28494632, 0.16640714],\n", + " [-0.23431693, -0.24937971],\n", + " [-0.21345644, -0.26375318],\n", + " [ 0.27926463, 0.21351193],\n", + " [ 0.21171738, 0.28555256],\n", + " [ 0.20336659, 0.23305316],\n", + " [ 0.2076843 , 0.26461732],\n", + " [ 0.24040912, 0.22672929],\n", + " [ 0.26609492, 0.18270998],\n", + " [ 0.27759737, 0.21346761],\n", + " [ 0.2724777 , 0.18199308],\n", + " [ 0.15100089, 0.29426283],\n", + " [-0.17589754, -0.27801663],\n", + " [-0.2785506 , -0.15468442],\n", + " [ 0.12913461, 0.25736886],\n", + " [ 0.1828006 , 0.26205993],\n", + " [-0.22678165, -0.279141 ],\n", + " [-0.22860202, -0.27190363],\n", + " [-0.14008933, -0.2561372 ],\n", + " [-0.22078572, -0.2765645 ],\n", + " [ 0.1633323 , 0.27531624],\n", + " [ 0.245555 , 0.2657401 ],\n", + " [ 0.23057584, 0.24123977],\n", + " [ 0.2830829 , 0.16790009],\n", + " [ 0.2726102 , 0.17767592],\n", + " [-0.2565915 , -0.20570078],\n", + " [-0.27068108, -0.25210744],\n", + " [ 0.25900364, 0.079817 ],\n", + " [ 0.22121005, 0.25927085],\n", + " [ 0.2580096 , 0.2309487 ],\n", + " [-0.1160552 , -0.2476957 ],\n", + " [-0.18104881, -0.27573648],\n", + " [-0.13058631, -0.27183357],\n", + " [-0.25789464, -0.2562468 ],\n", + " [ 0.1564322 , 0.2645709 ],\n", + " [-0.1561471 , -0.2756685 ],\n", + " [ 0.28238744, 0.16419713],\n", + " [ 0.13610052, 0.2615081 ],\n", + " [-0.18580219, -0.2808326 ],\n", + " [-0.13579899, -0.2644095 ],\n", + " [-0.25453556, -0.21007065],\n", + " [-0.26590583, -0.2072551 ],\n", + " [ 0.26141977, 0.164808 ],\n", + " [ 0.27442408, 0.14368123],\n", + " [ 0.27178353, 0.15285778],\n", + " [-0.16379204, -0.26401907],\n", + " [ 0.2755499 , 0.19109674],\n", + " [ 0.24399526, 0.2382826 ],\n", + " [-0.21451049, -0.27771395],\n", + " [-0.21337557, -0.2809109 ],\n", + " [-0.21846037, -0.263758 ],\n", + " [ 0.26772732, 0.18167527],\n", + " [-0.25085008, -0.22487502],\n", + " [-0.23806724, -0.269768 ],\n", + " [ 0.1680724 , 0.2599398 ],\n", + " [ 0.1615107 , 0.28257108],\n", + " [-0.2491922 , -0.09990174],\n", + " [-0.13779363, -0.26620704],\n", + " [-0.1558931 , -0.26355422],\n", + " [-0.21514817, -0.04246794],\n", + " [-0.14641808, -0.26987153],\n", + " [ 0.21003644, 0.26511782],\n", + " [-0.24539289, -0.09259013],\n", + " [-0.14154209, -0.26200366],\n", + " [ 0.2583239 , 0.22858255],\n", + " [-0.19606611, -0.27139133],\n", + " [-0.1492368 , -0.26115084],\n", + " [-0.14322335, -0.25939518],\n", + " [ 0.13081153, 0.25858665],\n", + " [ 0.25838077, 0.12905371],\n", + " [-0.27384835, -0.14071937],\n", + " [-0.2832994 , -0.10836863],\n", + " [-0.24902178, -0.11251126],\n", + " [-0.28539097, -0.1751118 ],\n", + " [ 0.23493315, 0.23637466],\n", + " [ 0.15128584, 0.27237755],\n", + " [-0.22134736, -0.26069474],\n", + " [-0.11475502, -0.24605674],\n", + " [ 0.27256542, 0.16720194],\n", + " [-0.2531284 , -0.10332474],\n", + " [-0.26484224, -0.21712682],\n", + " [-0.1963381 , -0.2724918 ],\n", + " [ 0.28410047, 0.16903692],\n", + " [-0.28119767, -0.18954884],\n", + " [ 0.16959138, 0.2755884 ],\n", + " [-0.20630495, -0.27477667],\n", + " [-0.28206652, -0.17886436],\n", + " [ 0.22695933, 0.2509038 ],\n", + " [ 0.15067881, 0.26440012],\n", + " [ 0.16328944, 0.2806378 ],\n", + " [ 0.26133698, 0.20578645],\n", + " [ 0.2730273 , 0.21775459],\n", + " [-0.26709872, -0.13604781],\n", + " [ 0.27165067, 0.18136705],\n", + " [-0.16354825, -0.2658952 ],\n", + " [ 0.27130795, 0.13281836],\n", + " [-0.26172233, -0.20255604],\n", + " [-0.2656914 , -0.12956285],\n", + " [ 0.26998436, 0.17273536],\n", + " [-0.25801274, -0.24676166],\n", + " [-0.16900623, -0.26649785],\n", + " [ 0.17619887, 0.28156966],\n", + " [-0.16531721, -0.28414413],\n", + " [ 0.26003808, 0.13653943],\n", + " [ 0.26175636, 0.14695114],\n", + " [ 0.13328604, 0.27599776],\n", + " [ 0.23215918, 0.25090832],\n", + " [ 0.26951528, 0.14531177],\n", + " [ 0.27629095, 0.17952204],\n", + " [-0.1899774 , -0.2696848 ],\n", + " [-0.2577222 , -0.2291099 ],\n", + " [ 0.27753985, 0.14961004],\n", + " [-0.26477182, -0.15710792],\n", + " [-0.16838536, -0.2772453 ],\n", + " [ 0.14863718, 0.2671488 ],\n", + " [ 0.16921876, 0.30610895],\n", + " [ 0.25700146, 0.20359705],\n", + " [-0.2723776 , -0.13057159],\n", + " [ 0.15679039, 0.27260846],\n", + " [ 0.26662302, 0.2313111 ],\n", + " [-0.1483945 , -0.26958892],\n", + " [-0.28797525, -0.16708171],\n", + " [-0.2701152 , -0.18412311],\n", + " [-0.22698759, -0.26212072],\n", + " [ 0.26134652, 0.17660096],\n", + " [ 0.24770044, 0.25576842],\n", + " [-0.26324385, -0.17584372],\n", + " [-0.24796484, -0.2644273 ],\n", + " [-0.2512376 , -0.2497181 ],\n", + " [-0.14071809, -0.26300108],\n", + " [ 0.25602716, 0.21307798],\n", + " [ 0.15512946, 0.26241523],\n", + " [-0.13459323, -0.2672544 ],\n", + " [-0.27148777, -0.10496917],\n", + " [-0.26881555, -0.14279647],\n", + " [ 0.2439939 , 0.23070608],\n", + " [ 0.24191315, 0.25623626],\n", + " [-0.26984468, -0.21261044],\n", + " [-0.2708213 , -0.14030349],\n", + " [-0.10411157, -0.25560504],\n", + " [ 0.16901971, 0.2944855 ],\n", + " [-0.23794813, -0.26510406],\n", + " [ 0.24874337, 0.2428212 ],\n", + " [-0.27015492, -0.19753729],\n", + " [ 0.16361678, 0.28820413],\n", + " [-0.2720252 , -0.17723957],\n", + " [-0.24740805, -0.2495067 ],\n", + " [-0.2634244 , -0.1747841 ],\n", + " [-0.16585416, -0.21445206],\n", + " [-0.2081011 , -0.2766696 ],\n", + " [-0.26483506, -0.17030197],\n", + " [-0.28507483, -0.16922492],\n", + " [-0.261419 , -0.1225991 ],\n", + " [-0.16929954, -0.26587838],\n", + " [ 0.2638874 , 0.14496247],\n", + " [-0.20467407, -0.2825095 ],\n", + " [ 0.26162702, 0.21156694],\n", + " [-0.14833814, -0.26648057],\n", + " [-0.2305345 , -0.255768 ],\n", + " [ 0.08747022, 0.24928974],\n", + " [-0.14354517, -0.27147245],\n", + " [-0.26059756, -0.23299775],\n", + " [-0.20088947, -0.26393372],\n", + " [ 0.25847495, 0.12191093],\n", + " [ 0.17978095, 0.26116067],\n", + " [ 0.27232087, 0.13389245],\n", + " [ 0.1689069 , 0.27647942],\n", + " [ 0.20815115, 0.26791954],\n", + " [-0.1713421 , -0.26446718],\n", + " [ 0.16414319, 0.2682889 ],\n", + " [-0.21358459, -0.25886625],\n", + " [-0.21310486, -0.26620245],\n", + " [-0.14630145, -0.27673244],\n", + " [ 0.25637555, 0.10911691],\n", + " [-0.24687341, -0.2515366 ],\n", + " [-0.22562824, -0.2634732 ],\n", + " [-0.24021488, -0.25353155],\n", + " [ 0.27167273, 0.21580078],\n", + " [ 0.23716737, 0.26719362],\n", + " [-0.27758297, -0.19752777],\n", + " [ 0.2810052 , 0.19020875],\n", + " [ 0.26287878, 0.22724666],\n", + " [-0.23488928, -0.250957 ],\n", + " [-0.27531016, -0.17768633],\n", + " [ 0.21583839, 0.27459365],\n", + " [-0.17736377, -0.2707683 ],\n", + " [-0.17810163, -0.27602732],\n", + " [-0.2846307 , -0.13999233],\n", + " [ 0.21645485, 0.26564485],\n", + " [-0.24948609, -0.22531442],\n", + " [-0.25105706, -0.22081058],\n", + " [ 0.15915596, 0.26416653],\n", + " [ 0.11383924, 0.28477448],\n", + " [ 0.21720125, 0.27140385],\n", + " [-0.2634048 , -0.23810792],\n", + " [ 0.250942 , 0.11477937],\n", + " [-0.26577836, -0.18498161],\n", + " [-0.24776752, -0.25986475],\n", + " [-0.26308554, -0.12615162],\n", + " [ 0.25829774, 0.22018121],\n", + " [ 0.27852893, 0.15985395],\n", + " [ 0.13305658, 0.30257642],\n", + " [-0.2717955 , -0.20277436],\n", + " [-0.20948571, -0.2705216 ],\n", + " [ 0.2838514 , 0.30136824],\n", + " [-0.15114638, -0.2609089 ],\n", + " [-0.27054653, -0.2396658 ],\n", + " [ 0.27936798, 0.1702747 ],\n", + " [-0.26180363, -0.23037021],\n", + " [ 0.12675536, 0.29809797],\n", + " [ 0.2860673 , 0.1818126 ],\n", + " [ 0.15882272, 0.2739191 ],\n", + " [-0.2716868 , -0.23455156],\n", + " [ 0.18038103, 0.26399285],\n", + " [-0.19035973, -0.2701561 ],\n", + " [-0.18592411, -0.28720114],\n", + " [-0.28801578, -0.20492856],\n", + " [-0.25421828, -0.252667 ],\n", + " [-0.19853641, -0.27539486],\n", + " [ 0.2658394 , 0.13718288],\n", + " [-0.18788746, -0.26782632],\n", + " [ 0.25122517, 0.22947817],\n", + " [ 0.22315653, 0.2824493 ],\n", + " [ 0.2672273 , 0.19705571],\n", + " [-0.18191774, -0.26465735],\n", + " [-0.16103017, -0.2874695 ],\n", + " [ 0.13860585, 0.26392978],\n", + " [-0.2511997 , -0.25885814],\n", + " [ 0.27468997, 0.16526596],\n", + " [ 0.14169784, 0.2599411 ],\n", + " [ 0.26611072, 0.17375904],\n", + " [-0.2878152 , -0.08115485],\n", + " [-0.28335226, -0.15670936],\n", + " [-0.27376425, -0.178761 ],\n", + " [ 0.27126724, 0.14793816],\n", + " [ 0.2668656 , 0.10891573],\n", + " [-0.15683985, -0.27304047],\n", + " [-0.2727784 , -0.13070625],\n", + " [ 0.2501254 , 0.10839404],\n", + " [-0.15620172, -0.2873354 ],\n", + " [-0.15404768, -0.26731277],\n", + " [ 0.21467187, 0.25746334],\n", + " [ 0.25429416, 0.1797972 ],\n", + " [ 0.20857449, 0.249985 ],\n", + " [ 0.11666797, 0.27871257],\n", + " [ 0.24273743, 0.20371343],\n", + " [ 0.26860923, 0.13919559],\n", + " [ 0.16740836, 0.27165604],\n", + " [ 0.16717246, 0.26184237],\n", + " [-0.21383077, -0.271177 ],\n", + " [-0.2659731 , -0.22285458],\n", + " [ 0.28479463, 0.15989105],\n", + " [ 0.11832219, 0.27264667],\n", + " [-0.2787032 , -0.2067107 ],\n", + " [-0.16163397, -0.27661702],\n", + " [-0.26493195, -0.20964915],\n", + " [-0.21109939, -0.26734334],\n", + " [ 0.26570964, 0.15831824],\n", + " [-0.16366 , -0.27646917],\n", + " [ 0.1530443 , 0.2601412 ],\n", + " [-0.2540462 , -0.24356595],\n", + " [-0.25072587, -0.21606591],\n", + " [-0.25311893, -0.18031111],\n", + " [ 0.28140247, 0.21742497],\n", + " [ 0.13237408, 0.26540524],\n", + " [ 0.2564062 , 0.12839843],\n", + " [ 0.15366873, 0.2635122 ],\n", + " [ 0.22159882, 0.24206953],\n", + " [ 0.26372212, 0.12725428],\n", + " [-0.22098112, -0.27278456],\n", + " [-0.20793442, -0.27573195],\n", + " [ 0.26726925, 0.13613409],\n", + " [ 0.19226794, 0.2606786 ],\n", + " [ 0.13405557, 0.25951523],\n", + " [-0.24498801, -0.10678881],\n", + " [-0.20078732, -0.27810192],\n", + " [ 0.26355267, 0.19066595],\n", + " [-0.24511454, -0.22381541],\n", + " [ 0.20138271, 0.26740623],\n", + " [-0.17140028, -0.27542722],\n", + " [ 0.28308725, 0.13273714],\n", + " [-0.21728876, -0.2650354 ],\n", + " [-0.16061038, -0.23381335],\n", + " [-0.26724768, -0.22253583],\n", + " [-0.2655784 , -0.16959837],\n", + " [-0.24018149, -0.26769286],\n", + " [-0.26636603, -0.18842779],\n", + " [ 0.24131311, 0.28197736],\n", + " [-0.26913485, -0.2285188 ],\n", + " [ 0.11729264, 0.25736594],\n", + " [ 0.13231607, 0.22730176],\n", + " [ 0.26941478, 0.21044381],\n", + " [-0.25148863, -0.2359707 ],\n", + " [ 0.23700745, 0.21061157],\n", + " [ 0.22925268, 0.2627837 ],\n", + " [ 0.24546961, 0.16472188],\n", + " [ 0.26913851, 0.14494902],\n", + " [-0.2042441 , -0.27723473],\n", + " [ 0.27372217, 0.18476877],\n", + " [ 0.23441432, 0.25146925],\n", + " [-0.26290894, -0.21708629],\n", + " [-0.2156923 , -0.2754048 ],\n", + " [ 0.24024172, 0.27976716],\n", + " [-0.23979467, -0.2512276 ],\n", + " [-0.25661737, -0.24269213],\n", + " [ 0.19709219, 0.2668875 ],\n", + " [-0.26447535, -0.22487412],\n", + " [-0.26613384, -0.2496682 ],\n", + " [ 0.25823814, 0.22014786],\n", + " [-0.23661795, -0.2597047 ],\n", + " [-0.18679908, -0.26142985],\n", + " [-0.24493581, -0.18804303],\n", + " [ 0.17663766, 0.26372564],\n", + " [-0.2361322 , -0.26485068],\n", + " [ 0.17371084, 0.27549374],\n", + " [ 0.20716886, 0.28001374],\n", + " [-0.27110183, -0.13153674],\n", + " [ 0.27748924, 0.14194696],\n", + " [-0.20262113, -0.27168894],\n", + " [ 0.27442312, 0.12007011],\n", + " [-0.2616744 , -0.19247435],\n", + " [ 0.13171248, 0.2602473 ],\n", + " [-0.16328776, -0.2698179 ],\n", + " [ 0.12095347, 0.25061035],\n", + " [ 0.15479945, 0.26521897],\n", + " [ 0.26058197, 0.12008516],\n", + " [-0.24259163, -0.25619787],\n", + " [-0.25585926, -0.10046624],\n", + " [-0.09253009, -0.24216235],\n", + " [ 0.2025079 , 0.25464582],\n", + " [ 0.26418608, 0.14600448],\n", + " [-0.13562107, -0.2535293 ],\n", + " [ 0.20196442, 0.28894037],\n", + " [ 0.14610024, 0.2617845 ],\n", + " [-0.19805442, -0.24427016],\n", + " [-0.26247138, -0.14815763],\n", + " [-0.25442994, -0.14696085],\n", + " [-0.27216715, -0.1356835 ],\n", + " [-0.13450277, -0.24344285],\n", + " [ 0.24729492, 0.26272237],\n", + " [ 0.20515431, 0.27405024],\n", + " [ 0.2608167 , 0.15682302],\n", + " [-0.15053959, -0.26844054],\n", + " [-0.18680488, -0.26017186],\n", + " [ 0.22681056, 0.24950348],\n", + " [-0.2625429 , -0.23297158],\n", + " [ 0.2613256 , 0.22275327],\n", + " [-0.15576485, -0.27131522],\n", + " [ 0.27989298, 0.13113242],\n", + " [-0.27222443, -0.22823197],\n", + " [-0.25886652, -0.2011627 ],\n", + " [ 0.26871747, 0.21118324],\n", + " [ 0.15130182, 0.27510864],\n", + " [-0.14449994, -0.28044653],\n", + " [ 0.25875401, 0.13721596],\n", + " [-0.22085567, -0.25945926],\n", + " [-0.24155052, -0.25671858],\n", + " [ 0.24092378, 0.23994891],\n", + " [ 0.29403216, 0.17666245],\n", + " [-0.26444742, -0.11436895],\n", + " [ 0.26498055, 0.18587114],\n", + " [-0.30489963, -0.18212527],\n", + " [ 0.2606231 , 0.12549354],\n", + " [ 0.17819688, 0.2753808 ],\n", + " [-0.2085124 , -0.24076138],\n", + " [ 0.2783888 , 0.21553223],\n", + " [ 0.24207772, 0.27036875],\n", + " [-0.19318947, -0.28491825],\n", + " [-0.28575206, -0.11007828],\n", + " [ 0.20214315, 0.25610656],\n", + " [ 0.25302982, 0.16020659],\n", + " [-0.25691435, -0.09973457],\n", + " [-0.16929191, -0.2823217 ],\n", + " [ 0.28185576, 0.14506812],\n", + " [-0.25279728, -0.09789051],\n", + " [-0.27485964, -0.15160048],\n", + " [-0.2778846 , -0.16473699],\n", + " [-0.17066869, -0.28102553],\n", + " [ 0.13137233, 0.2776695 ],\n", + " [ 0.27647096, 0.17463823],\n", + " [-0.26363677, -0.19969368],\n", + " [ 0.26302052, 0.14910254],\n", + " [ 0.2524321 , 0.26034302],\n", + " [-0.18632376, -0.2879629 ],\n", + " [-0.27709562, -0.19126996],\n", + " [-0.26679304, -0.11288792],\n", + " [-0.13774812, -0.2649699 ],\n", + " [ 0.28189063, 0.20290558],\n", + " [-0.27441102, -0.2103486 ],\n", + " [ 0.2638747 , 0.20561664],\n", + " [ 0.24635617, 0.15771465],\n", + " [-0.2822139 , -0.1699524 ],\n", + " [-0.20473354, -0.21127829],\n", + " [-0.2742424 , -0.22162642],\n", + " [ 0.25963825, 0.23948784],\n", + " [-0.26096076, -0.13017972],\n", + " [-0.16121498, -0.28439793],\n", + " [-0.27414775, -0.14413777],\n", + " [ 0.24034937, 0.19255368],\n", + " [-0.18428135, -0.2854242 ],\n", + " [-0.25247365, -0.2194159 ],\n", + " [ 0.16972582, 0.2716142 ],\n", + " [ 0.27258348, 0.16816002],\n", + " [-0.21479341, -0.2591002 ],\n", + " [-0.22572103, -0.2740637 ],\n", + " [-0.25940806, -0.2296557 ],\n", + " [-0.25289565, -0.10597888],\n", + " [ 0.27644205, 0.17131492],\n", + " [-0.2744681 , -0.20247808],\n", + " [ 0.1569854 , 0.26544535],\n", + " [ 0.16878295, 0.2785138 ],\n", + " [ 0.27115518, 0.19420569],\n", + " [-0.19849882, -0.2581566 ],\n", + " [ 0.2501585 , 0.10598603],\n", + " [ 0.27531207, 0.1523792 ],\n", + " [ 0.2688099 , 0.15870939],\n", + " [-0.25474942, -0.15129292],\n", + " [ 0.23721237, 0.24722971],\n", + " [-0.2737707 , -0.1965245 ],\n", + " [ 0.24654536, 0.2790857 ],\n", + " [ 0.258439 , 0.22061013],\n", + " [-0.27449885, -0.17695272],\n", + " [ 0.20442174, 0.27892023],\n", + " [ 0.17934252, 0.27942443],\n", + " [ 0.13801269, 0.2676139 ],\n", + " [ 0.2189254 , 0.21409418],\n", + " [ 0.26100272, 0.15690793],\n", + " [ 0.25874227, 0.19437556],\n", + " [-0.2684424 , -0.18466231],\n", + " [-0.27459067, -0.13723826],\n", + " [-0.26729393, -0.22947192],\n", + " [-0.23811837, -0.25796393],\n", + " [-0.26904866, -0.1451745 ],\n", + " [-0.21287099, -0.26636374],\n", + " [-0.2356898 , -0.25938934],\n", + " [-0.18526974, -0.26750624],\n", + " [ 0.26033378, 0.14162609],\n", + " [ 0.18297893, 0.12170283],\n", + " [ 0.21097694, 0.24945582],\n", + " [-0.24043491, -0.25133517],\n", + " [-0.13337177, -0.25957477],\n", + " [ 0.21417986, 0.2602194 ],\n", + " [-0.20514008, -0.27646336],\n", + " [ 0.26594204, 0.2504294 ],\n", + " [-0.15494907, -0.27878577],\n", + " [-0.24846888, -0.22488065],\n", + " [ 0.27053243, 0.18878628],\n", + " [-0.20375696, -0.26968008],\n", + " [-0.23099579, -0.25355121],\n", + " [ 0.23137142, 0.23804243],\n", + " [ 0.13809101, 0.26715648],\n", + " [ 0.1641148 , 0.27006269],\n", + " [ 0.27676284, 0.15274183],\n", + " [-0.25447246, -0.09072974],\n", + " [-0.27731633, -0.18285707],\n", + " [-0.1584472 , -0.27314755],\n", + " [-0.23445056, -0.25665188],\n", + " [-0.1045018 , -0.24585408],\n", + " [-0.22868754, -0.14893918],\n", + " [ 0.16116361, 0.27902877],\n", + " [-0.2680272 , -0.13343239],\n", + " [-0.25145337, -0.23630832],\n", + " [-0.27399144, -0.21066816],\n", + " [ 0.16021812, 0.27963275],\n", + " [-0.18627508, -0.25839704],\n", + " [-0.16178118, -0.28439075],\n", + " [ 0.12971437, 0.25599122],\n", + " [ 0.26095128, 0.16551861],\n", + " [-0.13171294, -0.2682826 ],\n", + " [-0.17962739, -0.25391853],\n", + " [ 0.1130407 , 0.25770533],\n", + " [ 0.26060718, 0.13303255],\n", + " [ 0.26106697, 0.11457562],\n", + " [ 0.24399544, 0.22096606],\n", + " [-0.2890758 , -0.15642428]]], dtype=float32)}" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 56 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T21:06:00.343643Z", + "start_time": "2024-10-14T21:06:00.324488Z" + } + }, + "cell_type": "code", + "source": "sample_data[\"theta\"].shape", + "id": "5c47d5541b0a24ee", + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 2)" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 42 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T21:50:10.198765Z", + "start_time": "2024-10-14T21:50:10.176034Z" + } + }, + "cell_type": "code", + "source": "np.reshape(samples_at_origin, (128, 2, 1)).shape", + "id": "51c802b500199423", + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 2, 1)" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 50 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-14T21:51:42.186086Z", + "start_time": "2024-10-14T21:51:40.305203Z" + } + }, + "cell_type": "code", + "source": "f = plot_recovery(post_samples=np.swapaxes(samples_at_origin, 0, 1), prior_samples=np.array(sample_data['theta']))", + "id": "e32b56a7ec84701b", + "outputs": [ + { + "ename": "TypeError", + "evalue": "unsupported operand type(s) for -: 'NoneType' and 'int'", + "output_type": "error", + "traceback": [ + "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[1;31mTypeError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[1;32mIn[53], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m f \u001B[38;5;241m=\u001B[39m \u001B[43mplot_recovery\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpost_samples\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mswapaxes\u001B[49m\u001B[43m(\u001B[49m\u001B[43msamples_at_origin\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mprior_samples\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43marray\u001B[49m\u001B[43m(\u001B[49m\u001B[43msample_data\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mtheta\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\diagnostics\\plot_recovery.py:162\u001B[0m, in \u001B[0;36mplot_recovery\u001B[1;34m(post_samples, prior_samples, point_agg, uncertainty_agg, param_names, fig_size, label_fontsize, title_fontsize, metric_fontsize, tick_fontsize, add_corr, add_r2, color, n_col, n_row, xlabel, ylabel, **kwargs)\u001B[0m\n\u001B[0;32m 159\u001B[0m ax\u001B[38;5;241m.\u001B[39mtick_params(axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mboth\u001B[39m\u001B[38;5;124m\"\u001B[39m, which\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mmajor\u001B[39m\u001B[38;5;124m\"\u001B[39m, labelsize\u001B[38;5;241m=\u001B[39mtick_fontsize)\n\u001B[0;32m 160\u001B[0m ax\u001B[38;5;241m.\u001B[39mtick_params(axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mboth\u001B[39m\u001B[38;5;124m\"\u001B[39m, which\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mminor\u001B[39m\u001B[38;5;124m\"\u001B[39m, labelsize\u001B[38;5;241m=\u001B[39mtick_fontsize)\n\u001B[1;32m--> 162\u001B[0m \u001B[43mpostprocess\u001B[49m\u001B[43m(\u001B[49m\u001B[43maxarr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maxarr_it\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mn_row\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mn_col\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mn_params\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mxlabel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mylabel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlabel_fontsize\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 164\u001B[0m f\u001B[38;5;241m.\u001B[39mtight_layout()\n\u001B[0;32m 165\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m f\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\utils\\plot_utils.py:299\u001B[0m, in \u001B[0;36mpostprocess\u001B[1;34m(*args)\u001B[0m\n\u001B[0;32m 294\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mpostprocess\u001B[39m(\u001B[38;5;241m*\u001B[39margs):\n\u001B[0;32m 295\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 296\u001B[0m \u001B[38;5;124;03m Procedural wrapper for postprocessing steps, including adding labels and removing unused axes.\u001B[39;00m\n\u001B[0;32m 297\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n\u001B[1;32m--> 299\u001B[0m \u001B[43madd_labels\u001B[49m\u001B[43m(\u001B[49m\u001B[43margs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 300\u001B[0m remove_unused_axes(args)\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\utils\\plot_utils.py:239\u001B[0m, in \u001B[0;36madd_labels\u001B[1;34m(axarr, n_row, n_col, xlabel, ylabel, label_fontsize)\u001B[0m\n\u001B[0;32m 228\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21madd_labels\u001B[39m(\n\u001B[0;32m 229\u001B[0m axarr,\n\u001B[0;32m 230\u001B[0m n_row: \u001B[38;5;28mint\u001B[39m \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 234\u001B[0m label_fontsize: \u001B[38;5;28mint\u001B[39m \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m 235\u001B[0m ):\n\u001B[0;32m 236\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 237\u001B[0m \u001B[38;5;124;03m Wrapper function for configuring labels for both axes.\u001B[39;00m\n\u001B[0;32m 238\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n\u001B[1;32m--> 239\u001B[0m \u001B[43madd_xlabels\u001B[49m\u001B[43m(\u001B[49m\u001B[43maxarr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mn_row\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mn_col\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mxlabel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlabel_fontsize\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 240\u001B[0m add_ylabels(axarr, n_row, ylabel, label_fontsize)\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\utils\\plot_utils.py:208\u001B[0m, in \u001B[0;36madd_xlabels\u001B[1;34m(axarr, n_row, n_col, xlabel, label_fontsize)\u001B[0m\n\u001B[0;32m 200\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21madd_xlabels\u001B[39m(\n\u001B[0;32m 201\u001B[0m axarr,\n\u001B[0;32m 202\u001B[0m n_row: \u001B[38;5;28mint\u001B[39m \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 206\u001B[0m ):\n\u001B[0;32m 207\u001B[0m \u001B[38;5;66;03m# Only add x-labels to the bottom row\u001B[39;00m\n\u001B[1;32m--> 208\u001B[0m bottom_row \u001B[38;5;241m=\u001B[39m axarr \u001B[38;5;28;01mif\u001B[39;00m n_row \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m1\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m axarr[\u001B[38;5;241m0\u001B[39m] \u001B[38;5;28;01mif\u001B[39;00m n_col \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m1\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m axarr[\u001B[43mn_row\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m-\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;241;43m1\u001B[39;49m, :]\n\u001B[0;32m 209\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m _ax \u001B[38;5;129;01min\u001B[39;00m bottom_row:\n\u001B[0;32m 210\u001B[0m _ax\u001B[38;5;241m.\u001B[39mset_xlabel(xlabel, fontsize\u001B[38;5;241m=\u001B[39mlabel_fontsize)\n", + "\u001B[1;31mTypeError\u001B[0m: unsupported operand type(s) for -: 'NoneType' and 'int'" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 53 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "977301bbdf313eeb" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/TwoMoons_FlowMatching.ipynb b/examples/TwoMoons_FlowMatching.ipynb index 4fcfd655d..e80477087 100644 --- a/examples/TwoMoons_FlowMatching.ipynb +++ b/examples/TwoMoons_FlowMatching.ipynb @@ -10,28 +10,13 @@ }, { "cell_type": "code", - "execution_count": 1, "id": "d5f88a59", "metadata": { "ExecuteTime": { - "end_time": "2024-09-23T14:39:46.551814Z", - "start_time": "2024-09-23T14:39:46.032170Z" + "end_time": "2024-10-14T00:04:09.950062Z", + "start_time": "2024-10-14T00:03:55.013171Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:\n", - "Outdated cuDNN installation found.\n", - "Version JAX was built against: 8907\n", - "Minimum supported: 9100\n", - "Installed version: 8907\n", - "The local installation version must be no lower than 9100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - } - ], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -50,7 +35,9 @@ "sys.path.append('../')\n", "\n", "import bayesflow as bf" - ] + ], + "outputs": [], + "execution_count": 1 }, { "cell_type": "markdown", @@ -435,65 +422,65 @@ "output_type": "stream", "text": [ "Epoch 1/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 6ms/step - loss: 0.6938 - loss/inference_loss: 0.6938 - val_loss: 0.5508 - val_loss/inference_loss: 0.5508\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 6ms/step - loss: 0.6938 - loss/inference_loss: 0.6938 - val_loss: 0.5508 - val_loss/inference_loss: 0.5508\n", "Epoch 2/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.6250 - loss/inference_loss: 0.6250 - val_loss: 0.6023 - val_loss/inference_loss: 0.6023\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.6250 - loss/inference_loss: 0.6250 - val_loss: 0.6023 - val_loss/inference_loss: 0.6023\n", "Epoch 3/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.6056 - loss/inference_loss: 0.6056 - val_loss: 0.4454 - val_loss/inference_loss: 0.4454\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.6056 - loss/inference_loss: 0.6056 - val_loss: 0.4454 - val_loss/inference_loss: 0.4454\n", "Epoch 4/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.6006 - loss/inference_loss: 0.6006 - val_loss: 0.5079 - val_loss/inference_loss: 0.5079\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.6006 - loss/inference_loss: 0.6006 - val_loss: 0.5079 - val_loss/inference_loss: 0.5079\n", "Epoch 5/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.6020 - loss/inference_loss: 0.6020 - val_loss: 0.5414 - val_loss/inference_loss: 0.5414\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.6020 - loss/inference_loss: 0.6020 - val_loss: 0.5414 - val_loss/inference_loss: 0.5414\n", "Epoch 6/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5973 - loss/inference_loss: 0.5973 - val_loss: 0.6961 - val_loss/inference_loss: 0.6961\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5973 - loss/inference_loss: 0.5973 - val_loss: 0.6961 - val_loss/inference_loss: 0.6961\n", "Epoch 7/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5874 - loss/inference_loss: 0.5874 - val_loss: 0.5399 - val_loss/inference_loss: 0.5399\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5874 - loss/inference_loss: 0.5874 - val_loss: 0.5399 - val_loss/inference_loss: 0.5399\n", "Epoch 8/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5939 - loss/inference_loss: 0.5939 - val_loss: 0.4877 - val_loss/inference_loss: 0.4877\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5939 - loss/inference_loss: 0.5939 - val_loss: 0.4877 - val_loss/inference_loss: 0.4877\n", "Epoch 9/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5841 - loss/inference_loss: 0.5841 - val_loss: 0.5115 - val_loss/inference_loss: 0.5115\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5841 - loss/inference_loss: 0.5841 - val_loss: 0.5115 - val_loss/inference_loss: 0.5115\n", "Epoch 10/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5827 - loss/inference_loss: 0.5827 - val_loss: 0.5383 - val_loss/inference_loss: 0.5383\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5827 - loss/inference_loss: 0.5827 - val_loss: 0.5383 - val_loss/inference_loss: 0.5383\n", "Epoch 11/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5807 - loss/inference_loss: 0.5807 - val_loss: 0.4411 - val_loss/inference_loss: 0.4411\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5807 - loss/inference_loss: 0.5807 - val_loss: 0.4411 - val_loss/inference_loss: 0.4411\n", "Epoch 12/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5774 - loss/inference_loss: 0.5774 - val_loss: 0.5844 - val_loss/inference_loss: 0.5844\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5774 - loss/inference_loss: 0.5774 - val_loss: 0.5844 - val_loss/inference_loss: 0.5844\n", "Epoch 13/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5813 - loss/inference_loss: 0.5813 - val_loss: 0.8106 - val_loss/inference_loss: 0.8106\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5813 - loss/inference_loss: 0.5813 - val_loss: 0.8106 - val_loss/inference_loss: 0.8106\n", "Epoch 14/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 5ms/step - loss: 0.5756 - loss/inference_loss: 0.5756 - val_loss: 0.4150 - val_loss/inference_loss: 0.4150\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 5ms/step - loss: 0.5756 - loss/inference_loss: 0.5756 - val_loss: 0.4150 - val_loss/inference_loss: 0.4150\n", "Epoch 15/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 5ms/step - loss: 0.5761 - loss/inference_loss: 0.5761 - val_loss: 0.5451 - val_loss/inference_loss: 0.5451\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 5ms/step - loss: 0.5761 - loss/inference_loss: 0.5761 - val_loss: 0.5451 - val_loss/inference_loss: 0.5451\n", "Epoch 16/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5747 - loss/inference_loss: 0.5747 - val_loss: 0.6248 - val_loss/inference_loss: 0.6248\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5747 - loss/inference_loss: 0.5747 - val_loss: 0.6248 - val_loss/inference_loss: 0.6248\n", "Epoch 17/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5666 - loss/inference_loss: 0.5666 - val_loss: 0.4689 - val_loss/inference_loss: 0.4689\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5666 - loss/inference_loss: 0.5666 - val_loss: 0.4689 - val_loss/inference_loss: 0.4689\n", "Epoch 18/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5705 - loss/inference_loss: 0.5705 - val_loss: 0.3853 - val_loss/inference_loss: 0.3853\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5705 - loss/inference_loss: 0.5705 - val_loss: 0.3853 - val_loss/inference_loss: 0.3853\n", "Epoch 19/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5739 - loss/inference_loss: 0.5739 - val_loss: 0.5055 - val_loss/inference_loss: 0.5055\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5739 - loss/inference_loss: 0.5739 - val_loss: 0.5055 - val_loss/inference_loss: 0.5055\n", "Epoch 20/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5688 - loss/inference_loss: 0.5688 - val_loss: 0.5032 - val_loss/inference_loss: 0.5032\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5688 - loss/inference_loss: 0.5688 - val_loss: 0.5032 - val_loss/inference_loss: 0.5032\n", "Epoch 21/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5663 - loss/inference_loss: 0.5663 - val_loss: 0.5237 - val_loss/inference_loss: 0.5237\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5663 - loss/inference_loss: 0.5663 - val_loss: 0.5237 - val_loss/inference_loss: 0.5237\n", "Epoch 22/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5650 - loss/inference_loss: 0.5650 - val_loss: 0.3955 - val_loss/inference_loss: 0.3955\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5650 - loss/inference_loss: 0.5650 - val_loss: 0.3955 - val_loss/inference_loss: 0.3955\n", "Epoch 23/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5666 - loss/inference_loss: 0.5666 - val_loss: 0.7317 - val_loss/inference_loss: 0.7317\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5666 - loss/inference_loss: 0.5666 - val_loss: 0.7317 - val_loss/inference_loss: 0.7317\n", "Epoch 24/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5632 - loss/inference_loss: 0.5632 - val_loss: 0.6094 - val_loss/inference_loss: 0.6094\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5632 - loss/inference_loss: 0.5632 - val_loss: 0.6094 - val_loss/inference_loss: 0.6094\n", "Epoch 25/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5701 - loss/inference_loss: 0.5701 - val_loss: 0.5721 - val_loss/inference_loss: 0.5721\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5701 - loss/inference_loss: 0.5701 - val_loss: 0.5721 - val_loss/inference_loss: 0.5721\n", "Epoch 26/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5711 - loss/inference_loss: 0.5711 - val_loss: 0.6184 - val_loss/inference_loss: 0.6184\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5711 - loss/inference_loss: 0.5711 - val_loss: 0.6184 - val_loss/inference_loss: 0.6184\n", "Epoch 27/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5672 - loss/inference_loss: 0.5672 - val_loss: 0.6326 - val_loss/inference_loss: 0.6326\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5672 - loss/inference_loss: 0.5672 - val_loss: 0.6326 - val_loss/inference_loss: 0.6326\n", "Epoch 28/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5585 - loss/inference_loss: 0.5585 - val_loss: 0.6209 - val_loss/inference_loss: 0.6209\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5585 - loss/inference_loss: 0.5585 - val_loss: 0.6209 - val_loss/inference_loss: 0.6209\n", "Epoch 29/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5594 - loss/inference_loss: 0.5594 - val_loss: 0.5672 - val_loss/inference_loss: 0.5672\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5594 - loss/inference_loss: 0.5594 - val_loss: 0.5672 - val_loss/inference_loss: 0.5672\n", "Epoch 30/30\n", - "\u001b[1m1024/1024\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 5ms/step - loss: 0.5597 - loss/inference_loss: 0.5597 - val_loss: 0.4648 - val_loss/inference_loss: 0.4648\n" + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m5s\u001B[0m 5ms/step - loss: 0.5597 - loss/inference_loss: 0.5597 - val_loss: 0.4648 - val_loss/inference_loss: 0.4648\n" ] } ], From 8d38eff9572ec237fa698ec0eca2299683207bfd Mon Sep 17 00:00:00 2001 From: Jerry Date: Tue, 22 Oct 2024 14:10:54 -0400 Subject: [PATCH 10/22] WIP notebooks for Diagnostics testing --- bayesflow/diagnostics/__init__.py | 2 +- bayesflow/diagnostics/plot_losses.py | 37 +- bayesflow/diagnostics/plot_recovery.py | 91 +- examples/Quickstart_Diagnostics.ipynb | 695 +++++++++++++ examples/TwoMoons_Diagnostics.ipynb | 1279 ++++++++++++------------ 5 files changed, 1413 insertions(+), 691 deletions(-) create mode 100644 examples/Quickstart_Diagnostics.ipynb diff --git a/bayesflow/diagnostics/__init__.py b/bayesflow/diagnostics/__init__.py index 28adaf7ba..5e4fdbce2 100644 --- a/bayesflow/diagnostics/__init__.py +++ b/bayesflow/diagnostics/__init__.py @@ -1,2 +1,2 @@ from .plot_losses import plot_losses -from .plot_recovery import plot_recovery \ No newline at end of file +from .plot_recovery import plot_recovery diff --git a/bayesflow/diagnostics/plot_losses.py b/bayesflow/diagnostics/plot_losses.py index d778e1e8b..3797b0257 100644 --- a/bayesflow/diagnostics/plot_losses.py +++ b/bayesflow/diagnostics/plot_losses.py @@ -1,5 +1,7 @@ +import numpy as np import seaborn as sns +import matplotlib.pyplot as plt from tensorflow.keras import ops from ..utils.plot_utils import initialize_figure @@ -8,16 +10,17 @@ def plot_losses( train_losses, val_losses=None, - moving_average: bool = False, - ma_window_fraction: float = 0.01, - train_color: str = "#8f2727", - val_color: str = "black", - lw_train: int = 2, - lw_val: int = 3, - grid_alpha: float = 0.5, - legend_fontsize: int = 14, - label_fontsize: int = 14, - title_fontsize: int = 16, + moving_average=False, + ma_window_fraction=0.01, + fig_size=None, + train_color="#8f2727", + val_color="black", + lw_train=2, + lw_val=3, + grid_alpha=0.5, + legend_fontsize=14, + label_fontsize=14, + title_fontsize=16, ): """A generic helper function to plot the losses of a series of training epochs and runs. @@ -70,17 +73,15 @@ def plot_losses( n_row = len(train_losses.columns) # Initialize figure - f, axarr = initialize_figure(n_row=n_row, n_col=1, fig_size=(16, int(4 * n_row))) - - # if fig_size is None: - # fig_size = (16, int(4 * n_row)) - # f, axarr = plt.subplots(n_row, 1, figsize=fig_size) + if fig_size is None: + fig_size = (16, int(4 * n_row)) + f, axarr = plt.subplots(n_row, 1, figsize=fig_size) # Get the number of steps as an array - train_step_index = ops.arange(1, len(train_losses) + 1) + train_step_index = np.arange(1, len(train_losses) + 1) if val_losses is not None: - val_step = int(ops.floor(len(train_losses) / len(val_losses))) - val_step_index = train_step_index[(val_step - 1) :: val_step] + val_step = int(np.floor(len(train_losses) / len(val_losses))) + val_step_index = train_step_index[(val_step - 1)::val_step] # If unequal length due to some reason, attempt a fix if val_step_index.shape[0] > val_losses.shape[0]: diff --git a/bayesflow/diagnostics/plot_recovery.py b/bayesflow/diagnostics/plot_recovery.py index 2001c95e6..4031a7542 100644 --- a/bayesflow/diagnostics/plot_recovery.py +++ b/bayesflow/diagnostics/plot_recovery.py @@ -2,30 +2,31 @@ import numpy as np from scipy.stats import median_abs_deviation from sklearn.metrics import r2_score +import matplotlib.pyplot as plt import seaborn as sns from ..utils.plot_utils import preprocess, postprocess - +from ..utils.plot_utils import check_posterior_prior_shapes def plot_recovery( - post_samples, - prior_samples, - point_agg=np.median, - uncertainty_agg=median_abs_deviation, - param_names: list = None, - fig_size: tuple = None, - label_fontsize: int = 16, - title_fontsize: int = 18, - metric_fontsize: int = 16, - tick_fontsize: int = 12, - add_corr: bool = True, - add_r2: bool = True, - color: str | tuple = "#8f2727", - n_col: int = None, - n_row: int = None, - xlabel: str = "Ground truth", - ylabel: str = "Estimated", - **kwargs, + post_samples, + prior_samples, + point_agg=np.median, + uncertainty_agg=median_abs_deviation, + param_names=None, + fig_size=None, + label_fontsize=16, + title_fontsize=18, + metric_fontsize=16, + tick_fontsize=12, + add_corr=True, + add_r2=True, + color="#8f2727", + n_col=None, + n_row=None, + xlabel="Ground truth", + ylabel="Estimated", + **kwargs, ): """Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty. The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate @@ -93,17 +94,40 @@ def plot_recovery( If there is a deviation from the expected shapes of ``post_samples`` and ``prior_samples``. """ - # Preprocess - f, axarr, axarr_it, n_row, n_col, n_params, param_names = preprocess( - post_samples, prior_samples, fig_size=fig_size - ) + # Sanity check + check_posterior_prior_shapes(post_samples, prior_samples) # Compute point estimates and uncertainties est = point_agg(post_samples, axis=1) if uncertainty_agg is not None: u = uncertainty_agg(post_samples, axis=1) - # Loop and plot + # Determine n params and param names if None given + n_params = prior_samples.shape[-1] + if param_names is None: + param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] + + # Determine number of rows and columns for subplots based on inputs + if n_row is None and n_col is None: + n_row = int(np.ceil(n_params / 6)) + n_col = int(np.ceil(n_params / n_row)) + elif n_row is None and n_col is not None: + n_row = int(np.ceil(n_params / n_col)) + elif n_row is not None and n_col is None: + n_col = int(np.ceil(n_params / n_row)) + + # Initialize figure + if fig_size is None: + fig_size = (int(4 * n_col), int(4 * n_row)) + f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) + + # turn axarr into 1D list + axarr = np.atleast_1d(axarr) + if n_col > 1 or n_row > 1: + axarr_it = axarr.flat + else: + axarr_it = axarr + for i, ax in enumerate(axarr_it): if i >= n_params: break @@ -159,7 +183,22 @@ def plot_recovery( ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) - postprocess(axarr, axarr_it, n_row, n_col, n_params, xlabel, ylabel, label_fontsize) + # Only add x-labels to the bottom row + bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] + for _ax in bottom_row: + _ax.set_xlabel(xlabel, fontsize=label_fontsize) + + # Only add y-labels to right left-most row + if n_row == 1: # if there is only one row, the ax array is 1D + axarr[0].set_ylabel(ylabel, fontsize=label_fontsize) + # If there is more than one row, the ax array is 2D + else: + for _ax in axarr[:, 0]: + _ax.set_ylabel(ylabel, fontsize=label_fontsize) + + # Remove unused axes entirely + for _ax in axarr_it[n_params:]: + _ax.remove() f.tight_layout() - return f + return f \ No newline at end of file diff --git a/examples/Quickstart_Diagnostics.ipynb b/examples/Quickstart_Diagnostics.ipynb new file mode 100644 index 000000000..6ceddd5bc --- /dev/null +++ b/examples/Quickstart_Diagnostics.ipynb @@ -0,0 +1,695 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": "Quickstart Diagnostics", + "id": "ee8e90d08cdb035e" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T16:18:05.726356Z", + "start_time": "2024-10-22T16:18:05.710363Z" + } + }, + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "\n", + "# ensure the backend is set\n", + "import os\n", + "if \"KERAS_BACKEND\" not in os.environ:\n", + " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", + " os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "\n", + "import keras\n", + "\n", + "# for BayesFlow devs: this ensures that the latest dev version can be found\n", + "import sys\n", + "sys.path.append('../')\n", + "\n", + "import bayesflow as bf\n", + "from bayesflow.diagnostics.plot_losses import plot_losses\n", + "from bayesflow.diagnostics.plot_recovery import plot_recovery" + ], + "id": "56c348ceefe0a66f", + "outputs": [], + "execution_count": 24 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T16:24:23.691704Z", + "start_time": "2024-10-22T16:24:23.677134Z" + } + }, + "cell_type": "code", + "source": [ + "def theta_prior():\n", + " theta = np.random.normal(size=4)\n", + " return dict(theta=theta)\n", + "\n", + "def forward_model(theta, n_obs=100):\n", + " x = np.random.normal(loc=theta, size=(n_obs, theta.shape[0]))\n", + " return dict(x=x)" + ], + "id": "214241c510d751f4", + "outputs": [], + "execution_count": 38 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T16:24:24.077994Z", + "start_time": "2024-10-22T16:24:24.058994Z" + } + }, + "cell_type": "code", + "source": "simulator = bf.simulators.CompositeLambdaSimulator([theta_prior, forward_model])", + "id": "938dc70eb8ba4a54", + "outputs": [], + "execution_count": 39 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T16:24:24.971064Z", + "start_time": "2024-10-22T16:24:24.948063Z" + } + }, + "cell_type": "code", + "source": "sample_data = simulator.sample((50,))", + "id": "931b7f6a77c8401b", + "outputs": [], + "execution_count": 40 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T16:24:27.263198Z", + "start_time": "2024-10-22T16:24:27.249771Z" + } + }, + "cell_type": "code", + "source": [ + "print(\"Type of sample_data:\\n\\t\", type(sample_data))\n", + "print(\"Keys of sample_data:\\n\\t\", sample_data.keys())\n", + "print(\"Types of sample_data values:\\n\\t\", {k: type(v) for k, v in sample_data.items()})\n", + "print(\"Shapes of sample_data values:\\n\\t\", {k: v.shape for k, v in sample_data.items()})" + ], + "id": "e9b0a37820825b85", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of sample_data:\n", + "\t \n", + "Keys of sample_data:\n", + "\t dict_keys(['theta', 'x'])\n", + "Types of sample_data values:\n", + "\t {'theta': , 'x': }\n", + "Shapes of sample_data values:\n", + "\t {'theta': (50, 4), 'x': (50, 100, 4)}\n" + ] + } + ], + "execution_count": 41 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T16:25:08.028115Z", + "start_time": "2024-10-22T16:25:08.013014Z" + } + }, + "cell_type": "code", + "source": "sample_data", + "id": "c0caf508bb83962f", + "outputs": [ + { + "data": { + "text/plain": [ + "{'theta': array([[ 0.34066415, -0.5138845 , 1.4528089 , -0.49958685],\n", + " [ 0.86015004, 0.48635587, 0.2364767 , -0.53709507],\n", + " [-0.6582664 , -0.1106401 , -0.5822995 , 0.29959023],\n", + " [ 0.17613287, 0.40979308, -1.1803418 , 0.7906092 ],\n", + " [-2.5815134 , -0.5926008 , -1.442977 , -1.1212678 ],\n", + " [ 0.15085632, 0.8538437 , -0.71999 , -0.6779198 ],\n", + " [-0.04969181, 0.45948943, 0.6696255 , 0.9931811 ],\n", + " [ 1.015672 , 0.28774238, 0.18076487, -0.11111598],\n", + " [-0.7719207 , -1.3176122 , 0.5294132 , 0.4176514 ],\n", + " [ 0.03191099, -0.6768063 , 0.5141813 , -1.592261 ],\n", + " [-0.99575025, -1.8044442 , 0.56740063, -1.9281672 ],\n", + " [ 0.81070745, -0.60243636, -0.10667904, -0.3417887 ],\n", + " [ 1.0685011 , -1.3776896 , 1.8168131 , -0.8139481 ],\n", + " [ 0.01134184, 0.02382061, 1.6661643 , -0.46634912],\n", + " [-1.8478132 , -0.08229433, -0.04664409, -0.11284911],\n", + " [ 1.4040706 , -0.67715555, -0.2592975 , -0.20792411],\n", + " [ 2.399644 , 0.89749336, 2.4230204 , 0.0970002 ],\n", + " [ 0.93040395, 0.25475293, 0.8398071 , 0.29117548],\n", + " [-0.16029291, -0.02478953, 0.29951358, 0.33260188],\n", + " [-0.86853355, -1.1873287 , 1.9413403 , 0.32616952],\n", + " [-0.67677224, 0.02859171, 0.5428518 , -1.5521122 ],\n", + " [ 2.3776774 , -0.6828046 , 0.5556347 , -1.4531173 ],\n", + " [ 0.17357443, -0.45678964, -0.12053017, -0.8963106 ],\n", + " [ 0.20243092, 0.4169088 , 0.4405855 , 0.06946267],\n", + " [-0.4409229 , 0.07481287, -0.82419586, 0.33597344],\n", + " [-1.0189365 , 1.2648267 , -0.84935266, 0.58711445],\n", + " [ 0.8026323 , -0.73901856, -0.3391541 , 0.5761913 ],\n", + " [-0.9766659 , -1.1858828 , -0.5103791 , 0.73724025],\n", + " [ 1.5424025 , 0.28468883, -0.05811996, -1.0388048 ],\n", + " [ 1.6943154 , -0.36717394, 0.37467596, -0.18305473],\n", + " [-1.4355195 , -1.1271007 , 0.98609 , -2.1474707 ],\n", + " [-0.24416563, -0.88949496, -0.83292514, 0.05820413],\n", + " [ 1.0845547 , 0.97108537, 0.18267912, 0.16928157],\n", + " [ 1.5283549 , -0.9298855 , -1.9587208 , 1.4929713 ],\n", + " [ 0.34513605, 0.904506 , 0.46237883, -1.4228871 ],\n", + " [-0.81769085, 0.7091762 , -0.54571545, 0.5346092 ],\n", + " [-1.221732 , -1.3575743 , -1.3833972 , 1.5352001 ],\n", + " [-1.1201226 , -0.11686669, -0.21259853, -0.01677035],\n", + " [-0.1394734 , -0.3124989 , -0.21038432, -0.2977672 ],\n", + " [ 0.41691035, 0.28065392, -0.38032046, 0.95429885],\n", + " [-1.771154 , -1.1321709 , -1.9100127 , 0.5539506 ],\n", + " [-1.447865 , 1.2216287 , 2.154635 , -0.3226352 ],\n", + " [ 0.68915546, 0.41079593, -0.05922764, -2.326437 ],\n", + " [-0.81387454, -1.0814589 , -0.6311428 , -0.16105291],\n", + " [-0.12934463, 0.26514062, -1.6791768 , -0.20046751],\n", + " [-0.19893628, -0.48227343, 0.38067642, -0.8310641 ],\n", + " [ 1.2004272 , 0.0041292 , -0.02631984, 1.3608695 ],\n", + " [ 2.2010703 , 1.2613806 , -1.1433984 , -0.1893912 ],\n", + " [ 0.38409767, -0.2333284 , 0.67292047, 1.7366157 ],\n", + " [-0.22914144, 1.6965197 , 1.2130772 , 0.39718068]],\n", + " dtype=float32),\n", + " 'x': array([[[ 3.6254582 , -1.2774805 , 1.829676 , -0.32256916],\n", + " [-1.256944 , 0.3428607 , 1.6378316 , -1.9851911 ],\n", + " [ 0.7485846 , 1.0144739 , 2.0964758 , -0.01522239],\n", + " ...,\n", + " [ 0.08928863, -0.51128244, 1.1220746 , 0.7665733 ],\n", + " [ 0.3698297 , -3.188901 , 2.025168 , 0.01316792],\n", + " [-0.0478575 , -1.2997395 , 0.98696446, -1.16682 ]],\n", + " \n", + " [[ 2.8832037 , 0.8261143 , -1.1377803 , -1.1128289 ],\n", + " [ 0.9582984 , 0.79203564, -0.3229012 , -1.6079041 ],\n", + " [ 0.16740797, 1.4642444 , -0.06614214, 0.3791665 ],\n", + " ...,\n", + " [ 1.3970535 , 0.965935 , 0.5211107 , -0.23564771],\n", + " [ 0.92998844, -0.22871257, -1.2391579 , 0.9288718 ],\n", + " [ 0.23330195, -1.122806 , 0.29205167, -0.21602471]],\n", + " \n", + " [[-2.1422992 , 1.1895313 , -1.2813323 , 0.94614196],\n", + " [-0.6536977 , -1.7648559 , -2.6089973 , -0.25280094],\n", + " [-0.9758994 , -0.09433198, 0.5126696 , -0.4680349 ],\n", + " ...,\n", + " [ 0.6029331 , -0.5393456 , -3.4281068 , 1.9625674 ],\n", + " [-2.7362714 , 1.6726367 , -0.20103195, -0.2370804 ],\n", + " [ 1.6226304 , -0.20690091, 0.6907045 , -0.4412011 ]],\n", + " \n", + " ...,\n", + " \n", + " [[ 2.5122259 , 2.7933333 , -1.8370881 , -0.21269593],\n", + " [ 0.08193951, 2.768969 , -1.8215785 , -1.3286033 ],\n", + " [ 1.3069335 , 1.077523 , -3.191183 , -0.19069603],\n", + " ...,\n", + " [ 2.0579758 , 1.8073362 , -0.04059944, -0.1998354 ],\n", + " [ 1.7150037 , 1.5609382 , -1.7395236 , -1.3606325 ],\n", + " [ 2.9268007 , 0.23673572, -0.95533824, -1.3200113 ]],\n", + " \n", + " [[ 1.4436094 , -1.3594241 , 0.415787 , 2.261267 ],\n", + " [ 0.31457698, -0.8279396 , 1.7133617 , 1.7376964 ],\n", + " [-0.72784173, -1.3070168 , 1.0091938 , 2.5164726 ],\n", + " ...,\n", + " [ 0.4194977 , -0.2666566 , 0.8669603 , 3.1416023 ],\n", + " [-0.5483485 , -0.539848 , -0.2195546 , 1.9718776 ],\n", + " [ 0.20176356, 1.069642 , -0.65165126, 2.8493927 ]],\n", + " \n", + " [[-0.6640122 , 1.3560729 , 0.5739129 , -0.85077333],\n", + " [-1.5651604 , 2.4200397 , 1.7220724 , -2.3291683 ],\n", + " [ 0.12706083, 0.7526239 , 0.46398893, 0.5266859 ],\n", + " ...,\n", + " [ 0.28359127, 1.544274 , 1.2267944 , -0.26292163],\n", + " [-0.24361266, 2.2830348 , -0.09784857, -0.17053986],\n", + " [-0.12756453, 0.9381281 , 1.8230177 , 0.8788254 ]]],\n", + " dtype=float32)}" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 43 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T16:25:14.230395Z", + "start_time": "2024-10-22T16:25:14.219237Z" + } + }, + "cell_type": "code", + "source": [ + "data_adapter = bf.ContinuousApproximator.build_data_adapter(\n", + " inference_variables=[\"theta\"],\n", + " inference_conditions=[\"x\"]\n", + ")" + ], + "id": "b0f547fc9dfec62e", + "outputs": [], + "execution_count": 44 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T16:25:14.901880Z", + "start_time": "2024-10-22T16:25:14.887223Z" + } + }, + "cell_type": "code", + "source": [ + "# Define hyperparameters\n", + "num_training_batches = 1024\n", + "num_validation_batches = 256\n", + "batch_size = 128" + ], + "id": "d6a75322b3e87b16", + "outputs": [], + "execution_count": 45 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T17:39:40.767333Z", + "start_time": "2024-10-22T17:39:27.893610Z" + } + }, + "cell_type": "code", + "source": [ + "training_samples = simulator.sample((num_training_batches * batch_size, ))\n", + "validation_samples = simulator.sample((num_validation_batches * batch_size, ))\n", + "\n", + "training_dataset = bf.datasets.OnlineDataset(\n", + " simulator=simulator, \n", + " batch_size=batch_size, \n", + " num_batches=num_training_batches, \n", + " data_adapter=data_adapter\n", + ")\n", + "\n", + "validation_dataset = bf.datasets.OnlineDataset(\n", + " simulator=simulator,\n", + " batch_size=batch_size,\n", + " num_batches=num_validation_batches,\n", + " data_adapter=data_adapter\n", + ")" + ], + "id": "f54a245984369b8b", + "outputs": [], + "execution_count": 69 + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "summary_network = bf.networks.DeepSet(summary_dim=10)\n", + "summary_network.build(input_shape=(training_samples['x'].shape))" + ], + "id": "6d219a2947a41c39", + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T17:33:25.157778Z", + "start_time": "2024-10-22T17:33:25.114152Z" + } + }, + "cell_type": "code", + "source": [ + "inference_network = bf.networks.FlowMatching(\n", + " subnet=\"mlp\",\n", + " subnet_kwargs=dict(\n", + " depth=6,\n", + " width=256,\n", + " ),\n", + ")\n", + "inference_network.build()" + ], + "id": "ecc20e920b0dc330", + "outputs": [], + "execution_count": 61 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T16:29:39.431488Z", + "start_time": "2024-10-22T16:29:39.361352Z" + } + }, + "cell_type": "code", + "source": [ + "test_sim = simulator.sample((4,))\n", + "z, log_det_J = summary_network(test_sim['x'])" + ], + "id": "2d182d111fdacf3b", + "outputs": [ + { + "ename": "ValueError", + "evalue": "too many values to unpack (expected 2)", + "output_type": "error", + "traceback": [ + "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[1;32mIn[58], line 2\u001B[0m\n\u001B[0;32m 1\u001B[0m test_sim \u001B[38;5;241m=\u001B[39m simulator\u001B[38;5;241m.\u001B[39msample((\u001B[38;5;241m4\u001B[39m,))\n\u001B[1;32m----> 2\u001B[0m z, log_det_J \u001B[38;5;241m=\u001B[39m summary_network(test_sim[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mx\u001B[39m\u001B[38;5;124m'\u001B[39m])\n", + "\u001B[1;31mValueError\u001B[0m: too many values to unpack (expected 2)" + ] + } + ], + "execution_count": 58 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T17:33:32.494258Z", + "start_time": "2024-10-22T17:33:32.477712Z" + } + }, + "cell_type": "code", + "source": [ + "approximator = bf.ContinuousApproximator(\n", + " inference_network=inference_network,\n", + " data_adapter=data_adapter,\n", + ")" + ], + "id": "a3b83230f640d6d9", + "outputs": [], + "execution_count": 62 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T17:33:33.040002Z", + "start_time": "2024-10-22T17:33:33.017491Z" + } + }, + "cell_type": "code", + "source": [ + "learning_rate = 1e-4\n", + "optimizer = keras.optimizers.Adam(learning_rate=learning_rate)" + ], + "id": "f0c0c672f6667945", + "outputs": [], + "execution_count": 63 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T17:33:33.566726Z", + "start_time": "2024-10-22T17:33:33.550196Z" + } + }, + "cell_type": "code", + "source": [ + "class BatchLossHistory(keras.callbacks.Callback):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.losses = {\n", + " \"training_loss\": [],\n", + " \"validation_loss\": [],\n", + " }\n", + "\n", + "\n", + " def on_train_batch_end(self, batch, logs=None):\n", + " # 'logs' is a dictionary containing loss and other metrics\n", + " training_loss = logs.get('loss')\n", + " self.losses[\"training_loss\"].append(training_loss)\n", + " \n", + " \n", + " def on_test_batch_end(self, batch, logs=None):\n", + " validation_loss = logs.get('loss')\n", + " self.losses[\"validation_loss\"].append(validation_loss)" + ], + "id": "359d6e9fe112d405", + "outputs": [], + "execution_count": 64 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T17:33:34.490293Z", + "start_time": "2024-10-22T17:33:34.464277Z" + } + }, + "cell_type": "code", + "source": "approximator.compile(optimizer=optimizer)", + "id": "7b96a6c3943dcf40", + "outputs": [], + "execution_count": 65 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T17:33:34.834571Z", + "start_time": "2024-10-22T17:33:34.821630Z" + } + }, + "cell_type": "code", + "source": "batch_loss_history = BatchLossHistory()", + "id": "e683fe5d365b279e", + "outputs": [], + "execution_count": 66 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-22T17:33:36.209593Z", + "start_time": "2024-10-22T17:33:35.778861Z" + } + }, + "cell_type": "code", + "source": [ + "history = approximator.fit(\n", + " epochs=10,\n", + " dataset=training_dataset,\n", + " validation_data=validation_dataset,\n", + " callbacks=[batch_loss_history]\n", + ")" + ], + "id": "768ee6ac6ce0ef37", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:bayesflow:Fitting on dataset instance of OnlineDataset.\n", + "INFO:bayesflow:Building on a test batch.\n" + ] + }, + { + "ename": "TypeError", + "evalue": "Cannot concatenate arrays with different numbers of dimensions: got (128, 4), (128, 1), (128, 100, 4).", + "output_type": "error", + "traceback": [ + "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[1;31mTypeError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[1;32mIn[67], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m history \u001B[38;5;241m=\u001B[39m \u001B[43mapproximator\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 2\u001B[0m \u001B[43m \u001B[49m\u001B[43mepochs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m10\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[0;32m 3\u001B[0m \u001B[43m \u001B[49m\u001B[43mdataset\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtraining_dataset\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 4\u001B[0m \u001B[43m \u001B[49m\u001B[43mvalidation_data\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvalidation_dataset\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 5\u001B[0m \u001B[43m \u001B[49m\u001B[43mcallbacks\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43m[\u001B[49m\u001B[43mbatch_loss_history\u001B[49m\u001B[43m]\u001B[49m\n\u001B[0;32m 6\u001B[0m \u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\approximators\\continuous_approximator.py:109\u001B[0m, in \u001B[0;36mContinuousApproximator.fit\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 108\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mfit\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m--> 109\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39mfit(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs, data_adapter\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdata_adapter)\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\approximators\\approximator.py:82\u001B[0m, in \u001B[0;36mApproximator.fit\u001B[1;34m(self, dataset, simulator, **kwargs)\u001B[0m\n\u001B[0;32m 80\u001B[0m mock_data \u001B[38;5;241m=\u001B[39m dataset[\u001B[38;5;241m0\u001B[39m]\n\u001B[0;32m 81\u001B[0m mock_data \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mtree\u001B[38;5;241m.\u001B[39mmap_structure(keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mconvert_to_tensor, mock_data)\n\u001B[1;32m---> 82\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbuild_from_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmock_data\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 84\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39mfit(dataset\u001B[38;5;241m=\u001B[39mdataset, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\approximators\\approximator.py:23\u001B[0m, in \u001B[0;36mApproximator.build_from_data\u001B[1;34m(self, data)\u001B[0m\n\u001B[0;32m 22\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mbuild_from_data\u001B[39m(\u001B[38;5;28mself\u001B[39m, data: \u001B[38;5;28mdict\u001B[39m[\u001B[38;5;28mstr\u001B[39m, \u001B[38;5;28many\u001B[39m]) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m---> 23\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcompute_metrics(\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mdata, stage\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtraining\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 24\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuilt \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\approximators\\continuous_approximator.py:95\u001B[0m, in \u001B[0;36mContinuousApproximator.compute_metrics\u001B[1;34m(self, inference_variables, inference_conditions, summary_variables, stage)\u001B[0m\n\u001B[0;32m 92\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m 93\u001B[0m inference_conditions \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mconcatenate([inference_conditions, summary_outputs], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m---> 95\u001B[0m inference_metrics \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43minference_network\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcompute_metrics\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 96\u001B[0m \u001B[43m \u001B[49m\u001B[43minference_variables\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mconditions\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minference_conditions\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mstage\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mstage\u001B[49m\n\u001B[0;32m 97\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 99\u001B[0m loss \u001B[38;5;241m=\u001B[39m inference_metrics\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mloss\u001B[39m\u001B[38;5;124m\"\u001B[39m, keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mzeros(())) \u001B[38;5;241m+\u001B[39m summary_metrics\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mloss\u001B[39m\u001B[38;5;124m\"\u001B[39m, keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mzeros(()))\n\u001B[0;32m 101\u001B[0m inference_metrics \u001B[38;5;241m=\u001B[39m {\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mkey\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m/inference_\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mkey\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m: value \u001B[38;5;28;01mfor\u001B[39;00m key, value \u001B[38;5;129;01min\u001B[39;00m inference_metrics\u001B[38;5;241m.\u001B[39mitems()}\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\networks\\flow_matching\\flow_matching.py:122\u001B[0m, in \u001B[0;36mFlowMatching.compute_metrics\u001B[1;34m(self, x, conditions, stage)\u001B[0m\n\u001B[0;32m 118\u001B[0m target_velocity \u001B[38;5;241m=\u001B[39m x1 \u001B[38;5;241m-\u001B[39m x0\n\u001B[0;32m 120\u001B[0m base_metrics \u001B[38;5;241m=\u001B[39m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39mcompute_metrics(x1, conditions, stage)\n\u001B[1;32m--> 122\u001B[0m predicted_velocity \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mintegrator\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvelocity\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mt\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mconditions\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 124\u001B[0m loss \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mlosses\u001B[38;5;241m.\u001B[39mmean_squared_error(target_velocity, predicted_velocity)\n\u001B[0;32m 125\u001B[0m loss \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mmean(loss)\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\networks\\flow_matching\\integrators\\euler.py:45\u001B[0m, in \u001B[0;36mEulerIntegrator.velocity\u001B[1;34m(self, x, t, conditions, **kwargs)\u001B[0m\n\u001B[0;32m 43\u001B[0m xtc \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mconcatenate([x, t], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)\n\u001B[0;32m 44\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m---> 45\u001B[0m xtc \u001B[38;5;241m=\u001B[39m \u001B[43mkeras\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mops\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconcatenate\u001B[49m\u001B[43m(\u001B[49m\u001B[43m[\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mt\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mconditions\u001B[49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maxis\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m-\u001B[39;49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[0;32m 47\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moutput_projector(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39msubnet(xtc, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs))\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\keras\\src\\ops\\numpy.py:1352\u001B[0m, in \u001B[0;36mconcatenate\u001B[1;34m(xs, axis)\u001B[0m\n\u001B[0;32m 1350\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m any_symbolic_tensors(xs):\n\u001B[0;32m 1351\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m Concatenate(axis\u001B[38;5;241m=\u001B[39maxis)\u001B[38;5;241m.\u001B[39msymbolic_call(xs)\n\u001B[1;32m-> 1352\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mbackend\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnumpy\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconcatenate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mxs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maxis\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43maxis\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\keras\\src\\backend\\jax\\numpy.py:405\u001B[0m, in \u001B[0;36mconcatenate\u001B[1;34m(xs, axis)\u001B[0m\n\u001B[0;32m 400\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m 401\u001B[0m xs \u001B[38;5;241m=\u001B[39m [\n\u001B[0;32m 402\u001B[0m x\u001B[38;5;241m.\u001B[39mtodense() \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(x, jax_sparse\u001B[38;5;241m.\u001B[39mJAXSparse) \u001B[38;5;28;01melse\u001B[39;00m x\n\u001B[0;32m 403\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m x \u001B[38;5;129;01min\u001B[39;00m xs\n\u001B[0;32m 404\u001B[0m ]\n\u001B[1;32m--> 405\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mjnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconcatenate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mxs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maxis\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43maxis\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\numpy\\lax_numpy.py:4243\u001B[0m, in \u001B[0;36mconcatenate\u001B[1;34m(arrays, axis, dtype)\u001B[0m\n\u001B[0;32m 4241\u001B[0m k \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m16\u001B[39m\n\u001B[0;32m 4242\u001B[0m \u001B[38;5;28;01mwhile\u001B[39;00m \u001B[38;5;28mlen\u001B[39m(arrays_out) \u001B[38;5;241m>\u001B[39m \u001B[38;5;241m1\u001B[39m:\n\u001B[1;32m-> 4243\u001B[0m arrays_out \u001B[38;5;241m=\u001B[39m [lax\u001B[38;5;241m.\u001B[39mconcatenate(arrays_out[i:i\u001B[38;5;241m+\u001B[39mk], axis)\n\u001B[0;32m 4244\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(\u001B[38;5;241m0\u001B[39m, \u001B[38;5;28mlen\u001B[39m(arrays_out), k)]\n\u001B[0;32m 4245\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m arrays_out[\u001B[38;5;241m0\u001B[39m]\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\numpy\\lax_numpy.py:4243\u001B[0m, in \u001B[0;36m\u001B[1;34m(.0)\u001B[0m\n\u001B[0;32m 4241\u001B[0m k \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m16\u001B[39m\n\u001B[0;32m 4242\u001B[0m \u001B[38;5;28;01mwhile\u001B[39;00m \u001B[38;5;28mlen\u001B[39m(arrays_out) \u001B[38;5;241m>\u001B[39m \u001B[38;5;241m1\u001B[39m:\n\u001B[1;32m-> 4243\u001B[0m arrays_out \u001B[38;5;241m=\u001B[39m [\u001B[43mlax\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconcatenate\u001B[49m\u001B[43m(\u001B[49m\u001B[43marrays_out\u001B[49m\u001B[43m[\u001B[49m\u001B[43mi\u001B[49m\u001B[43m:\u001B[49m\u001B[43mi\u001B[49m\u001B[38;5;241;43m+\u001B[39;49m\u001B[43mk\u001B[49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maxis\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 4244\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(\u001B[38;5;241m0\u001B[39m, \u001B[38;5;28mlen\u001B[39m(arrays_out), k)]\n\u001B[0;32m 4245\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m arrays_out[\u001B[38;5;241m0\u001B[39m]\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\lax\\lax.py:650\u001B[0m, in \u001B[0;36mconcatenate\u001B[1;34m(operands, dimension)\u001B[0m\n\u001B[0;32m 648\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(op, Array):\n\u001B[0;32m 649\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m op\n\u001B[1;32m--> 650\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mconcatenate_p\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbind\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43moperands\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdimension\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mdimension\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\core.py:438\u001B[0m, in \u001B[0;36mPrimitive.bind\u001B[1;34m(self, *args, **params)\u001B[0m\n\u001B[0;32m 435\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mbind\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mparams):\n\u001B[0;32m 436\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m (\u001B[38;5;129;01mnot\u001B[39;00m config\u001B[38;5;241m.\u001B[39menable_checks\u001B[38;5;241m.\u001B[39mvalue \u001B[38;5;129;01mor\u001B[39;00m\n\u001B[0;32m 437\u001B[0m \u001B[38;5;28mall\u001B[39m(\u001B[38;5;28misinstance\u001B[39m(arg, Tracer) \u001B[38;5;129;01mor\u001B[39;00m valid_jaxtype(arg) \u001B[38;5;28;01mfor\u001B[39;00m arg \u001B[38;5;129;01min\u001B[39;00m args)), args\n\u001B[1;32m--> 438\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbind_with_trace\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfind_top_trace\u001B[49m\u001B[43m(\u001B[49m\u001B[43margs\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mparams\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\core.py:442\u001B[0m, in \u001B[0;36mPrimitive.bind_with_trace\u001B[1;34m(self, trace, args, params)\u001B[0m\n\u001B[0;32m 440\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mbind_with_trace\u001B[39m(\u001B[38;5;28mself\u001B[39m, trace, args, params):\n\u001B[0;32m 441\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m pop_level(trace\u001B[38;5;241m.\u001B[39mlevel):\n\u001B[1;32m--> 442\u001B[0m out \u001B[38;5;241m=\u001B[39m \u001B[43mtrace\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mprocess_primitive\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mmap\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mtrace\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfull_raise\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43margs\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mparams\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 443\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mmap\u001B[39m(full_lower, out) \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmultiple_results \u001B[38;5;28;01melse\u001B[39;00m full_lower(out)\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\core.py:948\u001B[0m, in \u001B[0;36mEvalTrace.process_primitive\u001B[1;34m(self, primitive, tracers, params)\u001B[0m\n\u001B[0;32m 946\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m call_impl_with_key_reuse_checks(primitive, primitive\u001B[38;5;241m.\u001B[39mimpl, \u001B[38;5;241m*\u001B[39mtracers, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mparams)\n\u001B[0;32m 947\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m--> 948\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m primitive\u001B[38;5;241m.\u001B[39mimpl(\u001B[38;5;241m*\u001B[39mtracers, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mparams)\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\dispatch.py:90\u001B[0m, in \u001B[0;36mapply_primitive\u001B[1;34m(prim, *args, **params)\u001B[0m\n\u001B[0;32m 88\u001B[0m prev \u001B[38;5;241m=\u001B[39m lib\u001B[38;5;241m.\u001B[39mjax_jit\u001B[38;5;241m.\u001B[39mswap_thread_local_state_disable_jit(\u001B[38;5;28;01mFalse\u001B[39;00m)\n\u001B[0;32m 89\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m---> 90\u001B[0m outs \u001B[38;5;241m=\u001B[39m \u001B[43mfun\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 91\u001B[0m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[0;32m 92\u001B[0m lib\u001B[38;5;241m.\u001B[39mjax_jit\u001B[38;5;241m.\u001B[39mswap_thread_local_state_disable_jit(prev)\n", + " \u001B[1;31m[... skipping hidden 18 frame]\u001B[0m\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\lax\\lax.py:3904\u001B[0m, in \u001B[0;36m_concatenate_shape_rule\u001B[1;34m(*operands, **kwargs)\u001B[0m\n\u001B[0;32m 3902\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mlen\u001B[39m({operand\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;28;01mfor\u001B[39;00m operand \u001B[38;5;129;01min\u001B[39;00m operands}) \u001B[38;5;241m!=\u001B[39m \u001B[38;5;241m1\u001B[39m:\n\u001B[0;32m 3903\u001B[0m msg \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mCannot concatenate arrays with different numbers of dimensions: got \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m-> 3904\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(msg\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m, \u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39mjoin(\u001B[38;5;28mstr\u001B[39m(o\u001B[38;5;241m.\u001B[39mshape) \u001B[38;5;28;01mfor\u001B[39;00m o \u001B[38;5;129;01min\u001B[39;00m operands)))\n\u001B[0;32m 3905\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;241m0\u001B[39m \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m dimension \u001B[38;5;241m<\u001B[39m operands[\u001B[38;5;241m0\u001B[39m]\u001B[38;5;241m.\u001B[39mndim:\n\u001B[0;32m 3906\u001B[0m msg \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mconcatenate dimension out of bounds: dimension \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m for shapes \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n", + "\u001B[1;31mTypeError\u001B[0m: Cannot concatenate arrays with different numbers of dimensions: got (128, 4), (128, 1), (128, 100, 4)." + ] + } + ], + "execution_count": 67 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-20T18:45:43.804570Z", + "start_time": "2024-10-20T18:45:42.197334Z" + } + }, + "cell_type": "code", + "source": [ + "import pandas as pd\n", + "\n", + "f = plot_losses(\n", + " train_losses=pd.DataFrame(batch_loss_history.losses[\"training_loss\"]), \n", + " val_losses=pd.DataFrame(batch_loss_history.losses[\"validation_loss\"])\n", + ")" + ], + "id": "4aa8f4aa440e9925", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 36 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-20T19:00:34.909974Z", + "start_time": "2024-10-20T19:00:32.167528Z" + } + }, + "cell_type": "code", + "source": [ + "n_samples = 5000\n", + "\n", + "conditions = {\n", + " \"x\": np.array([[0.0, 0.0, 0.0, 0.0]]).astype(np.float32),\n", + "}\n", + "\n", + "samples = approximator.sample(conditions=conditions, num_samples=n_samples)\n", + "\n", + "theta_samples = samples[\"theta\"]" + ], + "id": "2c9328186f195ecd", + "outputs": [], + "execution_count": 37 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-20T19:00:35.847766Z", + "start_time": "2024-10-20T19:00:35.835217Z" + } + }, + "cell_type": "code", + "source": "theta_samples", + "id": "ad55f949b09a8160", + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[-0.8272288 , -0.49605218, 0.70173615, 0.5919672 ],\n", + " [-0.3977319 , 0.29412413, 0.4869203 , -0.6501034 ],\n", + " [ 0.36272183, -0.808805 , 0.0503429 , -0.8143094 ],\n", + " ...,\n", + " [ 0.44116256, 0.22608528, 0.08759339, 1.1507258 ],\n", + " [-0.14893901, -0.90049297, 0.67143995, -0.46183816],\n", + " [ 0.05667248, 0.3298449 , 0.0929981 , -0.47135326]]],\n", + " dtype=float32)" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 38 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-20T19:00:37.143832Z", + "start_time": "2024-10-20T19:00:36.896189Z" + } + }, + "cell_type": "code", + "source": "prior_samples = simulator.sample(conditions=conditions, batch_shape=(n_samples,))", + "id": "d05dc915057060f9", + "outputs": [], + "execution_count": 39 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-20T19:00:37.768597Z", + "start_time": "2024-10-20T19:00:37.759081Z" + } + }, + "cell_type": "code", + "source": [ + "prior_theta_sample = np.zeros((1, n_samples, 4))\n", + "prior_theta_sample[0, :] = prior_samples['theta']" + ], + "id": "67dfb4675712b77b", + "outputs": [], + "execution_count": 40 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-20T19:00:40.077651Z", + "start_time": "2024-10-20T19:00:38.274146Z" + } + }, + "cell_type": "code", + "source": [ + "f = plot_recovery(\n", + " post_samples=np.swapaxes(theta_samples, 0, 1), \n", + " prior_samples=prior_samples['theta']\n", + ")" + ], + "id": "aa003fa535917e25", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 41 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "fee7afdd60d78883" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/TwoMoons_Diagnostics.ipynb b/examples/TwoMoons_Diagnostics.ipynb index 43abbdaa4..c3aeb3e04 100644 --- a/examples/TwoMoons_Diagnostics.ipynb +++ b/examples/TwoMoons_Diagnostics.ipynb @@ -1,13 +1,19 @@ { "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": "TwoMoons Diagnostics", + "id": "33598ea529988920" + }, { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2024-10-14T21:00:47.608845Z", - "start_time": "2024-10-14T21:00:47.596803Z" + "end_time": "2024-10-20T14:56:48.800425Z", + "start_time": "2024-10-20T14:56:03.774871Z" } }, "source": [ @@ -32,13 +38,13 @@ "from bayesflow.diagnostics.plot_recovery import plot_recovery" ], "outputs": [], - "execution_count": 29 + "execution_count": 1 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:27:46.593978Z", - "start_time": "2024-10-14T20:27:46.585246Z" + "end_time": "2024-10-18T20:35:41.543711Z", + "start_time": "2024-10-18T20:35:41.533024Z" } }, "cell_type": "code", @@ -62,26 +68,26 @@ ], "id": "2aa9c9710a36b980", "outputs": [], - "execution_count": 2 + "execution_count": 3 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:27:47.665818Z", - "start_time": "2024-10-14T20:27:47.650596Z" + "end_time": "2024-10-18T20:36:14.245177Z", + "start_time": "2024-10-18T20:36:13.012180Z" } }, "cell_type": "code", "source": "simulator = bf.simulators.CompositeLambdaSimulator([alpha_prior, r_prior, theta_prior, forward_model])", "id": "7db949c6bfecc86d", "outputs": [], - "execution_count": 3 + "execution_count": 6 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:27:48.154339Z", - "start_time": "2024-10-14T20:27:48.131066Z" + "end_time": "2024-10-18T20:36:16.748867Z", + "start_time": "2024-10-18T20:36:15.701865Z" } }, "cell_type": "code", @@ -91,13 +97,13 @@ ], "id": "b3a0fe5293beec1b", "outputs": [], - "execution_count": 4 + "execution_count": 7 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:27:48.809696Z", - "start_time": "2024-10-14T20:27:48.792150Z" + "end_time": "2024-10-18T20:36:18.513991Z", + "start_time": "2024-10-18T20:36:17.139756Z" } }, "cell_type": "code", @@ -124,13 +130,13 @@ ] } ], - "execution_count": 5 + "execution_count": 8 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:27:50.913110Z", - "start_time": "2024-10-14T20:27:50.887083Z" + "end_time": "2024-10-18T20:36:19.705829Z", + "start_time": "2024-10-18T20:36:19.677479Z" } }, "cell_type": "code", @@ -140,532 +146,532 @@ { "data": { "text/plain": [ - "{'alpha': array([[-0.32667086],\n", - " [ 0.4273289 ],\n", - " [-1.4049091 ],\n", - " [ 0.6495664 ],\n", - " [-1.2431567 ],\n", - " [ 0.00233715],\n", - " [ 0.8480671 ],\n", - " [ 0.3825317 ],\n", - " [ 0.51094717],\n", - " [-1.089614 ],\n", - " [-1.3354424 ],\n", - " [ 1.521167 ],\n", - " [-0.39602077],\n", - " [ 0.88564837],\n", - " [ 0.9123921 ],\n", - " [-0.8354839 ],\n", - " [ 0.4103786 ],\n", - " [-1.3677335 ],\n", - " [ 1.5597553 ],\n", - " [ 0.857706 ],\n", - " [ 1.3851596 ],\n", - " [ 0.08588123],\n", - " [-1.3278203 ],\n", - " [-1.4882586 ],\n", - " [-0.7284002 ],\n", - " [ 0.13790596],\n", - " [ 1.0051142 ],\n", - " [-1.2707354 ],\n", - " [-1.4755409 ],\n", - " [-0.03580285],\n", - " [-1.5595189 ],\n", - " [-0.67198914],\n", - " [-1.5414019 ],\n", - " [ 0.46680772],\n", - " [-0.49131817],\n", - " [ 1.3081697 ],\n", - " [ 1.3158842 ],\n", - " [ 0.5498999 ],\n", - " [ 0.28842863],\n", - " [ 1.5669599 ],\n", - " [ 1.1178267 ],\n", - " [-1.3773671 ],\n", - " [ 0.21009572],\n", - " [ 1.247244 ],\n", - " [-0.70116574],\n", - " [-1.3302704 ],\n", - " [-1.2457173 ],\n", - " [ 1.5158378 ],\n", - " [ 0.4760814 ],\n", - " [ 0.5130099 ],\n", - " [ 0.38356388],\n", - " [ 0.51667935],\n", - " [-1.3263792 ],\n", - " [ 1.1610556 ],\n", - " [-1.0892584 ],\n", - " [ 1.5635314 ],\n", - " [ 0.6141801 ],\n", - " [ 0.7341371 ],\n", - " [ 0.23875016],\n", - " [ 0.6227148 ],\n", - " [-0.4782562 ],\n", - " [ 0.3656121 ],\n", - " [ 0.6328574 ],\n", - " [-1.3631004 ],\n", - " [ 1.4811654 ],\n", - " [-1.373762 ],\n", - " [-0.24200128],\n", - " [ 0.00168619],\n", - " [ 1.0400674 ],\n", - " [ 0.18901351],\n", - " [ 0.9238482 ],\n", - " [ 0.02520916],\n", - " [ 1.4398099 ],\n", - " [-0.51892877],\n", - " [-1.1703858 ],\n", - " [-0.12021437],\n", - " [ 0.80984557],\n", - " [-0.9726958 ],\n", - " [-0.2244538 ],\n", - " [ 0.30686495],\n", - " [ 0.59431726],\n", - " [-1.322811 ],\n", - " [ 0.8136638 ],\n", - " [-1.5020558 ],\n", - " [-1.4799207 ],\n", - " [-1.3136892 ],\n", - " [ 0.06446969],\n", - " [ 1.3328581 ],\n", - " [ 0.66848814],\n", - " [ 0.7860198 ],\n", - " [ 1.315634 ],\n", - " [-0.23607427],\n", - " [-0.8002341 ],\n", - " [ 1.5251296 ],\n", - " [ 0.15763855],\n", - " [-1.5531527 ],\n", - " [ 0.56706136],\n", - " [ 1.047334 ],\n", - " [ 0.89252347],\n", - " [ 1.2277393 ],\n", - " [ 0.8999341 ],\n", - " [ 1.0635433 ],\n", - " [ 0.04854681],\n", - " [ 0.84339076],\n", - " [ 0.42572305],\n", - " [-0.13823606],\n", - " [ 0.36718416],\n", - " [-1.1577339 ],\n", - " [-0.5522179 ],\n", - " [ 0.7911456 ],\n", - " [ 0.8179233 ],\n", - " [-0.62356246],\n", - " [-0.33656436],\n", - " [ 0.17404567],\n", - " [ 0.4389914 ],\n", - " [-0.9474675 ],\n", - " [-1.1168886 ],\n", - " [-0.09231075],\n", - " [ 1.0462689 ],\n", - " [-0.90480804],\n", - " [-1.4208354 ],\n", - " [-0.16266003],\n", - " [ 0.58943385],\n", - " [ 0.9045791 ],\n", - " [ 0.42233914],\n", - " [-0.9887428 ],\n", - " [ 1.3377244 ],\n", - " [ 1.3765699 ]], dtype=float32),\n", - " 'r': array([[0.08992655],\n", - " [0.11262761],\n", - " [0.09988438],\n", - " [0.10757513],\n", - " [0.09842595],\n", - " [0.08780682],\n", - " [0.09458143],\n", - " [0.08305464],\n", - " [0.10684904],\n", - " [0.09243193],\n", - " [0.09636895],\n", - " [0.10226896],\n", - " [0.10459247],\n", - " [0.10921837],\n", - " [0.11026573],\n", - " [0.08566567],\n", - " [0.08520468],\n", - " [0.09628233],\n", - " [0.09432837],\n", - " [0.09125222],\n", - " [0.09802835],\n", - " [0.08590778],\n", - " [0.10220997],\n", - " [0.09208615],\n", - " [0.09563595],\n", - " [0.11774731],\n", - " [0.08680242],\n", - " [0.11012331],\n", - " [0.09633071],\n", - " [0.09756536],\n", - " [0.12378401],\n", - " [0.11334178],\n", - " [0.09062148],\n", - " [0.10854411],\n", - " [0.09699341],\n", - " [0.09652205],\n", - " [0.10683485],\n", - " [0.10969307],\n", - " [0.1108022 ],\n", - " [0.10317604],\n", - " [0.07496227],\n", - " [0.09012368],\n", - " [0.09105562],\n", - " [0.08200298],\n", - " [0.0828375 ],\n", - " [0.10398124],\n", - " [0.1007928 ],\n", - " [0.10111594],\n", - " [0.10167468],\n", - " [0.08223265],\n", - " [0.09048541],\n", - " [0.09253196],\n", - " [0.11334959],\n", - " [0.10842754],\n", - " [0.09764497],\n", - " [0.08462145],\n", - " [0.11413962],\n", - " [0.11527291],\n", - " [0.10542885],\n", - " [0.09038962],\n", - " [0.10374972],\n", - " [0.10187822],\n", - " [0.10547098],\n", - " [0.0985254 ],\n", - " [0.11656975],\n", - " [0.10378908],\n", - " [0.09430881],\n", - " [0.10135388],\n", - " [0.09672231],\n", - " [0.10255771],\n", - " [0.09387974],\n", - " [0.09308615],\n", - " [0.09995835],\n", - " [0.10125452],\n", - " [0.08677949],\n", - " [0.10938775],\n", - " [0.08700917],\n", - " [0.10388696],\n", - " [0.10093628],\n", - " [0.08200264],\n", - " [0.10838373],\n", - " [0.11670296],\n", - " [0.0975048 ],\n", - " [0.10851161],\n", - " [0.11573117],\n", - " [0.08443198],\n", - " [0.11458082],\n", - " [0.09952442],\n", - " [0.09616404],\n", - " [0.10941261],\n", - " [0.10953938],\n", - " [0.10442203],\n", - " [0.10339843],\n", - " [0.11485437],\n", - " [0.10533367],\n", - " [0.09481129],\n", - " [0.09040346],\n", - " [0.10173973],\n", - " [0.10177024],\n", - " [0.11780889],\n", - " [0.09570873],\n", - " [0.11882972],\n", - " [0.08545157],\n", - " [0.09944647],\n", - " [0.08443879],\n", - " [0.08220201],\n", - " [0.1000874 ],\n", - " [0.09128848],\n", - " [0.08286219],\n", - " [0.09595444],\n", - " [0.0979057 ],\n", - " [0.10283633],\n", - " [0.09927638],\n", - " [0.08226943],\n", - " [0.08991795],\n", - " [0.08759429],\n", - " [0.09739396],\n", - " [0.09011568],\n", - " [0.1022917 ],\n", - " [0.10698826],\n", - " [0.09153012],\n", - " [0.11289136],\n", - " [0.09984156],\n", - " [0.09630046],\n", - " [0.09598115],\n", - " [0.09263593],\n", - " [0.11081006],\n", - " [0.10720474]], dtype=float32),\n", - " 'theta': array([[-3.05581540e-02, 1.45420015e-01],\n", - " [ 2.11661726e-01, -8.16319406e-01],\n", - " [-2.41767168e-01, -8.90369564e-02],\n", - " [-9.58784044e-01, -6.43721148e-02],\n", - " [-8.37376893e-01, -3.60958546e-01],\n", - " [-5.58282018e-01, 9.12082434e-01],\n", - " [ 4.27646607e-01, 8.86675537e-01],\n", - " [ 1.40667871e-01, -2.22100895e-02],\n", - " [ 1.42964154e-01, -7.41437301e-02],\n", - " [-2.23557115e-01, 6.94955945e-01],\n", - " [ 1.82445496e-01, -8.65079284e-01],\n", - " [ 4.57260549e-01, -5.30316174e-01],\n", - " [-6.21315300e-01, 2.07263768e-01],\n", - " [ 6.10240161e-01, -8.74304950e-01],\n", - " [-2.59324551e-01, 4.74186003e-01],\n", - " [-3.78903113e-02, 9.00502682e-01],\n", - " [ 6.82602823e-01, -7.56820023e-01],\n", - " [ 5.94463050e-01, -1.21822745e-01],\n", - " [ 8.27706277e-01, 3.59144360e-01],\n", - " [-7.00999856e-01, 4.03989136e-01],\n", - " [-4.31024581e-01, -7.00606406e-01],\n", - " [ 9.86329079e-01, -8.04728150e-01],\n", - " [-5.38084447e-01, 3.05770040e-01],\n", - " [ 6.28210962e-01, 3.67884368e-01],\n", - " [ 1.20291792e-01, 5.00622094e-01],\n", - " [-7.59945214e-01, 4.06511813e-01],\n", - " [ 9.54336047e-01, -4.55123216e-01],\n", - " [-4.13251549e-01, 5.47428668e-01],\n", - " [-9.78350341e-01, 4.92099434e-01],\n", - " [ 6.58699691e-01, -5.58192194e-01],\n", - " [ 3.67015302e-01, 3.97362381e-01],\n", - " [-8.03872824e-01, -5.61456561e-01],\n", - " [-7.27509439e-01, -2.47731626e-01],\n", - " [-2.85620838e-01, -4.79526430e-01],\n", - " [-1.94825262e-01, -2.55926311e-01],\n", - " [ 1.67872280e-01, -9.53495979e-01],\n", - " [-1.38880566e-01, -6.23547696e-02],\n", - " [ 8.53347182e-01, -2.39163131e-01],\n", - " [ 6.25432789e-01, -7.36523867e-01],\n", - " [ 7.81854391e-01, 2.92491883e-01],\n", - " [-8.66746664e-01, -8.48121166e-01],\n", - " [-1.03085585e-01, -3.94162923e-01],\n", - " [-7.86425531e-01, 8.93066466e-01],\n", - " [ 4.55110759e-01, -3.10456127e-01],\n", - " [ 8.81353259e-01, 5.10756969e-01],\n", - " [-7.21052825e-01, -2.99451917e-01],\n", - " [-7.17009425e-01, 7.83575058e-01],\n", - " [ 3.48350137e-01, 7.37254381e-01],\n", - " [-7.80723929e-01, -8.84120941e-01],\n", - " [-1.95760652e-01, -2.31012523e-01],\n", - " [ 6.60519898e-01, -6.55787408e-01],\n", - " [-7.19074786e-01, -2.03196511e-01],\n", - " [ 6.15297891e-02, -5.67862332e-01],\n", - " [-6.49809659e-01, 5.80364406e-01],\n", - " [ 2.61088669e-01, 5.19194543e-01],\n", - " [ 3.60945612e-01, 3.35697114e-01],\n", - " [ 3.50850195e-01, -7.03177214e-01],\n", - " [ 7.74477005e-01, -2.59503335e-01],\n", - " [-7.42385924e-01, -8.22378099e-01],\n", - " [-2.85933644e-01, 6.74357533e-01],\n", - " [-8.53843629e-01, 4.06876981e-01],\n", - " [ 9.58840549e-01, 9.16405022e-01],\n", - " [ 7.74163187e-01, 7.48323083e-01],\n", - " [ 8.63941729e-01, -3.28039467e-01],\n", - " [-2.40420654e-01, 2.52833039e-01],\n", - " [-5.98286033e-01, 9.57714319e-01],\n", - " [-7.42678821e-01, -3.43333989e-01],\n", - " [ 9.68854368e-01, 3.85829717e-01],\n", - " [-2.75604427e-01, -7.95602053e-02],\n", - " [-2.03438953e-01, 9.11069755e-03],\n", - " [-1.43148184e-01, -1.46573469e-01],\n", - " [-9.49487031e-01, 3.15213263e-01],\n", - " [ 6.84628665e-01, 9.61936653e-01],\n", - " [-2.04393521e-01, 5.88398874e-01],\n", - " [-9.50749099e-01, 7.83227861e-01],\n", - " [ 2.88510352e-01, 9.24940228e-01],\n", - " [ 8.27838302e-01, -9.24422801e-01],\n", - " [ 6.05663717e-01, -7.59422839e-01],\n", - " [-9.75102127e-01, 1.72841713e-01],\n", - " [-2.73299128e-01, -2.70784408e-01],\n", - " [-4.23562735e-01, 1.31962135e-01],\n", - " [ 9.45790648e-01, -3.38832617e-01],\n", - " [ 1.88001692e-01, 2.13898122e-01],\n", - " [ 5.90824075e-02, -2.38077283e-01],\n", - " [-9.78379369e-01, -6.26421869e-01],\n", - " [-6.73744559e-01, 4.74431008e-01],\n", - " [ 5.22617698e-01, 8.93913805e-01],\n", - " [ 5.95449269e-01, 9.14583445e-01],\n", - " [ 3.90204303e-02, -2.39531472e-01],\n", - " [ 5.71989954e-01, -4.74963844e-01],\n", - " [ 6.58734083e-01, 5.09142876e-01],\n", - " [-2.53153235e-01, -3.53049845e-01],\n", - " [-4.54396307e-01, -6.79341435e-01],\n", - " [ 9.29932415e-01, 8.80713701e-01],\n", - " [ 2.17661366e-01, 6.54169917e-01],\n", - " [ 2.10923515e-02, -2.40477219e-01],\n", - " [-2.51346976e-01, -2.13967457e-01],\n", - " [ 5.17891407e-01, 2.40374476e-01],\n", - " [ 3.59319031e-01, -3.92084904e-02],\n", - " [-6.95064545e-01, 6.54330254e-01],\n", - " [ 4.39562589e-01, -7.08017647e-01],\n", - " [-2.80083835e-01, -2.79529452e-01],\n", - " [ 4.68703598e-01, -3.61453325e-01],\n", - " [ 1.75413545e-04, -6.19711876e-01],\n", - " [-4.56947744e-01, -5.46697043e-02],\n", - " [ 1.10423014e-01, 3.81866604e-01],\n", - " [ 4.95571673e-01, 6.30076528e-01],\n", - " [ 8.15737665e-01, 1.29877731e-01],\n", - " [-9.83589232e-01, -7.80846715e-01],\n", - " [-1.60895333e-01, -4.13245976e-01],\n", - " [-2.93852985e-01, 7.96879292e-01],\n", - " [-2.34337926e-01, 8.69962096e-01],\n", - " [ 8.72636318e-01, -3.94712389e-02],\n", - " [ 4.87689257e-01, -5.77459276e-01],\n", - " [-1.07371598e-01, -4.61379528e-01],\n", - " [-5.80118716e-01, 2.98643053e-01],\n", - " [ 3.65539849e-01, -8.42200577e-01],\n", - " [-5.28841615e-02, 4.88022923e-01],\n", - " [-9.50672925e-01, 2.79116750e-01],\n", - " [-6.66263402e-01, -8.50575149e-01],\n", - " [-2.95702636e-01, -7.57089794e-01],\n", - " [-4.11702067e-01, 9.68640268e-01],\n", - " [ 6.91891074e-01, 3.68866891e-01],\n", - " [-8.32779333e-03, 9.29424405e-01],\n", - " [ 5.37282348e-01, -9.28587794e-01],\n", - " [-5.35197616e-01, 4.34972018e-01],\n", - " [-8.97157609e-01, -2.00011898e-02],\n", - " [-2.74420083e-01, -9.79378104e-01]], dtype=float32),\n", - " 'x': array([[ 2.53951252e-01, 9.55786705e-02],\n", - " [-7.50578418e-02, -6.80214882e-01],\n", - " [ 3.25798206e-02, 9.48337466e-03],\n", - " [-3.87813628e-01, 6.97510600e-01],\n", - " [-5.65676749e-01, 2.43688509e-01],\n", - " [ 8.76319110e-02, 1.03990984e+00],\n", - " [-6.16806686e-01, 3.95518869e-01],\n", - " [ 2.43289366e-01, -8.41702744e-02],\n", - " [ 2.94539064e-01, -1.01268895e-01],\n", - " [-4.05492671e-02, 5.67550659e-01],\n", - " [-2.10222960e-01, -8.34424138e-01],\n", - " [ 2.03415319e-01, -5.96179128e-01],\n", - " [ 5.37187196e-02, 5.45547307e-01],\n", - " [ 1.32390037e-01, -9.65161324e-01],\n", - " [ 1.65536702e-01, 6.05887115e-01],\n", - " [-3.02492917e-01, 6.00012600e-01],\n", - " [ 2.75650620e-01, -9.83832717e-01],\n", - " [-6.47898912e-02, -6.00794613e-01],\n", - " [-5.88188708e-01, -2.37000689e-01],\n", - " [ 9.96765494e-02, 8.50363314e-01],\n", - " [-5.32090664e-01, -9.42790136e-02],\n", - " [ 2.07179919e-01, -1.25909996e+00],\n", - " [ 1.10319838e-01, 4.97487545e-01],\n", - " [-4.46753800e-01, -2.75851369e-01],\n", - " [-1.17684998e-01, 2.05271512e-01],\n", - " [ 1.16714276e-01, 8.40996325e-01],\n", - " [-5.64713925e-02, -9.23357546e-01],\n", - " [ 1.87672526e-01, 5.74100673e-01],\n", - " [-8.46691579e-02, 9.43871021e-01],\n", - " [ 2.76433289e-01, -8.63964856e-01],\n", - " [-2.89100736e-01, -1.02317505e-01],\n", - " [-6.26734078e-01, 1.00853950e-01],\n", - " [-4.36936170e-01, 2.48671815e-01],\n", - " [-1.94109902e-01, -8.82630050e-02],\n", - " [ 1.67907290e-02, -8.89653489e-02],\n", - " [-2.80460954e-01, -6.99714661e-01],\n", - " [ 1.34644642e-01, 1.57494441e-01],\n", - " [-9.07719210e-02, -7.15195656e-01],\n", - " [ 2.77671933e-01, -9.31531489e-01],\n", - " [-5.09281754e-01, -2.42856264e-01],\n", - " [-9.29788351e-01, 8.05726424e-02],\n", - " [-8.42837393e-02, -2.94265717e-01],\n", - " [ 2.63646871e-01, 1.20657015e+00],\n", - " [ 1.73785478e-01, -4.63589519e-01],\n", - " [-6.71075225e-01, -3.15490365e-01],\n", - " [-4.46836084e-01, 1.97128952e-01],\n", - " [ 2.35122561e-01, 9.65559661e-01],\n", - " [-5.12083948e-01, 3.75960112e-01],\n", - " [-8.36854875e-01, -2.65152324e-02],\n", - " [ 1.98727194e-02, 1.54331038e-02],\n", - " [ 3.30564082e-01, -8.96907687e-01],\n", - " [-3.21691066e-01, 4.10491407e-01],\n", - " [-8.06015953e-02, -5.55028141e-01],\n", - " [ 2.44089246e-01, 9.69316781e-01],\n", - " [-2.56520003e-01, 9.59672704e-02],\n", - " [-2.41986051e-01, 6.67658299e-02],\n", - " [ 9.41473544e-02, -6.79532588e-01],\n", - " [-2.85616964e-02, -6.53907835e-01],\n", - " [-7.54016936e-01, -3.16303074e-02],\n", - " [ 4.87661436e-02, 7.31747448e-01],\n", - " [ 2.60557346e-02, 8.43715191e-01],\n", - " [-9.80854273e-01, 6.41715247e-03],\n", - " [-7.41514742e-01, 4.41092215e-02],\n", - " [-1.08623601e-01, -9.39265966e-01],\n", - " [ 2.51657397e-01, 4.64884877e-01],\n", - " [ 1.61637682e-02, 9.98477459e-01],\n", - " [-4.26366359e-01, 2.59778708e-01],\n", - " [-6.06552601e-01, -4.12089765e-01],\n", - " [ 4.78178374e-02, 2.22041234e-01],\n", - " [ 2.13320345e-01, 1.69564873e-01],\n", - " [ 1.01722233e-01, 7.24871457e-02],\n", - " [-1.05442718e-01, 8.96624506e-01],\n", - " [-9.01241720e-01, 2.95188427e-01],\n", - " [ 6.63916618e-02, 5.10371685e-01],\n", - " [ 1.65370926e-01, 1.14619160e+00],\n", - " [-4.99440819e-01, 4.36905563e-01],\n", - " [ 2.41706863e-01, -1.17602539e+00],\n", - " [ 1.99771896e-01, -1.05111480e+00],\n", - " [-2.18879402e-01, 7.89253116e-01],\n", - " [-5.65532483e-02, 2.65488382e-02],\n", - " [ 1.33606523e-01, 4.53504145e-01],\n", - " [-1.50539234e-01, -1.02149868e+00],\n", - " [ 3.27841304e-02, 8.91788527e-02],\n", - " [ 1.30884781e-01, -3.18378985e-01],\n", - " [-8.74263108e-01, 1.33617908e-01],\n", - " [ 1.30533725e-01, 7.30226099e-01],\n", - " [-6.37296259e-01, 2.69927859e-01],\n", - " [-7.94296503e-01, 3.22382361e-01],\n", - " [ 1.83682933e-01, -1.37363449e-01],\n", - " [ 2.58710474e-01, -6.62893653e-01],\n", - " [-5.48165679e-01, 2.15799548e-04],\n", - " [-7.71245658e-02, -9.50605869e-02],\n", - " [-4.79652673e-01, -2.33250588e-01],\n", - " [-1.02507687e+00, 7.99317434e-02],\n", - " [-2.62450218e-01, 3.25194120e-01],\n", - " [ 9.65442061e-02, -2.79754162e-01],\n", - " [-2.77321483e-03, 7.49920383e-02],\n", - " [-2.35317081e-01, -1.08117975e-01],\n", - " [ 8.75033140e-02, -2.02557355e-01],\n", - " [ 2.60823607e-01, 1.06511045e+00],\n", - " [ 1.19672045e-01, -7.36494482e-01],\n", - " [-8.79814923e-02, 1.04258955e-01],\n", - " [ 2.59513497e-01, -5.82862794e-01],\n", - " [-1.21953085e-01, -3.64049733e-01],\n", - " [-3.48663330e-02, 3.19325000e-01],\n", - " [-1.66834649e-02, 1.80612490e-01],\n", - " [-4.52537745e-01, 1.31039545e-01],\n", - " [-3.82006407e-01, -5.68586946e-01],\n", - " [-9.27098870e-01, 9.98930261e-02],\n", - " [-8.85202736e-02, -1.10199966e-01],\n", - " [-3.87514569e-02, 8.42708707e-01],\n", - " [-1.15971282e-01, 7.20808744e-01],\n", - " [-2.45430186e-01, -6.77743077e-01],\n", - " [ 2.67549545e-01, -7.38927364e-01],\n", - " [-7.07757547e-02, -2.12103873e-01],\n", - " [ 1.02099039e-01, 5.50257146e-01],\n", - " [-4.43446413e-02, -9.41533327e-01],\n", - " [ 3.20424363e-02, 3.74172240e-01],\n", - " [-1.73633844e-01, 9.58132327e-01],\n", - " [-7.56465554e-01, -2.14453653e-01],\n", - " [-4.80762124e-01, -4.16752875e-01],\n", - " [-3.24135572e-02, 9.57767427e-01],\n", - " [-4.17075306e-01, -1.72911614e-01],\n", - " [-3.41798395e-01, 7.38798976e-01],\n", - " [ 6.08528070e-02, -9.97184515e-01],\n", - " [ 2.30055511e-01, 6.08631432e-01],\n", - " [-3.72935683e-01, 7.28057206e-01],\n", - " [-6.15877926e-01, -3.93291622e-01]], dtype=float32)}" + "{'alpha': array([[ 0.811846 ],\n", + " [ 1.1655375 ],\n", + " [ 0.690349 ],\n", + " [-1.3843633 ],\n", + " [-0.9577933 ],\n", + " [-1.2079015 ],\n", + " [ 1.5107541 ],\n", + " [ 0.96131116],\n", + " [ 0.09636723],\n", + " [ 0.02404731],\n", + " [-0.9631397 ],\n", + " [-1.5443141 ],\n", + " [ 0.3322162 ],\n", + " [-0.8289226 ],\n", + " [ 1.545839 ],\n", + " [ 1.3099878 ],\n", + " [-1.1293337 ],\n", + " [-1.1112677 ],\n", + " [-1.2859131 ],\n", + " [-0.00848705],\n", + " [-0.826562 ],\n", + " [-0.05401359],\n", + " [-0.5733626 ],\n", + " [ 1.441103 ],\n", + " [ 1.3535103 ],\n", + " [ 0.82197374],\n", + " [ 1.1405069 ],\n", + " [ 1.5078775 ],\n", + " [-0.24578635],\n", + " [-1.2971845 ],\n", + " [ 1.2200094 ],\n", + " [ 1.2412809 ],\n", + " [ 0.8568348 ],\n", + " [-0.14014174],\n", + " [ 1.2381922 ],\n", + " [-0.00458808],\n", + " [-0.29915315],\n", + " [-0.014065 ],\n", + " [-0.36981848],\n", + " [ 0.08939534],\n", + " [-1.298005 ],\n", + " [ 0.08107368],\n", + " [ 1.1527332 ],\n", + " [ 1.1010569 ],\n", + " [-0.97093356],\n", + " [-0.8327428 ],\n", + " [-1.2685784 ],\n", + " [ 0.38197798],\n", + " [ 0.66803664],\n", + " [ 0.32831058],\n", + " [ 0.7471112 ],\n", + " [-1.1261024 ],\n", + " [ 0.976383 ],\n", + " [ 1.2762269 ],\n", + " [-1.1819408 ],\n", + " [ 0.01273822],\n", + " [ 1.3222057 ],\n", + " [ 0.32145002],\n", + " [ 1.0546241 ],\n", + " [ 1.1236954 ],\n", + " [-0.86915904],\n", + " [-1.4725404 ],\n", + " [-0.57851046],\n", + " [-1.3111424 ],\n", + " [-1.5549136 ],\n", + " [-0.6562859 ],\n", + " [ 0.46044338],\n", + " [ 1.4144645 ],\n", + " [-0.80985373],\n", + " [ 0.67673105],\n", + " [ 0.00445595],\n", + " [ 0.9129677 ],\n", + " [ 0.73264116],\n", + " [ 0.9383897 ],\n", + " [ 1.4690843 ],\n", + " [ 0.44805676],\n", + " [-1.2438159 ],\n", + " [-1.0919781 ],\n", + " [ 0.10005943],\n", + " [-0.652468 ],\n", + " [ 1.3374164 ],\n", + " [ 1.5248041 ],\n", + " [ 0.9711219 ],\n", + " [ 1.5689381 ],\n", + " [ 0.6760715 ],\n", + " [ 0.23398188],\n", + " [-0.21398419],\n", + " [-1.2060331 ],\n", + " [ 0.79313487],\n", + " [ 0.930686 ],\n", + " [ 0.47900388],\n", + " [-0.33682647],\n", + " [ 1.5067643 ],\n", + " [ 1.3799299 ],\n", + " [-1.0610836 ],\n", + " [-0.19346415],\n", + " [ 0.6703556 ],\n", + " [ 0.55801785],\n", + " [-1.302337 ],\n", + " [ 0.8880511 ],\n", + " [ 1.4861475 ],\n", + " [ 1.5566795 ],\n", + " [ 0.6298603 ],\n", + " [-0.4204177 ],\n", + " [-0.9615261 ],\n", + " [ 0.93494344],\n", + " [ 0.03652098],\n", + " [-0.14426196],\n", + " [ 0.5074313 ],\n", + " [-1.4291117 ],\n", + " [-1.3903574 ],\n", + " [ 0.5471139 ],\n", + " [ 0.8714255 ],\n", + " [ 0.40494618],\n", + " [-0.6385578 ],\n", + " [ 0.6857702 ],\n", + " [ 1.2249589 ],\n", + " [-1.2018368 ],\n", + " [-0.99335515],\n", + " [ 0.5213185 ],\n", + " [ 0.653224 ],\n", + " [-1.5564892 ],\n", + " [ 0.29157606],\n", + " [-1.1314583 ],\n", + " [ 0.97327983],\n", + " [-1.3943021 ],\n", + " [-0.16822235],\n", + " [-0.5008124 ]], dtype=float32),\n", + " 'r': array([[0.09191831],\n", + " [0.09392393],\n", + " [0.09252372],\n", + " [0.0982731 ],\n", + " [0.10405302],\n", + " [0.10219509],\n", + " [0.08235917],\n", + " [0.10991579],\n", + " [0.09014853],\n", + " [0.11165598],\n", + " [0.10175227],\n", + " [0.10613288],\n", + " [0.09984934],\n", + " [0.10292789],\n", + " [0.09939721],\n", + " [0.09315164],\n", + " [0.08483255],\n", + " [0.09881789],\n", + " [0.08632301],\n", + " [0.11751007],\n", + " [0.11542765],\n", + " [0.08752786],\n", + " [0.09544735],\n", + " [0.09462399],\n", + " [0.10281887],\n", + " [0.09729108],\n", + " [0.09319801],\n", + " [0.07714061],\n", + " [0.10567886],\n", + " [0.09792958],\n", + " [0.09385688],\n", + " [0.09495461],\n", + " [0.1094628 ],\n", + " [0.09255824],\n", + " [0.09803515],\n", + " [0.10114206],\n", + " [0.12053802],\n", + " [0.10272378],\n", + " [0.09209541],\n", + " [0.09367829],\n", + " [0.09309962],\n", + " [0.0895135 ],\n", + " [0.10099211],\n", + " [0.07979899],\n", + " [0.09191797],\n", + " [0.10424455],\n", + " [0.10285137],\n", + " [0.10614532],\n", + " [0.09172467],\n", + " [0.11042022],\n", + " [0.1118973 ],\n", + " [0.10933825],\n", + " [0.0880883 ],\n", + " [0.09732444],\n", + " [0.09554953],\n", + " [0.0945937 ],\n", + " [0.09006199],\n", + " [0.10566922],\n", + " [0.11036118],\n", + " [0.09307823],\n", + " [0.10670979],\n", + " [0.09900924],\n", + " [0.09903299],\n", + " [0.08567982],\n", + " [0.10059226],\n", + " [0.10335171],\n", + " [0.11039949],\n", + " [0.10738601],\n", + " [0.07921619],\n", + " [0.09219779],\n", + " [0.09759328],\n", + " [0.11219482],\n", + " [0.09465504],\n", + " [0.0967709 ],\n", + " [0.09380519],\n", + " [0.07976504],\n", + " [0.10182797],\n", + " [0.09801979],\n", + " [0.09238707],\n", + " [0.10204466],\n", + " [0.09565485],\n", + " [0.09219388],\n", + " [0.0845885 ],\n", + " [0.09161578],\n", + " [0.10767981],\n", + " [0.10286254],\n", + " [0.09314913],\n", + " [0.09183365],\n", + " [0.11092991],\n", + " [0.09702247],\n", + " [0.10227211],\n", + " [0.09580602],\n", + " [0.10044043],\n", + " [0.09645868],\n", + " [0.10505722],\n", + " [0.09277388],\n", + " [0.10051655],\n", + " [0.10390085],\n", + " [0.09256679],\n", + " [0.09127223],\n", + " [0.09056184],\n", + " [0.10505585],\n", + " [0.0941221 ],\n", + " [0.08350186],\n", + " [0.0976114 ],\n", + " [0.09832696],\n", + " [0.08880883],\n", + " [0.11294892],\n", + " [0.11518918],\n", + " [0.0965459 ],\n", + " [0.08664226],\n", + " [0.1084713 ],\n", + " [0.08493465],\n", + " [0.10586594],\n", + " [0.09469504],\n", + " [0.10193761],\n", + " [0.09030203],\n", + " [0.10500195],\n", + " [0.09645825],\n", + " [0.11158349],\n", + " [0.10557099],\n", + " [0.11049099],\n", + " [0.10678031],\n", + " [0.08420987],\n", + " [0.11620536],\n", + " [0.10082643],\n", + " [0.09813577],\n", + " [0.11379081]], dtype=float32),\n", + " 'theta': array([[-0.31046662, -0.5192129 ],\n", + " [ 0.25899222, -0.2448592 ],\n", + " [-0.09939136, -0.9546972 ],\n", + " [-0.52233803, -0.00598578],\n", + " [ 0.3325237 , -0.14211619],\n", + " [ 0.31660682, -0.9715865 ],\n", + " [-0.0025572 , 0.01219535],\n", + " [ 0.0804862 , 0.3379917 ],\n", + " [ 0.8448752 , -0.15739872],\n", + " [ 0.67502624, 0.30911726],\n", + " [-0.5518989 , 0.38458768],\n", + " [-0.5097594 , -0.8797034 ],\n", + " [ 0.1299824 , -0.11516342],\n", + " [ 0.54469794, 0.45290363],\n", + " [-0.79390055, -0.5927222 ],\n", + " [-0.30554888, -0.70457566],\n", + " [ 0.12381205, 0.66449183],\n", + " [ 0.9896895 , 0.43069273],\n", + " [-0.5636638 , -0.37905496],\n", + " [-0.47466183, 0.8630436 ],\n", + " [-0.25352567, -0.6127405 ],\n", + " [ 0.8454494 , 0.20636818],\n", + " [ 0.05136211, -0.9957166 ],\n", + " [-0.87917453, -0.43837497],\n", + " [-0.9768608 , -0.45436805],\n", + " [ 0.45989057, 0.1390453 ],\n", + " [ 0.8821207 , -0.6721646 ],\n", + " [ 0.6098006 , 0.02059207],\n", + " [-0.149392 , 0.0668803 ],\n", + " [ 0.816167 , -0.7244783 ],\n", + " [-0.10582127, 0.52367836],\n", + " [ 0.04364904, -0.71592885],\n", + " [-0.661958 , -0.889716 ],\n", + " [-0.75828946, -0.5449933 ],\n", + " [-0.0180878 , 0.93717176],\n", + " [ 0.6192016 , -0.18021217],\n", + " [ 0.6736404 , -0.86073035],\n", + " [ 0.40462115, -0.79911304],\n", + " [-0.60797983, -0.39942083],\n", + " [-0.8760179 , -0.7573692 ],\n", + " [ 0.7658219 , -0.25693515],\n", + " [ 0.6227054 , 0.6554283 ],\n", + " [-0.07259648, 0.8712275 ],\n", + " [ 0.90728253, -0.5650344 ],\n", + " [ 0.79422444, 0.68135595],\n", + " [-0.13548942, 0.82118744],\n", + " [ 0.59628546, 0.31666732],\n", + " [ 0.44013777, 0.09197523],\n", + " [-0.46371552, 0.22121692],\n", + " [-0.09013044, 0.3197752 ],\n", + " [-0.05406016, -0.22707504],\n", + " [ 0.97624415, -0.86260617],\n", + " [ 0.8244347 , 0.08089861],\n", + " [-0.06575817, -0.15638828],\n", + " [-0.06798965, -0.5022756 ],\n", + " [ 0.815011 , -0.5965575 ],\n", + " [-0.09590898, -0.521891 ],\n", + " [ 0.7684214 , 0.455757 ],\n", + " [ 0.14870167, 0.07132804],\n", + " [-0.65875846, -0.8905175 ],\n", + " [-0.7458266 , 0.11753188],\n", + " [ 0.810346 , 0.25945064],\n", + " [ 0.9870774 , 0.54629356],\n", + " [-0.35242337, 0.6756984 ],\n", + " [-0.44706544, 0.41715237],\n", + " [-0.9747908 , 0.6067063 ],\n", + " [ 0.31205034, -0.90819144],\n", + " [-0.17396212, -0.7797098 ],\n", + " [-0.24114859, -0.52675587],\n", + " [-0.05354288, -0.21780755],\n", + " [-0.6323093 , 0.58807445],\n", + " [-0.8843161 , -0.5124541 ],\n", + " [-0.2678302 , 0.8521883 ],\n", + " [ 0.5850157 , -0.7819173 ],\n", + " [ 0.6423173 , -0.8578282 ],\n", + " [ 0.14295006, 0.85248995],\n", + " [-0.15930556, 0.8565224 ],\n", + " [ 0.03929539, -0.98305523],\n", + " [ 0.31724995, 0.5836652 ],\n", + " [-0.26811373, 0.73696667],\n", + " [ 0.9598292 , 0.7289081 ],\n", + " [ 0.93546367, -0.15821293],\n", + " [-0.16623883, -0.3217834 ],\n", + " [-0.41705012, 0.47349688],\n", + " [ 0.39159286, -0.37216726],\n", + " [-0.45079473, 0.602687 ],\n", + " [-0.65825534, -0.32123926],\n", + " [-0.44960454, -0.07970022],\n", + " [-0.4009049 , -0.6618951 ],\n", + " [-0.2572403 , -0.3288973 ],\n", + " [ 0.5176336 , 0.7126518 ],\n", + " [-0.760555 , -0.17923199],\n", + " [ 0.08348245, -0.54072154],\n", + " [-0.67218244, -0.81798273],\n", + " [-0.5323433 , -0.36760053],\n", + " [ 0.7281168 , 0.58143365],\n", + " [-0.11570914, -0.80179965],\n", + " [-0.27270266, -0.06209885],\n", + " [-0.08812815, -0.6612023 ],\n", + " [ 0.8177256 , 0.19989038],\n", + " [-0.35036862, 0.83757067],\n", + " [ 0.06185859, 0.42102838],\n", + " [ 0.13955897, -0.4102967 ],\n", + " [ 0.42227042, -0.32889032],\n", + " [-0.6494401 , 0.26140058],\n", + " [-0.78374976, 0.79485214],\n", + " [-0.4989267 , 0.22898036],\n", + " [ 0.5334011 , -0.7786707 ],\n", + " [ 0.772915 , 0.09377074],\n", + " [-0.03455243, -0.7146796 ],\n", + " [-0.31524613, -0.58907634],\n", + " [ 0.27913257, -0.39805576],\n", + " [-0.5170899 , 0.94988567],\n", + " [-0.6617844 , -0.7926895 ],\n", + " [ 0.30463684, -0.65223706],\n", + " [ 0.95110357, 0.01535923],\n", + " [ 0.2822953 , -0.58528644],\n", + " [-0.58930004, -0.2095232 ],\n", + " [ 0.6582694 , 0.2106661 ],\n", + " [-0.71783096, -0.15897161],\n", + " [ 0.5494588 , -0.41081655],\n", + " [ 0.81622076, 0.12805746],\n", + " [ 0.38824794, -0.9268966 ],\n", + " [ 0.8239423 , -0.38376778],\n", + " [ 0.8703679 , -0.9352464 ],\n", + " [ 0.89721346, -0.3582009 ],\n", + " [ 0.02263427, 0.9801298 ],\n", + " [ 0.895235 , -0.8068917 ]], dtype=float32),\n", + " 'x': array([[-2.73417473e-01, -8.09137672e-02],\n", + " [ 2.77036577e-01, -2.69960642e-01],\n", + " [-4.24015194e-01, -5.45872867e-01],\n", + " [-1.05365962e-01, 2.68545985e-01],\n", + " [ 1.75225988e-01, -4.20728564e-01],\n", + " [-1.76863164e-01, -1.00642967e+00],\n", + " [ 2.48126850e-01, 9.26423967e-02],\n", + " [ 1.70122087e-02, 2.72208542e-01],\n", + " [-1.46389037e-01, -7.00040758e-01],\n", + " [-3.34270835e-01, -2.56051958e-01],\n", + " [ 1.89788073e-01, 5.78658581e-01],\n", + " [-7.29688287e-01, -3.67685556e-01],\n", + " [ 3.33911151e-01, -1.40779510e-01],\n", + " [-3.85865510e-01, -1.40787229e-01],\n", + " [-7.28009880e-01, 2.41620854e-01],\n", + " [-4.40245688e-01, -1.92153111e-01],\n", + " [-2.71169245e-01, 3.05618912e-01],\n", + " [-7.10533679e-01, -4.83837098e-01],\n", + " [-3.92342120e-01, 4.76944335e-02],\n", + " [ 9.28784534e-02, 9.44903255e-01],\n", + " [-2.84350991e-01, -3.38912606e-01],\n", + " [-4.06347096e-01, -4.56624061e-01],\n", + " [-3.37575883e-01, -7.92172849e-01],\n", + " [-6.69410408e-01, 4.05521661e-01],\n", + " [-7.39865899e-01, 4.69859362e-01],\n", + " [-1.07278086e-01, -1.55607030e-01],\n", + " [ 1.40414655e-01, -1.01434314e+00],\n", + " [-1.90904528e-01, -3.39645326e-01],\n", + " [ 2.94158250e-01, 1.27213925e-01],\n", + " [ 2.11627930e-01, -1.18368745e+00],\n", + " [-1.32169081e-02, 5.33264697e-01],\n", + " [-1.94647789e-01, -4.47256684e-01],\n", + " [-7.75519371e-01, -7.83201531e-02],\n", + " [-5.79909265e-01, 1.37894318e-01],\n", + " [-3.67881477e-01, 7.68132865e-01],\n", + " [ 4.07285877e-02, -5.65734982e-01],\n", + " [ 2.32891902e-01, -1.12048781e+00],\n", + " [ 7.37657398e-02, -8.52613330e-01],\n", + " [-3.76470715e-01, 1.14185952e-01],\n", + " [-8.11674833e-01, 9.22605619e-02],\n", + " [-8.47542733e-02, -8.12855482e-01],\n", + " [-5.64557493e-01, 3.03878188e-02],\n", + " [-2.73715496e-01, 7.59678721e-01],\n", + " [ 4.41153459e-02, -9.69929636e-01],\n", + " [-7.41502583e-01, -1.55680373e-01],\n", + " [-1.64720863e-01, 5.99354684e-01],\n", + " [-3.64942580e-01, -2.95909882e-01],\n", + " [-2.77653802e-02, -2.06621721e-01],\n", + " [ 1.50535062e-01, 5.41138887e-01],\n", + " [ 1.92139149e-01, 3.25451434e-01],\n", + " [ 1.33301392e-01, -4.63032946e-02],\n", + " [ 2.16681108e-01, -1.39896774e+00],\n", + " [-3.40835840e-01, -4.52780277e-01],\n", + " [ 1.21174738e-01, 2.90472377e-02],\n", + " [-1.17012754e-01, -3.95502687e-01],\n", + " [ 1.90116048e-01, -9.96924758e-01],\n", + " [-1.64691880e-01, -2.13921264e-01],\n", + " [-5.15368223e-01, -1.87701717e-01],\n", + " [ 1.48884773e-01, 4.12713327e-02],\n", + " [-8.05260777e-01, -7.99493268e-02],\n", + " [-1.25393584e-01, 5.28983176e-01],\n", + " [-4.96747881e-01, -4.88073528e-01],\n", + " [-7.51338840e-01, -3.65830243e-01],\n", + " [ 4.34079915e-02, 6.44184113e-01],\n", + " [ 2.30445877e-01, 5.10514736e-01],\n", + " [ 7.16068000e-02, 1.05522442e+00],\n", + " [-7.26334229e-02, -8.13785732e-01],\n", + " [-4.07628328e-01, -3.22251856e-01],\n", + " [-2.38362625e-01, -2.59322137e-01],\n", + " [ 1.30005881e-01, -5.84139936e-02],\n", + " [ 3.16313535e-01, 8.63376498e-01],\n", + " [-6.69069767e-01, 3.51728320e-01],\n", + " [-9.28360224e-02, 8.55281293e-01],\n", + " [ 1.67969659e-01, -8.88511479e-01],\n", + " [ 1.07135452e-01, -9.67442632e-01],\n", + " [-3.81990910e-01, 5.36275864e-01],\n", + " [-2.10301131e-01, 6.21866047e-01],\n", + " [-3.72178257e-01, -8.09907556e-01],\n", + " [-2.95118243e-01, 1.97612807e-01],\n", + " [-4.45644546e-04, 6.48742855e-01],\n", + " [-9.21995819e-01, -7.02241957e-02],\n", + " [-2.95360595e-01, -6.81249738e-01],\n", + " [-4.73442897e-02, -4.01571728e-02],\n", + " [ 2.10256353e-01, 7.21327484e-01],\n", + " [ 3.20258260e-01, -4.72681075e-01],\n", + " [ 2.42655575e-01, 7.68773019e-01],\n", + " [-3.51582617e-01, 2.18525678e-01],\n", + " [-9.15153325e-02, 1.75770089e-01],\n", + " [-4.23682958e-01, -1.05504133e-01],\n", + " [-1.06511913e-01, 2.71457620e-02],\n", + " [-5.29181302e-01, 1.85035408e-01],\n", + " [-3.24107260e-01, 3.79394174e-01],\n", + " [-6.68898523e-02, -3.41144264e-01],\n", + " [-7.85406768e-01, -8.38936307e-03],\n", + " [-3.35096180e-01, 2.47879084e-02],\n", + " [-5.84948957e-01, -1.21557318e-01],\n", + " [-3.20011824e-01, -4.22691882e-01],\n", + " [ 1.01399325e-01, 2.03935474e-01],\n", + " [-2.55303651e-01, -4.94475722e-01],\n", + " [-4.11977261e-01, -3.66062492e-01],\n", + " [-8.68470743e-02, 9.30237532e-01],\n", + " [-8.99696425e-02, 3.59016776e-01],\n", + " [ 1.34620503e-01, -3.33365768e-01],\n", + " [ 2.60200709e-01, -5.65231442e-01],\n", + " [ 3.14746089e-02, 5.64013839e-01],\n", + " [ 3.00542265e-01, 1.19535053e+00],\n", + " [ 1.47868738e-01, 5.17950714e-01],\n", + " [ 1.88343808e-01, -9.44012642e-01],\n", + " [-2.62164533e-01, -4.24253196e-01],\n", + " [-2.66153723e-01, -5.76500952e-01],\n", + " [-3.73903632e-01, -2.78862804e-01],\n", + " [ 2.58546293e-01, -4.22415018e-01],\n", + " [-1.35729939e-03, 1.10230434e+00],\n", + " [-6.81164503e-01, -5.08559495e-02],\n", + " [ 8.02454948e-02, -7.33053863e-01],\n", + " [-3.54499668e-01, -5.97117186e-01],\n", + " [ 6.63639084e-02, -5.28517485e-01],\n", + " [-2.76984870e-01, 1.70607135e-01],\n", + " [-3.11775416e-01, -3.97322059e-01],\n", + " [-2.73231924e-01, 4.50744480e-01],\n", + " [ 2.35802069e-01, -6.14856482e-01],\n", + " [-4.16124791e-01, -5.97084582e-01],\n", + " [-2.86087506e-02, -8.99252355e-01],\n", + " [-2.54325196e-02, -9.30192709e-01],\n", + " [ 2.69500047e-01, -1.18069100e+00],\n", + " [-1.13436393e-01, -9.86972153e-01],\n", + " [-3.62310737e-01, 6.60620689e-01],\n", + " [ 2.87348330e-01, -1.25822067e+00]], dtype=float32)}" ] }, - "execution_count": 6, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 6 + "execution_count": 9 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T22:06:35.724356Z", - "start_time": "2024-10-14T22:06:35.716334Z" + "end_time": "2024-10-18T20:36:21.177448Z", + "start_time": "2024-10-18T20:36:21.163813Z" } }, "cell_type": "code", @@ -678,13 +684,13 @@ ], "id": "461e6dfcdf6944b", "outputs": [], - "execution_count": 57 + "execution_count": 10 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:27:53.903783Z", - "start_time": "2024-10-14T20:27:53.883195Z" + "end_time": "2024-10-18T20:36:23.287123Z", + "start_time": "2024-10-18T20:36:23.274401Z" } }, "cell_type": "code", @@ -695,13 +701,13 @@ ], "id": "ed2cec2c3fdedb22", "outputs": [], - "execution_count": 8 + "execution_count": 11 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:28:10.611306Z", - "start_time": "2024-10-14T20:27:54.318482Z" + "end_time": "2024-10-18T20:36:37.893513Z", + "start_time": "2024-10-18T20:36:23.887416Z" } }, "cell_type": "code", @@ -711,13 +717,13 @@ ], "id": "7d1bffc7f17b5aaa", "outputs": [], - "execution_count": 9 + "execution_count": 12 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:28:22.181678Z", - "start_time": "2024-10-14T20:28:22.155164Z" + "end_time": "2024-10-18T20:38:23.502875Z", + "start_time": "2024-10-18T20:38:23.474475Z" } }, "cell_type": "code", @@ -727,13 +733,13 @@ ], "id": "d7f545fd2ee536d8", "outputs": [], - "execution_count": 10 + "execution_count": 13 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:28:23.402981Z", - "start_time": "2024-10-14T20:28:23.353636Z" + "end_time": "2024-10-18T20:38:25.248748Z", + "start_time": "2024-10-18T20:38:25.153585Z" } }, "cell_type": "code", @@ -748,13 +754,13 @@ ], "id": "be6ed75d4d899021", "outputs": [], - "execution_count": 11 + "execution_count": 14 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:28:24.123891Z", - "start_time": "2024-10-14T20:28:24.106889Z" + "end_time": "2024-10-18T20:38:26.228405Z", + "start_time": "2024-10-18T20:38:26.217686Z" } }, "cell_type": "code", @@ -767,13 +773,13 @@ ], "id": "b1dc4f27eb17b270", "outputs": [], - "execution_count": 12 + "execution_count": 15 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:28:25.265638Z", - "start_time": "2024-10-14T20:28:25.223633Z" + "end_time": "2024-10-18T20:38:28.542989Z", + "start_time": "2024-10-18T20:38:28.513271Z" } }, "cell_type": "code", @@ -783,13 +789,13 @@ ], "id": "ad75c807a7617e0a", "outputs": [], - "execution_count": 13 + "execution_count": 16 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:28:26.705303Z", - "start_time": "2024-10-14T20:28:26.693791Z" + "end_time": "2024-10-18T21:08:38.953498Z", + "start_time": "2024-10-18T21:08:38.926729Z" } }, "cell_type": "code", @@ -797,59 +803,59 @@ "class BatchLossHistory(keras.callbacks.Callback):\n", " def __init__(self):\n", " super().__init__()\n", - " self.batch_losses = []\n", + " self.training_loss = []\n", + " self.validation_loss = []\n", "\n", " def on_train_batch_end(self, batch, logs=None):\n", " # 'logs' is a dictionary containing loss and other metrics\n", - " loss = logs.get('loss')\n", - " self.batch_losses.append(loss)" + " training_loss = logs.get('loss')\n", + " self.training_loss.append(training_loss)\n", + " \n", + " def on_test_batch_end(self, batch, logs=None):\n", + " validation_loss = logs.get('loss')\n", + " self.validation_loss.append(validation_loss)" ], "id": "9d08447b96c58cf4", "outputs": [], - "execution_count": 14 + "execution_count": 36 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:28:28.279798Z", - "start_time": "2024-10-14T20:28:28.258494Z" + "end_time": "2024-10-18T21:08:39.546893Z", + "start_time": "2024-10-18T21:08:39.522895Z" } }, "cell_type": "code", - "source": [ - "approximator.compile(\n", - " optimizer=optimizer,\n", - " loss=\"sparse_categorical_crossentropy\"\n", - ")" - ], + "source": "approximator.compile(optimizer=optimizer)", "id": "120d9b0fed8a8a01", "outputs": [], - "execution_count": 15 + "execution_count": 37 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:28:28.698900Z", - "start_time": "2024-10-14T20:28:28.690911Z" + "end_time": "2024-10-18T21:08:39.859389Z", + "start_time": "2024-10-18T21:08:39.852730Z" } }, "cell_type": "code", "source": "batch_loss_history = BatchLossHistory()", "id": "50d2d9f6d6419075", "outputs": [], - "execution_count": 16 + "execution_count": 38 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:38:09.636742Z", - "start_time": "2024-10-14T20:28:29.157027Z" + "end_time": "2024-10-18T21:11:12.394144Z", + "start_time": "2024-10-18T21:08:40.381171Z" } }, "cell_type": "code", "source": [ "history = approximator.fit(\n", - " epochs=30,\n", + " epochs=10,\n", " dataset=training_dataset,\n", " validation_data=validation_dataset,\n", " callbacks=[batch_loss_history]\n", @@ -861,112 +867,93 @@ "name": "stderr", "output_type": "stream", "text": [ - "INFO:bayesflow:Fitting on dataset instance of OfflineDataset.\n", - "INFO:bayesflow:Building on a test batch.\n" + "INFO:bayesflow:Fitting on dataset instance of OfflineDataset.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 16ms/step - loss: 0.6919 - loss/inference_loss: 0.6919 - val_loss: 0.6134 - val_loss/inference_loss: 0.6134\n", - "Epoch 2/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m18s\u001B[0m 18ms/step - loss: 0.6234 - loss/inference_loss: 0.6234 - val_loss: 0.6321 - val_loss/inference_loss: 0.6321\n", - "Epoch 3/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m18s\u001B[0m 18ms/step - loss: 0.6018 - loss/inference_loss: 0.6018 - val_loss: 0.4567 - val_loss/inference_loss: 0.4567\n", - "Epoch 4/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m21s\u001B[0m 20ms/step - loss: 0.6079 - loss/inference_loss: 0.6079 - val_loss: 0.6692 - val_loss/inference_loss: 0.6692\n", - "Epoch 5/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m24s\u001B[0m 24ms/step - loss: 0.5956 - loss/inference_loss: 0.5956 - val_loss: 0.7312 - val_loss/inference_loss: 0.7312\n", - "Epoch 6/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m22s\u001B[0m 22ms/step - loss: 0.5911 - loss/inference_loss: 0.5911 - val_loss: 0.5461 - val_loss/inference_loss: 0.5461\n", - "Epoch 7/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m22s\u001B[0m 22ms/step - loss: 0.5907 - loss/inference_loss: 0.5907 - val_loss: 0.5829 - val_loss/inference_loss: 0.5829\n", - "Epoch 8/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m24s\u001B[0m 24ms/step - loss: 0.5820 - loss/inference_loss: 0.5820 - val_loss: 0.7137 - val_loss/inference_loss: 0.7137\n", - "Epoch 9/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m25s\u001B[0m 25ms/step - loss: 0.5801 - loss/inference_loss: 0.5801 - val_loss: 0.5453 - val_loss/inference_loss: 0.5453\n", - "Epoch 10/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m26s\u001B[0m 25ms/step - loss: 0.5841 - loss/inference_loss: 0.5841 - val_loss: 0.6155 - val_loss/inference_loss: 0.6155\n", - "Epoch 11/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 17ms/step - loss: 0.5748 - loss/inference_loss: 0.5748 - val_loss: 0.4574 - val_loss/inference_loss: 0.4574\n", - "Epoch 12/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 16ms/step - loss: 0.5714 - loss/inference_loss: 0.5714 - val_loss: 0.9205 - val_loss/inference_loss: 0.9205\n", - "Epoch 13/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m23s\u001B[0m 22ms/step - loss: 0.5804 - loss/inference_loss: 0.5804 - val_loss: 0.4696 - val_loss/inference_loss: 0.4696\n", - "Epoch 14/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m20s\u001B[0m 19ms/step - loss: 0.5691 - loss/inference_loss: 0.5691 - val_loss: 0.5795 - val_loss/inference_loss: 0.5795\n", - "Epoch 15/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m22s\u001B[0m 22ms/step - loss: 0.5663 - loss/inference_loss: 0.5663 - val_loss: 0.7035 - val_loss/inference_loss: 0.7035\n", - "Epoch 16/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m23s\u001B[0m 22ms/step - loss: 0.5692 - loss/inference_loss: 0.5692 - val_loss: 0.6051 - val_loss/inference_loss: 0.6051\n", - "Epoch 17/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 17ms/step - loss: 0.5635 - loss/inference_loss: 0.5635 - val_loss: 0.5303 - val_loss/inference_loss: 0.5303\n", - "Epoch 18/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 17ms/step - loss: 0.5730 - loss/inference_loss: 0.5730 - val_loss: 0.4921 - val_loss/inference_loss: 0.4921\n", - "Epoch 19/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m24s\u001B[0m 23ms/step - loss: 0.5641 - loss/inference_loss: 0.5641 - val_loss: 0.5474 - val_loss/inference_loss: 0.5474\n", - "Epoch 20/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 15ms/step - loss: 0.5669 - loss/inference_loss: 0.5669 - val_loss: 0.5979 - val_loss/inference_loss: 0.5979\n", - "Epoch 21/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5698 - loss/inference_loss: 0.5698 - val_loss: 0.6764 - val_loss/inference_loss: 0.6764\n", - "Epoch 22/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m15s\u001B[0m 15ms/step - loss: 0.5697 - loss/inference_loss: 0.5697 - val_loss: 0.5636 - val_loss/inference_loss: 0.5636\n", - "Epoch 23/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m19s\u001B[0m 19ms/step - loss: 0.5697 - loss/inference_loss: 0.5697 - val_loss: 0.5355 - val_loss/inference_loss: 0.5355\n", - "Epoch 24/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m18s\u001B[0m 17ms/step - loss: 0.5623 - loss/inference_loss: 0.5623 - val_loss: 0.4090 - val_loss/inference_loss: 0.4090\n", - "Epoch 25/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5686 - loss/inference_loss: 0.5686 - val_loss: 0.5841 - val_loss/inference_loss: 0.5841\n", - "Epoch 26/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5650 - loss/inference_loss: 0.5650 - val_loss: 0.5608 - val_loss/inference_loss: 0.5608\n", - "Epoch 27/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m17s\u001B[0m 16ms/step - loss: 0.5646 - loss/inference_loss: 0.5646 - val_loss: 0.5898 - val_loss/inference_loss: 0.5898\n", - "Epoch 28/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5608 - loss/inference_loss: 0.5608 - val_loss: 0.3862 - val_loss/inference_loss: 0.3862\n", - "Epoch 29/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5666 - loss/inference_loss: 0.5666 - val_loss: 0.5265 - val_loss/inference_loss: 0.5265\n", - "Epoch 30/30\n", - "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5673 - loss/inference_loss: 0.5673 - val_loss: 0.7562 - val_loss/inference_loss: 0.7562\n" + "Epoch 1/10\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m15s\u001B[0m 14ms/step - loss: 0.5526 - loss/inference_loss: 0.5526 - val_loss: 0.5321 - val_loss/inference_loss: 0.5321\n", + "Epoch 2/10\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m15s\u001B[0m 14ms/step - loss: 0.5567 - loss/inference_loss: 0.5567 - val_loss: 0.7145 - val_loss/inference_loss: 0.7145\n", + "Epoch 3/10\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m15s\u001B[0m 14ms/step - loss: 0.5549 - loss/inference_loss: 0.5549 - val_loss: 0.5232 - val_loss/inference_loss: 0.5232\n", + "Epoch 4/10\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m15s\u001B[0m 15ms/step - loss: 0.5476 - loss/inference_loss: 0.5476 - val_loss: 0.4488 - val_loss/inference_loss: 0.4488\n", + "Epoch 5/10\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 15ms/step - loss: 0.5523 - loss/inference_loss: 0.5523 - val_loss: 0.4283 - val_loss/inference_loss: 0.4283\n", + "Epoch 6/10\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 15ms/step - loss: 0.5580 - loss/inference_loss: 0.5580 - val_loss: 0.6679 - val_loss/inference_loss: 0.6679\n", + "Epoch 7/10\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 15ms/step - loss: 0.5506 - loss/inference_loss: 0.5506 - val_loss: 0.5052 - val_loss/inference_loss: 0.5052\n", + "Epoch 8/10\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m14s\u001B[0m 14ms/step - loss: 0.5521 - loss/inference_loss: 0.5521 - val_loss: 0.3337 - val_loss/inference_loss: 0.3337\n", + "Epoch 9/10\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m16s\u001B[0m 16ms/step - loss: 0.5533 - loss/inference_loss: 0.5533 - val_loss: 0.6080 - val_loss/inference_loss: 0.6080\n", + "Epoch 10/10\n", + "\u001B[1m1024/1024\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m15s\u001B[0m 15ms/step - loss: 0.5499 - loss/inference_loss: 0.5499 - val_loss: 0.4822 - val_loss/inference_loss: 0.4822\n" ] } ], - "execution_count": 17 + "execution_count": 39 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-14T20:38:43.704457Z", - "start_time": "2024-10-14T20:38:42.786411Z" + "end_time": "2024-10-18T21:11:18.353698Z", + "start_time": "2024-10-18T21:11:18.340415Z" } }, "cell_type": "code", - "source": "plt.plot(batch_loss_history.batch_losses)", - "id": "3bc7cb16f130a630", + "source": "type(batch_loss_history.validation_loss)", + "id": "a30166eb4abb1951", "outputs": [ { "data": { "text/plain": [ - "[]" + "list" ] }, - "execution_count": 18, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" - }, + } + ], + "execution_count": 40 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-18T21:11:30.079377Z", + "start_time": "2024-10-18T21:11:29.336090Z" + } + }, + "cell_type": "code", + "source": [ + "import pandas as pd\n", + "\n", + "f = plot_losses(\n", + " train_losses=pd.DataFrame(batch_loss_history.training_loss), \n", + " val_losses=pd.DataFrame(batch_loss_history.validation_loss), \n", + " moving_average=True\n", + ")" + ], + "id": "3bc7cb16f130a630", + "outputs": [ { "data": { "text/plain": [ - "
" + "
" ], - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGeCAYAAAC+dvpwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABhCklEQVR4nO3dd3gU5doG8HuTkEIKIYVAAIHQQglJSChCEERApChFOII0PYiFYuGIgp+KejxR8YgKilQR4SiC2ABFQVBBDBhIqIEUOqRBei/v90fIkt1smdnMZifL/bsuLt2d2Zl3nszOPvO20QghBIiIiIhUxMHWBSAiIiLSxwSFiIiIVIcJChEREakOExQiIiJSHSYoREREpDpMUIiIiEh1mKAQERGR6jBBISIiItVhgkJERESqIztBSUtLw7x589C7d28MGDAA0dHRKCkpMbjuvn378MADDyA8PByjR4/Gnj17dJZHRkaic+fOOv8KCgosOxIiIiKyG05yVhZCYN68efDy8sKmTZuQk5ODRYsWwcHBAS+88ILOugkJCZgzZw4WLFiAgQMHYv/+/Xj66aexdetWBAcHIy0tDXl5edi9ezdcXV21n2vcuLEyR0ZEREQNlqwEJSUlBXFxcThw4AD8/PwAAPPmzcPbb79dK0HZvn07+vbti2nTpgEA2rRpg19//RU//vgjgoODkZycDH9/f7Ru3bpOB3D9eh6UfJqQRgP4+noqvl17wziZxxhJwziZxxiZxxhJo4Y4VZfBHFkJir+/P9asWaNNTqrl5+fXWnfs2LEoKyur9X5eXh4AICkpCe3atZOze4OEgFWCbK3t2hvGyTzGSBrGyTzGyDzGSJqGECdZCYqXlxcGDBigfV1ZWYmNGzeib9++tdZt3769zuvExEQcPHgQDz30EAAgOTkZRUVFmDp1Ks6dO4cuXbpg0aJFspMWjUbW6pK3p/R27Q3jZB5jJA3jZB5jZB5jJI0a4iR13xohLM+h3n77bWzatAlbt25Fp06djK5348YNTJ48GX5+ftiwYQMcHBwwdepUpKam4rXXXoOHhwdWr16NY8eOYceOHfDw8LC0SERERGQHLE5QlixZgk8//RRLly7Fvffea3S9zMxMPPLIIygtLcUXX3wBHx8fAEBpaSnKysrg7u4OACgpKcHAgQPx0ksvYfTo0ZLLwT4otsE4mccYScM4mccYmccYSaOGOFmlD0q1N954A1988QWWLFliMjlJS0vTdpLdsGGDNjkBAGdnZzg7O2tfu7i4oFWrVkhLS5NVFvZBsS3GyTzGSBrGyTzGyDzGSJqGECfZ86AsX74cX375Jd577z2MHDnS6HqFhYWYOXMmHBwcsHHjRgQEBGiXCSEwZMgQbNu2TWf9CxcuICgoSG6RiIiIyM7IqkFJTk7Gxx9/jFmzZiEiIgIZGRnaZf7+/sjIyICnpydcXV2xcuVKXLx4EZ9//jkAaNd1dXWFp6cnBg0ahGXLlqFly5bw8fHBBx98gObNm2PgwIEKHh4RERE1RLISlD179qCiogIrVqzAihUrdJadOXMGUVFRiI6Oxrhx47Br1y4UFxdjwoQJOuuNHTsWb731Fp5//nk4OTlh/vz5yM/PR9++fbFq1So4OjrW/aiIiIioQavTKB41yMxUvpOsn5+n4tu1N4yTeYyRNIyTeYyReYyRNGqIU3UZzOHDAomIiEh1mKAQERGR6jBBISIiItVhgkJERESqwwRFT2puMT75LRl5xeW2LgoREdFty6KZZO3ZjE1xyCwoRWxnf/xnVBdbF4eIiOi2xBoUPZkFpQCAQxeybFwSIiKi2xcTFGP4yG4iIiKbYYJCREREqsMEhYiIiFSHCYoRGrbxEBER2QwTFCOYnhAREdkOExQiIiJSHSYoREREpDpMUIzQsI2HiIjIZpigGOHWyNHWRSAiIrptMUHRE9G6CQBgSGd/G5eEiIjo9sUERU8rbzcAgLsza1CIiIhshQmKEcLWBSAiIrqNMUHRw76xREREtscEhYiIiFSHCQoRERGpDhMUPdXznwjBXihERES2wgSFiIiIVIcJChEREakOExQiIiJSHSYoREREpDpMUIiIiEh1mKAYwUE8REREtsMERY9Gw7lkiYiIbI0JChEREakOExQiIiJSHVkJSlpaGubNm4fevXtjwIABiI6ORklJicF1T506hQkTJiA0NBTjx4/HiRMndJZv374dQ4YMQWhoKGbPno0bN25YfhQKqm7gYRcUIiIi25GcoAghMG/ePBQVFWHTpk1YunQp9u7di/fff7/WuoWFhZg1axYiIyOxbds2hIeH4/HHH0dhYSEA4NixY3jppZcwZ84cbN68Gbm5uVi4cKFiB0VEREQNm+QEJSUlBXFxcYiOjkbHjh0RGRmJefPmYfv27bXW3blzJ1xcXLBgwQK0b98eL730Etzd3fHTTz8BADZu3Ij77rsPY8aMQXBwMN555x389ttvuHTpknJHRkRERA2W5ATF398fa9asgZ+fn877+fn5tdaNj49HRESEdkSMRqNBz549ERcXp10eGRmpXb9FixYIDAxEfHy8JcdgFXxYIBERke04SV3Ry8sLAwYM0L6urKzExo0b0bdv31rrZmRkoEOHDjrv+fr6IjExEQCQnp6OZs2a1Vqempoqq/DAracPK6V6exqN8tu2JzXjRIYxRtIwTuYxRuYxRtKoIU5S9y05QdG3ZMkSnDp1Clu3bq21rKioCM7OzjrvOTs7o7S0FABQXFxscrkcvr6esj9jiqtrIwCAm5sL/PyU3bY9Ujr+9ogxkoZxMo8xMo8xkqYhxMmiBGXJkiX47LPPsHTpUnTq1KnWchcXl1rJRmlpKVxdXU0ud3Nzk12W69fzFJ31tbi4DABQWFiCzMw85TZsZzSaqhNc6fjbE8ZIGsbJPMbIPMZIGjXEqboM5shOUN544w188cUXWLJkCe69916D6wQEBCAzM1PnvczMTG2zjrHl/v7+cosDIZSdlr5mzRNPcvOUjr89YoykYZzMY4zMY4ykaQhxkjUPyvLly/Hll1/ivffew8iRI42uFxoaiqNHj2o7mgohcOTIEYSGhmqXx8bGate/du0arl27pl1OREREtzfJCUpycjI+/vhjPPbYY4iIiEBGRob2H1DVMba4uBgAMHz4cOTm5uLNN99EUlIS3nzzTRQVFeG+++4DAEyaNAnfffcdtmzZgoSEBCxYsACDBg1C69atrXCIRERE1NBITlD27NmDiooKrFixAlFRUTr/ACAqKgo7d+4EAHh4eGDlypWIjY3FuHHjEB8fj1WrVqFx48YAgPDwcLz++uv46KOPMGnSJDRp0gTR0dFWODxLVDXyqL3qi4iIyJ5pRAOf8CMzU9mOPu/sScKWuKuY2fcOPN6/rXIbtjMaDeDn56l4/O0JYyQN42QeY2QeYySNGuJUXQZz+LBAIiIiUh0mKHqqJ5BhAk5ERGQ7TFCIiIhIdZigEBERkeowQTFCsJGHiIjIZpig6OFzpoiIiGyPCQoRERGpDhMUIiIiUh0mKHq0w4zZBYWIiMhmmKAQERGR6jBBISIiItVhgqJHw3E8RERENscEhYiIiFSHCQoRERGpDhMUIziKh4iIyHaYoOhjFxQiIiKbY4JCREREqsMEhYiIiFSHCYqe6hYedkEhIiKyHSYoREREpDpMUIiIiEh1mKDo0TbxcJwxERGRzTBBISIiItVhgkJERESqwwRFj0bDmdqIiIhsjQmKEeyBQkREZDtMUIiIiEh1mKAQERGR6jBB0aPtgcI2HiIiIpthgkJERESqwwSFiIiIVIcJir6bbTxs4SEiIrIdJ0s/WFpainHjxuHll19Gnz59ai2fOnUqDh06VOv9cePGITo6Gjk5Oejdu7fOMm9vb8TExFhaJCIiIrITFiUoJSUlmD9/PhITE42us2zZMpSVlWlfx8fH45lnnsHkyZMBAElJSfD29sb27du16zg4sEKHiIiILEhQkpKSMH/+fLMP0/P29tb+f0VFBZYuXYqZM2ciJCQEAJCSkoJ27drB399fbhGsivPIEhER2Z7sKotDhw6hT58+2Lx5s+TPbNu2DTk5OXjssce07yUlJaFt27Zyd19v+DRjIiIi25Fdg1LdRCOVEAJr1qzBtGnT4O7urn0/OTkZ5eXlePDBB5GWlobIyEgsXLgQzZo1k7V9pR+do92eRvlt25Pq2DBGxjFG0jBO5jFG5jFG0qghTlL3bXEnWaliYmKQmpqKiRMn6ryfkpICHx8fLFy4EEIILF26FE888QS2bNkCR0dHydv39fVUtLyubs4AADdXZ/j5Kbtte6R0/O0RYyQN42QeY2QeYyRNQ4iT1ROUXbt24a677tLpkwIAO3bsgEajgaurKwDgww8/RFRUFOLj49GzZ0/J279+PQ9KtsYUF1V17C0qKkVmZp5yG7YzGk3VCa50/O0JYyQN42QeY2QeYySNGuJUXQZzrJ6g/PHHH5gzZ06t993c3HRe+/r6wtvbG2lpabK2LwSsEmQB62zX3lgr/vaEMZKGcTKPMTKPMZKmIcTJquN6b9y4gUuXLiEiIkLn/fz8fPTq1Qt//fWX9r20tDRkZWUhKCjImkUiIiKiBkDRBCUjIwPFxcXa14mJiXBxcUGrVq101vPw8EBERASio6Nx7NgxnDx5Es8++ywGDBiAzp07K1kk2ar77qg8sSQiIrJriiYoUVFR2Llzp/b19evX4eXlBY2BLrtvv/02unbtilmzZmHq1Klo2bIl3n33XSWLQ0RERA1UnfqgnDlzxuTrESNGYMSIEQY/26RJE0RHR9dl90RERGSnOLe8Ho6hJyIisj0mKMawEwoREZHNMEEhIiIi1WGCQkRERKrDBMUIwTYeIiIim2GCQkRERKrDBIWIiIhUhwmKnuphxmp/RgEREZE9Y4JCREREqsMEhYiIiFSHCYoeDTiVLBERka0xQTGCXVCIiIhshwkKERERqQ4TFD1s4CEiIrI9JihGcJgxERGR7TBBISIiItVhgkJERESqwwRFj0bbCYVtPERERLbCBIWIiIhUhwkKERERqQ4TFCIiIlIdJihGcJgxERGR7TBBISIiItVhgqJHo+FcskRERLbGBMUItvAQERHZDhMUIiIiUh0mKHrYwENERGR7TFCMYBMPERGR7TBBISIiItVhgkJERESqwwRFD0cZExER2R4TFD3XC8oAAH+du2HjkhAREd2+LE5QSktLMWrUKMTExBhd58knn0Tnzp11/u3du1e7fP369RgwYADCw8OxaNEiFBUVWVocxWyJuwoAuJpbYuOSEBER3b6cLPlQSUkJ5s+fj8TERJPrJScnY8mSJbjzzju17zVp0gQAsGvXLixfvhxLliyBr68vFi5ciCVLluCVV16xpEhERERkR2TXoCQlJWHixIm4ePGiyfVKS0tx+fJlhISEwN/fX/vP2dkZALBhwwZMnz4dd999N3r06IHXXnsNX3/9tSpqUYiIiMi2ZCcohw4dQp8+fbB582aT66WkpECj0aB169a1llVUVOD48eOIjIzUvhcWFoaysjIkJCTILRIRERHZGdlNPJMnT5a0XkpKCjw8PLBgwQIcOnQIzZs3x9y5czFw4EDk5uaipKQEzZo1u1UQJyd4e3sjNTVVVnmsOeqGI3qMq44NY2QcYyQN42QeY2QeYySNGuIkdd8W9UGRIiUlBcXFxYiKisKsWbPwyy+/4Mknn8TmzZvh5+cHANrmnmrOzs4oLS2VtR9fX0/FyqzPz89627YX1oy/vWCMpGGczGOMzGOMpGkIcbJagvLUU09h6tSp2k6xwcHBOHnyJL766is8++yzAFArGSktLYWbm5us/Vy/ngdhpXnpMzPzrLNhO6DRVJ3g1ox/Q8cYScM4mccYmccYSaOGOFWXwRyrJSgODg7a5KRaUFAQkpKS4O3tDRcXF2RmZqJ9+/YAgPLycmRnZ8Pf31/WfoSA1YLMk9w8a8bfXjBG0jBO5jFG5jFG0jSEOFltorYXX3wRCxcu1HkvISEBQUFBcHBwQEhICGJjY7XL4uLi4OTkhODgYGsViYiIiBoIRROUjIwMFBcXAwAGDx6MH374Ad9++y0uXLiA5cuXIzY2FlOmTAFQ1dl27dq12L17N44dO4bFixdj4sSJspt4iIiIyP4o2sQTFRWF6OhojBs3DsOGDcOrr76KFStW4OrVq+jYsSPWrFmDVq1aAQBGjhyJK1eu4JVXXkFpaSmGDRuG559/XsniEBERUQOlEULtrVCmZWYq29Gn139/1/7/4fl3KbdhO6PRVI1yUjr+9oQxkoZxMo8xMo8xkkYNcaougzl8WCARERGpDhMUIiIiUh0mKERERKQ6TFCIiIhIdZigEBERkeowQSEiIiLVYYJCREREqsMEhYiIiFSHCQoRERGpDhMUIiIiUh0mKERERKQ6TFCIiIhIdZigEBERkeowQSEiIiLVYYJCREREqsMEhYiIiFSHCQoRERGpDhMUIiIiUh0mKERERKQ6TFCIiIhIdZigEBERkeowQSEiIiLVYYJCREREqsMEhYiIiFSHCQoRERGpDhMUIiIiUh0mKERERKQ6TFCIiIhIdZigEBERkeowQSEiIiLVYYJCREREqsMEhYiIiFTH4gSltLQUo0aNQkxMjNF19u3bhwceeADh4eEYPXo09uzZo7M8MjISnTt31vlXUFBgaZGIiIjITjhZ8qGSkhLMnz8fiYmJRtdJSEjAnDlzsGDBAgwcOBD79+/H008/ja1btyI4OBhpaWnIy8vD7t274erqqv1c48aNLSkSERER2RHZCUpSUhLmz58PIYTJ9bZv346+ffti2rRpAIA2bdrg119/xY8//ojg4GAkJyfD398frVu3tqzkREREZLdkJyiHDh1Cnz598OyzzyIsLMzoemPHjkVZWVmt9/Py8gBUJTrt2rWTu3siIiK6DchOUCZPnixpvfbt2+u8TkxMxMGDB/HQQw8BAJKTk1FUVISpU6fi3Llz6NKlCxYtWiQ7adFoZK1u9W1XCoEr2cVo3dRN+QKpSHVsrBn/ho4xkoZxMo8xMo8xkkYNcZK6b4v6oMh148YNzJ07Fz179sQ999wDAEhJSUFOTg6ee+45eHh4YPXq1ZgxYwZ27NgBDw8Pydv29fW0VrHh5yd/289tjsO2o1fwxpjumNq3jRVKpS7WjL+9YIykYZzMY4zMY4ykaQhxsnqCkpmZiUceeQRCCHz44YdwcKgaOLR27VqUlZXB3d0dAPDuu+9i4MCB2Lt3L0aPHi15+9ev58FMdxiLZWbmyf7MtqNXAAAf7j6L+zr4KF0k1dBoqk5wa8a/oWOMpGGczGOMzGOMpFFDnKrLYI5VE5S0tDRtJ9kNGzbAx+fWD7azszOcnZ21r11cXNCqVSukpaXJ2ocQsFqQ67rd2+FLYs342wvGSBrGyTzGyDzGSJqGECerTdRWWFiImTNnwsHBARs3bkRAQIB2mRACQ4YMwbZt23TWv3DhAoKCgqxVJCIiImogFK1BycjIgKenJ1xdXbFy5UpcvHgRn3/+uXYZALi6usLT0xODBg3CsmXL0LJlS/j4+OCDDz5A8+bNMXDgQCWLRERERA2QoglKVFQUoqOjMW7cOOzatQvFxcWYMGGCzjpjx47FW2+9heeffx5OTk6YP38+8vPz0bdvX6xatQqOjo5KFslm1F51RkREpGZ1SlDOnDlj9PVPP/1k8rMuLi548cUX8eKLL9alCERERGSH+LBAIiIiUh0mKERERKQ6TFCIiIhIdZigEBERkeowQbESDuIhIiKyHBMUIiIiUh0mKERERKQ6TFBs7GRqHiZ9FouY81m2LgoREZFqMEGxsblbjyMpswBzvj5u66IQERGpBhMUKxES57rPLym3ckmIiIgaHiYoRHbgm2PXMH7dYVzKKrJ1UYiIFMEEpR4Ul1XgzZ/PYn/KdVsXhezUf35JxMWsIrzza5Kti0JEpAgmKPXgiyNX8O3xVDz7zUlbF4XsXFlFpa2LQESkCCYo9SAtr8ToMo2mHgtCRETUQDBBIbIjEvtmExGpHhMUC2QWlOL5705y7hIiIjKqrKISSZkFkkd1ki4mKBb4769J2Jd0XZG5S9jCQ0Rkn5795gQmfRaLH06m2booJn155AruXx2Dy9nqGgXIBEXP6G4B2v8vKa/Eaz+dwa9nM3TWScsrre9iERFRAxNzIRsAsOXoVdsWxIz/7k3GtdwSfPBbiq2LooMJih4/D2ft/3919Aq2n0zDCz+ctt4O2UuWFMSKZCKyVHmluq4gTFBMyCxgTQkREZEtMEEhIiIi1WGCYkJZhbqqu4iIiG4XTFD01OwSsiXO+h2b2AOFiIioNiYoejQSUgYl+7Wyj6x92XkqDWv/umDrYhARNXhOti6A2iiVLyg9L8/PCenw93BBeKsmym6YFPXqj2cAAP3a+aBrc08bl4aIqOFigtIAJGUW4KUdCQCAw/PvsnFpSIqcojJbF4GIqEFjE08DcC2nWLFtZReVIf5KDqdetlf8u5KeD35Lwewtx1Q3x8XtZtnv5zB5QyyKyipsXRSj1Hb5YIJiY/XdBWX8usOY+WU89qfcqOc9E1F9yCwoxXt7k3HueiEAYOPfl3HoYjYOXeCzw2xpw+FLSMwowA8nUm1dlAaDCYoee++0mltcDgD4I+W6yfXOXS/EpzEXUazibJ/kuXCjENM2HsGviZmKbK+orAIVvCtXnVd2JuCLI1cwdeMRnfdZg2I7NSPP2SukY4KicuUVlXju25P1vt+J6//Gx/vPY8WB8/W+b7KcqWvfaz+dxem0fLzw/ak67+e3pOu468MDmLQhts7bImWdSs0DUPUsMbW7nF2EjPwSWxeDVIoJih6NhCoUKZUsUpNkc/s7cM621bLHr+bZdP/Vissq8MOJVFzn4wcsll9Srsh2jlzOxr++q0qaq5sRSLrsojLEXWY/sLzicoxdexgjVsbI/mzMhSxcylLXk3ftSXZRGY5czrb5OcoEReUqKtV/F1Qflv1+Dq/vOovHvoyzdVEkseefntiLObYuQoM2ft1hPLY5Hr8nm25mNaesomFfG67lWtb5/+S1XMzZehzj1h1WuEQkbl65Hlx3GI9vPoa9CjUHW8riBKW0tBSjRo1CTIzx7PfUqVOYMGECQkNDMX78eJw4cUJn+fbt2zFkyBCEhoZi9uzZuHFD3R03hRBWvyhUVAr8347T+F/sZavup6HZl1T1RbmUXXVRO5mah1lfxuHktVyr7lcIgeTMApTqVZfnFStTG6E0O+9CZReq+4HpJygnruXiqsQRe5ezi9Dv/f148+ezipdP7U6m5tu6CHYvp/octfFgCosSlJKSEjz33HNITEw0uk5hYSFmzZqFyMhIbNu2DeHh4Xj88cdRWFhVJXzs2DG89NJLmDNnDjZv3ozc3FwsXLjQsqNQkKkL/Nyvj2PQsgPIk1BVnl1UZlGntD+Sr2NXQgaW7kuR/VmlNIS265lfxOHolVw8+kWcVfez+2wmHvosFk9uOaZ9b2vcVQz+6E9s/Ft9SaQ919zYs/M3CvHI/+LwwJpDktavPve+PS59REh9Jq9CCMWaFOn2JTtBSUpKwsSJE3Hx4kWT6+3cuRMuLi5YsGAB2rdvj5deegnu7u746aefAAAbN27EfffdhzFjxiA4OBjvvPMOfvvtN1y6dMmyI6kHMReyUVohkCKx3f3OpX/gdJq8PhyFKhg1s/NUWo1X6vzJq07+KgXwzLYT+PVshlX2s+3YNQDAsau3amre3pMEoGp+CVKvP5Kv653L6nU23b5qBV798QzuXv4n4q/Uf3NgXnE5fk3MrFXrSQ2P7ATl0KFD6NOnDzZv3mxyvfj4eERERGg7gWo0GvTs2RNxcXHa5ZGRkdr1W7RogcDAQMTHx8stkqKUHmb87q/Jym6wHtj6i30xqwjPfnMCx69Ka745cO4GXvjhtJVLZR0v/nAKc78+Xj+d0Qyc21dzirHiwHm77Hz83Lcn8eqPZ5BqYV8Ha5Py3C+LtquCdr4fT6cDADYcrv9axnlfH8cL35/Csj/O1fu+1epiVhEKS83f/FZUCiz7XT03XrKnup88ebKk9TIyMtChQwed93x9fbXNQunp6WjWrFmt5amp8iaxqa8vo7H9SNl/zVX01ze5TFN7lE+9PKjQQDmMfdYa8X/+u5NIuV6I/Sk38Pe/ak/tX5e/hVym/j4mywLzMSotr8Ses1V9a67mFqOVt5vF5ZRaJv11HvsyDun5pTh2JQef/CNU4g6k7U8qa55LQFV7egs1PsJKU+OYaxy73O+bqbhp9LZri+RFyvdVTrlqrltQWo7c4nIENnHVvn/8WlWt9Y+n0vCvwe1lltb6al4b6sOZ9Hw8vOEIvN0aYffsO02eSzEXshFzIVv72lpllbpNqz2Lp6ioCM7OzjrvOTs7o7S06k6tuLjY5HKpfH2VfSCbe2MXg+/7+Rnez+bjaZjSpw2aNG5kcHmjRo5wdbu1TH87NRMQPz9PeHrm6Lz20usQZqwccrm6Ohvdlof7rRg4NXI0uU+l4w8A13JvzYtQvW8HB4da7+mrdG6EZl6uipalUSNHk/s1VhZPLzdtbIzFqKT81h1N06bu8PN1r0tRAQBOTsb/Xo6OtWOYnl/1fTtyOUfyudW4se73Vqlz0hrnEgB4ezdWrIxKcnVppC2Xp+etpmAp3zdXV+nXlGpeXm611s0pKkOTm9enH49fwzdHr2DJhFDte3Xl7Oxk8Hg+2puEb45eMXoMpnh43Lo+3b38TwgB/LHgbrT2aayznsZBo5q/u5PTre+eu7tLvZbrf/FVN/3ZRWU6+5XyfXOpcY7agtUSFBcXl1rJRmlpKVxdXU0ud3OTdxd5/Xqeos8PKCwyPGlQZqbhviRLdp3B3ymZWPJAN4PLy8oqUFzjwXG1tyN0luXlFeu8ztWrnjZWDrmKi0uNbiu/4FYMjl7Mxjcx5zGgva/OOhpN1QmudPwB6DR3VJexssZwa2Pl7v2fPXh/XHdEBfkoVpbyGn2CDO3XWFlyc4tw+nwmNh69hge6+iPIQPJRsyktK6sA7kJa09rVnGLcKCxF9xZetZaVlVcYLVNFhekYSj23Cgt1v7d1PSeteS4BQHZ2ITJd1DejQnFJmTZ2ubm35vQwFE/9GJUU615ThBDaxKSysvb3BwDycot0Xv+ckI5F2xPwSJ/WmD2gHZ7cVDXzrO8PJzFfoZqH0tJyg8ezZNcZnddyzqGCGhO7VZ8vvx6/ilHdA3R+dEWlUOx6WVcnazRXFxSU1Gu5CgtvxSszM0/W962kxjmqpOoymGO1b21AQAAyM3XHUGdmZmqbdYwt9/f3l7UfIZT9Z6yvu6k/5KEL2TU+r0sD3W6mtfdXe5nua2FwHcuPz/y29CPw7DcnjW5L6fgbLaPee8Z8fviSsuUx8fcx9fcUAnjh+9PYFHMRUzYcMXusUspSViGw+0wG7l99CDM2xeHCjaLaZVLwWIyW2cLPCQEUlFTgt6TrKC6rtPq5ZC4mZRUCPydkIDW3ROe98gphlfPaVMxrxtTU5wz9LY9dycWQjw7iu+OpBrdfIww623tnT1X/uE9jLumsd6Ow1CrHaGy57L997Y/XOlaLtmvlv7Wxv0N9lsPguWQsoFaOoRRWS1BCQ0Nx9OhRVP/ACiFw5MgRhIaGapfHxt6aJvvatWu4du2adrmtGGsaM/cESjmdHCsqhXZ9a3WUI9s7c3NkRqmRh29IP2OqbI27ihdrdAZOypA38qO+z7SKSoFfEzORnnfrDm7h9lOY/+1JvPtrUr2VY+lvyQZHyWyNu4pF209j/M0Jv8orBUavisGETw/L+j7b2os/nEJOcTne2FU1J4oaOsmSdD8npFt13quaZ3JDe7aaoglKRkYGiourmiSGDx+O3NxcvPnmm0hKSsKbb76JoqIi3HfffQCASZMm4bvvvsOWLVuQkJCABQsWYNCgQWjdurWSRVLMXR8eMLqsoLQCwz/5C/vNPIAPqDpB7l8do/NDQ+qlpmv9AQmTJln6s2qN4/zuRCpe+P6Uzoyff958dMO3x1OxPuaioonAoQtZeG9vcq1RaLGXcvDw50dqrX/wfFU8S8or8dbuRAz56E9kFpTiUnYxyhrQE92klnT3mQy8sydJkQc8FpdVqDKJK28As+vqx+2lHQlYui8FKdcLTH4uu0ZXAUs1tNl3FU1QoqKisHPnTgCAh4cHVq5cidjYWIwbNw7x8fFYtWoVGjeu6sgUHh6O119/HR999BEmTZqEJk2aIDo6WsniWMTS6ZdvFJbh2W9qP9RP/25mf8oNpOeXSnqi7OYjV7DmL9Pzzej74sgVPLPtBM5fL1R8oqSjl3OQIHNeF1KGEAIFEoYJKrm/uv4A/XW+KhkxNvHfR/vP48hl5ebJmL31OL44cgVfHrlifmU9X8dfq9f42sKOU+nYEncVP59Jr9N2zl0vxIAPD2hrbNQit7gMYa//YutiWCynyPD1urisAmPXHsLQjw9iwyF584RduFGIKzVmJ87Ib1jTCdSpk+yZM2dMvu7Rowe++eYbo58fN24cxo0bV5ciKE5NNwXv7q09h0rK9QKDHS6rvXfzMwfO3YCzowYHnhkga5/ZRWUGT+IbhaWYtblqjprD82sP/S0tr0QjR42khy1aQu7f5WpOMRw0QHOFR/XYyvPfncJxCdP6m4q+1D/N4h8TsONUOrzdGuHrRyPh5arMiA5DMq1wwbwq8SbD1t/1746nYnxoC3QJ8JRd81XzUQtyv3HfHU/FfV0CUFhq+AfRXFw2HK76kfzhZBpeGd5Z5t7rztjx/pyQYZez1y7/4xwu33zEx7I/zmFab2mtDEVlFXjw07+tWTSrU1/XdjsTdyUXvyXV7aFgNf1jfSySMk1XBVYz1vfBmOKyCgz9+CDWG8jSTf2QXMwqQv8P9uN1GXdU5ZV1v0M3RIOqL+YDaw5h9OpDda7OPmrhTJhKH9lvdXywnBw7TlXdYWcXlWFb/LV62+/taNrGoxZ97uczls+cHHup6pyWe32QorxS4EqOup8yLITApawiVNo6Q9UjjFw1LK1lvFHYsGpLDGGCoscap2xmjVk69Z/Po39Xa+wkrSn2YjYKSysQeykbP51OV+yx4zXnH5Fj083ngmw/KW1a8bzicty74iAWbbdOP5yas6LWNUEx1xfh4/3nZD9zKSO/xGqPMje1RZMdso1Ur9SsEcsuKkNmvvRzxBrHl1NUZna7KvvdqVe27IPx98Vs9H//D4xZI7+fQ35JueK1HzlGHui5+ehVjFt3GNG/GH+WXENh6gG29vA9YIJSz749rswd6eytx/DEV8fw8s4Em3V8OnU1F//edRYZMn60AOCXM+nILS7H7rO2fZS3Ej6NuYTvT0ib/Ti7sAxZhaUYsTIGj28+pr2TBRrGyIuhHx/EfStjkG+kacASco7718RMDPn4IN5X6BlI9XX9LlAwXuaMWn3I6iMD9bdeUl6JI5ezsXD7aVh6P3D38j9x9/I/6yXB+uTAeQDyHrSoVk9vO4GByw4gR4EOtIbY+rLEBEWPtf8gZxR6KNiJa7bvrDriwz/w7fFU/GHFR3KfSc9Hel6JSh9ZWOVqjvk+D+WVAkNXHMSwFX9p34u9lK39/4PnsnReq4Wh2gopx6vvvwb6U8n1wb6qbfwv1nQnWCm1kDdXNGpPYgaOKtCB99fETAxa9idW/3nBTFGUOcOlPlNJyojDannF5Zjy+REsNDLy8P92nMbjm48pMspEv9Zj+8lUPPlVvNV+gE9cy8XE9X/jnT1JJvfx0+l0mzz4EKhdE1IpBL47fg3JmQU4eD4LZRUC+5Ia/s2eIUxQ9Fj7hzC/pO4jBax1j/FzgrTe/XW9EBmL8ca/L2PpPt0fsimfH8HIVTHSNy6zKuJMej6+OnrVKu3RNUtiaP6Bz/++NffB23uS8MRXx3TKscuCi6Jaq3XNjay5dKMQGw5d0lbzV1QKHL2cY7N5G17ZeQazNsfXef8vfH8KALDqoOkEpS4s6ZhuaMRhtUohsGj7ae0w8MEf/Ykz6fnYfTYDlULgJ73rxD4F+9jpe+2ns/j7Ug5W14ifkrWNj/wvDueuF2JL3FXM/fq4wXVOpebh5Z0JmPml7oNsz10vxFu7ExV7GKXU7+76mEv498+JeOizWJPrGdueWh+eaYjVprpvqGw5tv/TmIuIuZBldr33FLgbNUTqkOahHx/E5hkRiu23oLQc359IwwdKVN2b+PslZxYgsIkr3Go8X2fKzfkx3J0dMbJbQN33X0dCANAAZ9Pz8X87EwAYHjVVXz7efx4PhDRHU73n7yjt/uX7kVVYhuTMAiy+LxjrYi5i1Z8X0LdtU8zo3RqHLmbL7utjjpRai6KyCrjWOF8agrrWxhw8n4WC0gr8ciajVu2oBub7ZVlDfYzOOZ1muHb7cvatPn4/nU7Hur8u4p37u2L6piMoKqvE6bR8fPZwuNXLV23FzSYqc/42UiObnFmI7kHKlceamKDoqe+vXs27gY/3n6/nvUunf9eyQ2KHWENq5hDbjl1DzPkss/PCSK26rrWvm/89fDELT205jkAvF3z3WB+Ullfip9O37gTPZuRjJAJulk+gQgBODsrdquUUy6t1umxBM0pdGDtSgaq72PfHdTf62ZjzWejTtmmd9p9VWBWfwxezAVTN8gpUzaVSPZ+KVGqtRZLrw99SEHclF59M7AGXRsYru/X/dnU9/pq1Rseumh/aXh8KSivMTrFQH16+edPw2q4zKCqrqstWqtleaetj5M2hpUZs4rExJWZ1BKo6qj33zYla7/f67+84lVrVX0XpO1AlRP+SKGnSurr65eawzKs3Ryqti7mIN342PCx60fYEDPv4IHJlJBXbT6bqdpbVC7XskQ3WaHKyMN/S/5HSL9ocI1XjaqfWRObPczdwKjUPn/99Gcev5eKuZQcw5fMjRkdr6LPHCef2JV3HP9bHqiZhKi6rv9FSltaIqfT0loUJig3FXMiyuNe7vu9MdFadvqlqroUPazSh1PXiLGXSMGPU8MUxdVe++2wG8krK8XOCtLkmissq8NpPZ3Vm1hQQFiUEl7LVPYcEAGQYqM26rHC563J+So27UudhdmEZJn76t8V3rEkZBTrH+/S2E9rvLFB1E5OQlo8YK3ZGbyh+T75utA+czUbCKZTp7jyVhsMX5dUWmiqCWhNwOZig2NCcrcrdeUoZyviFBVOAG3P0sjruZGrRaMz2ZSkqq8DJVOVGQZUqODRy5hdxim1LnyWjb6rVnLn0lIHYjV17WJGasPT8Uuw5m4EsMx2xYy5kYfKGWINlkXJh/uv8DW1zUl19eugizt0oxEcWNtFO2mC6s2M1Y3fSSvwwG3qYolp9cqBuHY6VrmFSqkvO9yfS8NSWuv0mCCFw/Gqu1UY91TcmKHp6tmpSr/uzVdJffVFb/ecFvLwzQdHREkIIxF/JMdpEUtfM3txkYYZGFdTc58o6XuBMbbuuqodZWuPmx9SF2Vzz39fxV81u/6ujyiTAUh6kOWfrcSRmFGCehc1Lc7+u3RxqyDfHUnVGt63+8wL+b8dpnc705TZ+sKClEyzWZOhhikqSM/iguKwCC384hZ2nLO/nZszJOtT81qTUsHAlXc0pRu/3/sCjX8RhzNpDSM0zfF4s3ZeMGInDzG09PxM7yepp7uVi6yLUm4S0PO3wx8YKjlTYl3QdC74/BV93Z/z0RF+dZf+LvYyl++o2Wue+ldKHHRu6MJ4y8sDDorIKrFBxR2V9db1EvmmkD44h1uq+VNdnN+UZGN2hZFGrR0wMC24G4NZw4QlhgQht2USxicXOpkt7fIVccvpRVVMy39KgqkP9+7+l4P2x3dCthZfJ9V/84RT83F2w+2ymwYkc5d4MrNh/Dr7uzpgY3hIAcMrISJ36JPecT84slLTest9vXVdNTWdx/kYR/rHqL/z9L9uNDpSKCYqN1VeG+p2BGWyvF966eO0wc7ciZ3bKvTer+g2NvKlrcmJNq/+8YFEzmDV+u2teiPXnhqm2+cgV3NPJ3+LSqGEmTSEErtXziCUllJRXYs3BC1j55wV0a+5Z5+1tir1sfiULjF51yCrblWPxT1UPkZ3z9XF0b+6FocH65+wtcVcsr+EwdI1aF1P1XLHqBEUNrDWVhY0r8qyCCcpt4t8/N/znTkhharisOedvSLtTsTYhBP48d6tDpLGZU4/W4WJuSKUQcLBixmysmWj0ast/RG1ZA73y5uywSvZnUlqhjSa6MyS/pAJ/XcjCXxLmerKEEs0uZ9Pz0amZh8Wfl/sdMjaKM7uwDDnFZWjj01jSduwwNwHABOW2VSmAZ7ZJa4cH1Nnmam11OeK63CQdPJ+FH+owz4wh+h1YDVX937/6EMJamq6Cl0M/Bm/tTqq1Tl2beAzvWPlNUpV4hZNitXn48yMWT4x4Jj0fj2+Ox6x+bTA5opXR9aonnUvOLDA6G+zQFQcBAPeaqG2qK1E9K6SKMUGxMVsNBUuQ2RYrdcitIedvFKK4rALBAXWvCldEHYJuaOp5Y1/xIgvnSpDzDBghhNkf+bjLOdop16uNX/d3rfXS8kqwy8jf2dad5Uwpq6hEI8f67e9/3QaPsl9/4DyCmrqi1IZ1+Y9tjje/Ug31WdLLWeabCpW63hrq4/Hmz2dRUFqBpftSTCYoqw9exOCO/li0w3xncGPfRyX8eDodnx26hMf7tTW6jrUfPGkOExQ99Z0wFJfb5vHocieIW3/oksX7mvBp1Y/hrif7mlmz7oweVY0FxppGpERE/3kcxj4n9bEBdbVoewKeHthO+/r4tTztjJuVQuDTmIsGk1E5z1NSc3JSIYB+7+/H2B7Nrb6v7Bp9tl7ZecbgOkM++hNujRzx+dSe8HZrpOj+9ySkY4+iW6y7S1n1O2/P54eNX4dqPtuqLpIyC9DBz/SMtWl6I2QMDXc3ZcPhS/Uyfb8p1efwAr2bFzXhMGMCUNXpTynGftD+sV7afA/2QO4Fy1K7z2bU6sNRHefdZzLwyYELdX6YmxDSEveM/FKd2U6lPFdKqdznm2PW7/BbXe1uSk5xOVLzSjDVwLDdg+dv4IYNal6sadw60zMkK90h1NKtySnHDyfkn0vTNx212c2mNRWX27YPExMUPWq+W2yIas6vocTj2M0x9OfLKSpDrgV3K0rPjiqVUpf0K/U8OuZiVhEm15h0zB4v2FKl5pXUag6c9/UJPGigac2eGZvdur4peQNmzLnr0jvZ/3g6HRn5yiWr1kp8bf3YBCYoeuxhemApkjKtM+8CoJsk7DljvTZUKcorBYZ8fFA79NkYQ22txkbP6Lth4YMMGxKpifv5G0W4nF2Et3bbbtSYNTp0F1kwGuaMgblNDM3bQnW31kyT6pfa6QPs8wJf19l11Yp9UMiqDpyz7R1UbrG0H4S6/KjVZe4GQ5RIktccvIAyhWZXkxubsWtlPhixnry1OxEjuwYgJFD+SKW7Pjwg+zNLfq09aoms46KZvjDGZlVVo9d/Mty/6XbEBEVPR3/bPs7bUnkSf4jr24+n082vZEX7kqQ9HyYls9Bq8zPIlWBkpls5qufoUIKpWSkbkq/jr+Hr+GsWDyOlhq20vLJemnrqSukpBhoyJih6XJ2Um/K9PinVg93eSJ25Vi3JCQAcUughdkr62EqPAHjPyCy5dXG7NNOSdOWVAsNWHJTUp+J/sVfweL+2KCwtR4mCDwIl+ZigkF1R6gm1cljSP4GqKNlRUKpVf56v932SbRWWVsjq8Dlw2a0mvReHdLBGkRqEZCv2VZSCnWRJUbnFZSgpv71uYc09CZjUZfXB+pmjhtTj2FXL+4kp8bTohsrWx84aFFLUPR+ZnyuioaqPYdJEpDz9idXkSMqQX4uQXVgG78bKTtR3O2INCpFEQz+23+SLiAyzZCTip4dYS6cEJih6OFEbERHVRWl5Ja7k2GaiR3vCBIWI7IqAZdXyREopKK3AmDX1Ox+QuckoGyImKERkdyZtuH2e+0Tqk55/+3asVRITFCKyKxxTRWQfmKAQkV3Zn1y3pzcT1VXspRzzK5FZsocZl5SU4LXXXsPPP/8MV1dXPProo3j00UdrrTd16lQcOnSo1vvjxo1DdHQ0cnJy0Lt3b51l3t7eiImJkVskIiKtHJU+9oGI5JGdoLzzzjs4ceIEPvvsM1y9ehUvvPACAgMDMXz4cJ31li1bhrKyW/NGxMfH45lnnsHkyZMBAElJSfD29sb27du16zg4sEKHiIiIZCYohYWF2LJlC1avXo1u3bqhW7duSExMxKZNm2olKN7e3tr/r6iowNKlSzFz5kyEhIQAAFJSUtCuXTv4+/vX/SiIiIjIrsiqskhISEB5eTnCw8O170VERCA+Ph6VlcYfqrRt2zbk5OTgscce076XlJSEtm3byi8xERER2T1ZNSgZGRlo2rQpnJ2dte/5+fmhpKQE2dnZ8PHxqfUZIQTWrFmDadOmwd3dXft+cnIyysvL8eCDDyItLQ2RkZFYuHAhmjVrJusAlJ5YjRO1ERERVbHGb6LUbcpKUIqKinSSEwDa16Wlhp9KGhMTg9TUVEycOFHn/ZSUFPj4+GDhwoUQQmDp0qV44oknsGXLFjg6Okouk6+vp5xDMKuYT6YlIiICAPj5KfsbK4esBMXFxaVWIlL92tXV1eBndu3ahbvuukunTwoA7NixAxqNRvu5Dz/8EFFRUYiPj0fPnj0ll+n69TwIBSc+KK0w3lRFRER0O8nMzFN8mxqNtMoFWX1QAgICkJWVhfLyW8P4MjIy4OrqCi8vL4Of+eOPP3DPPffUet/NzU0nqfH19YW3tzfS0tLkFAlCKP+PiIiIgCvZxTb7nZWVoHTp0gVOTk6Ii4vTvhcbG4uQkBCDQ4Rv3LiBS5cuISIiQuf9/Px89OrVC3/99Zf2vbS0NGRlZSEoKEhOkYiIiMhK9pzNsNm+ZSUobm5uGDNmDBYvXoxjx45h9+7dWLduHaZNmwagqjaluLhYu35iYiJcXFzQqlUrne14eHggIiIC0dHROHbsGE6ePIlnn30WAwYMQOfOnRU4LMuxjywREVGVg+ezbLZv2TOjLVy4EN26dcP06dPx2muvYe7cuRg2bBgAICoqCjt37tSue/36dXh5eUFjoMvu22+/ja5du2LWrFmYOnUqWrZsiXfffbcOh0JERERKyi4qM7+SlWiEaNi9LjIzle0kW1ZRiX7v71dug0RERA1UR393/G9ahPkVZdBopI0O4tzyREREpDpMUIiIiEh1mKDo4UyyREREtscEhYiIiFSHCQoRERGpDhMUPQ17TBMREZF9YIJCREREqsMERQ87yRIREdkeExQiIiJSHSYoREREZJAt+2UyQSEiIiKDkjILbLZvJihERESkOkxQ9LCPLBERke0xQSEiIiLVYYJCREREqsMEhYiIiFSHCQoRERGpDhMUfZxKloiIyOaYoBAREZHqMEEhIiIi1WGCQkRERKrDBEUPe6AQERHZHhMUIiIiUh0mKERERKQ6TFCIiIhIdZig6OE0KERERLbHBEWPg0aDR/q3tXUxiIiIbmtMUAwY2jXA1kUgIiK6rTFBISIiItVhgmKAhrOhEBER2RQTFAMCvV1tXQQiIqLbGhMUA9r4utu6CERERLc12QlKSUkJFi1ahMjISERFRWHdunVG133yySfRuXNnnX979+7VLl+/fj0GDBiA8PBwLFq0CEVFRZYdBREREdkVJ7kfeOedd3DixAl89tlnuHr1Kl544QUEBgZi+PDhtdZNTk7GkiVLcOedd2rfa9KkCQBg165dWL58OZYsWQJfX18sXLgQS5YswSuvvFKHwyEiIiJ7IKsGpbCwEFu2bMFLL72Ebt26YejQoZg5cyY2bdpUa93S0lJcvnwZISEh8Pf31/5zdnYGAGzYsAHTp0/H3XffjR49euC1117D119/zVoUIiIikpegJCQkoLy8HOHh4dr3IiIiEB8fj8rKSp11U1JSoNFo0Lp161rbqaiowPHjxxEZGal9LywsDGVlZUhISJB7DERERGRnZCUoGRkZaNq0qbYWBAD8/PxQUlKC7OxsnXVTUlLg4eGBBQsWICoqCg8++CB+++03AEBubi5KSkrQrFkz7fpOTk7w9vZGamqqrAPQaJT/R0RERFVs9Tsrqw9KUVGRTnICQPu6tLRU5/2UlBQUFxcjKioKs2bNwi+//IInn3wSmzdvhp+fn85na25Lfzvm+Pp6ylqfiIiIpPPzs83vrKwExcXFpVYCUf3a1VV37pCnnnoKU6dO1XaKDQ4OxsmTJ/HVV1/h2Wef1flszW25ubnJOoDr1/MghKyPmKTRVCU979zfBQu+P63chomIiBqgzMw8RbdX/TtrjqwmnoCAAGRlZaG8vFz7XkZGBlxdXeHl5aW7YQcHbXJSLSgoCGlpafD29oaLiwsyMzO1y8rLy5GdnQ1/f385RYIQyv8DgMGd/LF3Tj9ZZSEiIrI31vqdNUdWgtKlSxc4OTkhLi5O+15sbCxCQkLg4KC7qRdffBELFy7UeS8hIQFBQUFwcHBASEgIYmNjtcvi4uLg5OSE4OBgOUWyKg8X2aOwiYiI7MaobrZ7eK6sBMXNzQ1jxozB4sWLcezYMezevRvr1q3DtGnTAFTVphQXFwMABg8ejB9++AHffvstLly4gOXLlyM2NhZTpkwBAEyePBlr167F7t27cezYMSxevBgTJ06U3cRjbX3aeNu6CERERDbh07iRzfYtu4pg4cKFWLx4MaZPnw4PDw/MnTsXw4YNAwBERUUhOjoa48aNw7Bhw/Dqq69ixYoVuHr1Kjp27Ig1a9agVatWAICRI0fiypUreOWVV1BaWophw4bh+eefV/boFBDg6WLrIhAREdmI7Ya2aoRQsotp/cvMVL6TrJ+fp3a7b+w6g+9PpCm3AyIiogZiWq/WmHtXO0W3Wf07aw4fFmhGt+YcxkxERFTfmKCY8UBICywc0sHWxSAiIrqtMEExw9FBg3GhgWjqZruOQkRERLcbJihERESkOkxQJJrRp/ZDD6sF+Taux5IQERHVj3uD5U2eqiQmKBJN6tnS6LKPJ/Sox5IQERHVjyY27N7ABEUijYnHL7o4yQvjO/d3xQ+P9a5rkSz2qInaICIiIjVggqIADxcnHHpuAD6ZaLgm5Z37u2r/38XJAXd39ENzL1eD69aHJ6OUGdP+yqiuBt9/7b7OimyfiIhuX0xQZDDVFqfRaOBp5Nk9d3f0u7We4qWyHR93Z4Pvh7TwMvg+ERE1LLb8zWKCIsOQTrUTFGPNO01cneCgAe7vrvugJRMtRfXipaEdFduWu5GEzNbHSEREDR8f1yvDwA6+eOGeDugS4IEZ/4sDADg5GP41/vmpO1FZKeDkqJvAaGyYj66Y0AORd3grtr3Bwc0Mvh/YxHbNV0REZB9YgyKDRqPBg2GB6CahCcNBo6mVnFRt49b/T+/dGq5ODpgYFihp/3vn9ENHf3eT6xjqANutuSeaujVC9xZ1n7a/uvlmzUOhcDSSnDmwCkXrtfu72boIqhaiwDlJRNZjy8s5a1AU1M63MZq4OsFb4rCsOQPa4cn+bVEpBE6k5uFUap7J9T1cnAzWv0zr1RobDl9CMw9nPBnVDtN6t8aBlBv44WQa/Nyd8cq9nVAhjNf2mNMj0AvHruYCAD6eEILySgFPV546UowODcSr35+0dTFMWjcpDI9+EWeTfXfwd0e3Fp748shVm+yflDejd2usP3TJ6vt5Kqot2jR1AzQavPD9Kavvj+ofa1AU1MjRAT890RebZ0QaXUe/dsHRQYNGjg747OFwSaNfHghpUeu9AE8X7H7qTnw3s2rosruzE4YFN8Oy8SF4dXhnaDSaWslJ40aOUg4JANC1uSf6tm2KB7o3h2sjR3gY6XsCAC28XExuq7W3+eaftZPCJJfN3P5q6hHoBbdG9XvKG+tIXJ+Mdd6uFhJou07NlcL0HEP67uvSDF0CPIwuH1yjQ/rtbt2kMJ0RhPVl9oB26FXHpmRnR/M3U4/0uQODO/mjpYkRka4yp4CoqzuautXr/uwdExSFOTk6GG36AExXl43oGoA/5vU3uf0Hw1rU+tI5O2rQxK2RwSYlueV47M47aq8LYNn4EPzfvZ1qLds8I0Ln9RgDCVRNbmYSo7dHd0EjvYuTqYuMECY3p6NPG28sHdtd0roz+9aOg75X7u2E6b2NzynTr21TAIY7Utf1wvn26C6S153aq1Wd9mVNLWX2V+rUzHhyAgAjuhruFyVFgKeLtkzm/j6W1kbWp5BALzR2vvV9q8/k3FQSKUV7P9NN2VLtevJOWeu/JeN7RdbHBKWO5M6yZ+6L59rI0WRHVgeNRueOt0egF+7rGmB0fblm9WuLqCAfnff0E4aa2vu5463RXTDvrnb4YFx3TDPxgy3F4E7+CG7mofNwxl/n9EN4qyYG1+9s5gerJg00iGjtjU8m9sC6SWE4+EwUXh1eO+kCgKm9zB/HiK4BmDPA+JwyY3pUJWvuzrWTsodk1BrU9MX0CHw5PQId/KUf94SwQLRs4orxobWTx/+MMn1BbuvjZvIY62pqpLLJUzNPwzVqQzqZr1lZ9Y9Q/CM8EB9NCMHqh0LRI9ALY3s0r70PD2eDCf4Ave+N2vRp09TWRZDMXA2MT2Np1125/ScaUoxuB0xQLLRsfHd0a+6J/z4grRPkpqk98UBIc/x7RLDZdcd0170oLhzaEb/XqFmpWWmwdlKY7Jlszan5nQ7ybWz2x/qeTv6Y2qs1+rXzMXtn2adNU7M/zhqNRueHs5Gjg9GqElO1VcZEtPZGSKAXnBwdMKqbbqzvaOqG3U/dicbOjtj9lOm7L3P7Fjf/UnJqecxp59NY9t2lh4sTvvlnL7w4pPYQc3NNfVse6YXpvVvjyf5tMf/u9gbXeX9cd0QF+eBfRpYbM++udnBydIBvPTSDjTOQnNX0QEhzBDZxxb8Gd0DLJm4IDvDE2klhiGztXWvdt0Ybbja5v3vtZMbmapx7gzv54f7uAXh5mOGkXE2czdQG1zwXm3lafv4YSlwnR1Rdn/q3k59wGroivD6iM775Zy/8Z1QXo5N5kmFMUCzUt60P1j8cjg5mRtVU69TMA/83rJPROzxjGjdyxLgeLXSbRpT8xTNj84xIyZ1+pZjVrw2eGxRkdj1vvTskY0dsrMamZ6sm6Nu2qW6bsJlcZmyP5tj6SKS2VqyJWyMcnn+X2bJWMzbCalrv2rUEw/Qm/fP3kHaRtbRlwdSjGqR4tO8dRhPL/u18sHRsd9n9baoTPNdGjpLvdPWP/8ObNwrVjG1Gv++XodokQwzFLSTQy2C/mYEdfCVtsy7k9rERNb45jhoNXr63M+4PsV4iteLmc8mq+8qZunl6sn9bo8ucTXyunU9jDKsxxUHTxs5YJ6PfWrUXh3QwWDP77KD2+H1efywd2w0bp/aUvL3OzTww2kCS2trbDa283TC0s79VRsQM6yz9YX592ngbXWasPx8naiMdNU9iYeCn2Vhzh3L7r/sp2d7P8BOeq36MzG+/g587nh0UhOibNSn6OVnftk3x29z+Oj9OAPDjE32x5qFQrPxHKJaND8GSB6R3EnTUaCQf+yIDE961aWr4mKdEtsKmqT3xzxr9Wjr6e+CHx3ojsnUTTIlshXl31U7avPRGSvm5O2vLJ6O7kQ5jHbHbKNC5T0rHxppqHl9rI7Gr9nBEK9zR1K1WLcWdbX10fsw6Gmn60q/ZC/Ry1TnmGUYSXQ8XwzVMT0a1wzS9vj01z52uzY0PnzbXcXViWKDBSSEBIPpmc2p9etBMMjeoRmJWPZXBHU3dsHdOP+yb299oDPWT9GpdAjwwMdz41AtN3Gp3+g4J9MIzg4IwsJO/0WuPvvGhgRgXqrufRje/WG43r1NSm5BX/qMHNk7tiYd6tsTwLs3w5kjzNeW2sPzBWzU4rfQGLHRu5oE1D4VickRLneuqi5P0ARVKY4KicoYqS6b3vgPPD+6Ab/7Zy+LtWnts+8apEfjlyTvx7gPdDF7gqjN5UyOXJke0whAjdwcaQKcDYDU/d2eEtryVwAX5uut8xhSpycnaSWEY20PaHXj1djs188Dj/dpgSmQrvH+zo25zL1esmBiKpwcG4d5gfzwc0QrhLW/1Lxrc0Q/bZ/XRvnat0ckx0MsVwzr7y65RGdE1AIeeG1DrfTl3isb0D/JF37bS2/CHd7nVd+pNA52X104KQzvfxlj+YAieGRSErx/tZXIEGWC82U1/WHxjZ0ed1L+Vt+EErU+bphhjoMbByUGDuXcFGa0hMNX511z9ZzvfxnAx0qHV0UGDlkbKakjTxrdqtWqe3+sfDsdcM32LPp4QgrdHd8Hz93QwuZ6xmkMPFyc4OWiwd45ux38HTVWNXCtvN4N9oDZM6Ql3ZyfZCe+UyFb47NHeOs1DjWr8//sGzjH9xNXQ39NPQs1g9QScLk4OeGNEsE4Nj4ezU6315HB3rqpFN6alhJGRhrw0rKNODbFGo0FoyyZ4dlB7rH84HIuHd8Yr93ay6ZQSTFBUztDFzMXJARPDA41eVKVoXYfPSuHkoIF340YY2MEXTQ10aFv+YA8cem4ARljYwbenBbVI7Xyl3Vnp0+/8qJ8YVVexTunVCne1N17Nr9Fo8PTAIPQ30JlSo9HgmUFBWPVQmPY9RwcNAjxdtHeok3q20ln/zVFdsNBAv5KanujfxuC+qlU3Ubo2csQzA6tqcRYO7Yg3RwZj/eSwWp81xclBg2XjQySt+9WMSJ0fhwEd/XHgmSiddXoEeuGrGZFmOy6auuT/37COeKJ/G51EVQ4HjQYvDeuETkZ+hNv5WHZO3RvsD2+3RrdiPkQ3CXjAVH8WGU28xmoAujX3NNmh/d5gf/S6oykGd/JXZOLF6lqUeXe1w4Gno7TNO141Es4n+rfBF9NvjQr8Yrrx6RqkcnLQ4OtHe2HzjAj0D/LRjtSSw9IG9Zfv7YSnotqibY3rjiWzbO+Z3Q8La9TYvj+2O5aO7YZPJvbAxLBAPNLH/IhDQ8xNPzCyW4DBJqv6xNm2VEjOyBRLvTW6K5b9noKHI1vhkZvT9lfPEvtUVFscPH8DUxQaYWHselr9Q3lfl2b48XS6yarkJ6Pa4omvjqFfu6a4u4MfRnWTnth8OjkMJ6/lWTxHxpIHuuF6QSlGrooxuPzfI4OxcGhHeLg4oU+bpvg9+ToAoF9bZUZ1RI/qgvNZRWhvIMEy9NtxR1M3XMwqAgA09zR8Qfzx8T7IK6nQuWA/HNkK93dvbvKOafU/QnExuwgXbhRiw+HLso6jg587OjdzR0l5Jdr61E6Qle7sDRieN8gSC4d2xMwv42v1m+jdpikS0vNrjdRq3dQNwc084O7iiNhLObobEwL/HtkFFZUCjg4ajAut6mMWvTtJu4qcZtwWXi64lltidj0lK02fGRiE939LAaBbS2GsFvKrGZE4ejkHgzv569Za1Pjff/bVTaaNzSkyvIvpoeT6Jai5nRfu6YDnvj2pM43AyK7NsONUusltmtPBQMd1Q52mAzxdsPIfPfD45mOSt11dK7h3Tj+UlFfqdCqPMNCJW5+7syMKSiu0rxcN7Yic8koEB3jqXJsNfSdtjQmKCrWx8K5MjsAmrojWG41Q3bbd3s8dv8+Lqre5Hl4Z3hkTwgLRxUS7fURrb+yb2w/uzvJP2e4tvNBdwuMJPI20lTs6aHQ7N+slXBqNRtv0MD60BZo2boSQFp5wM9AEJUf1HY6To4PBC6AUxm5+/Txc4GcgDzZXnRvWqgnCWjXB9ydSJZfBp3EjjA9tgTEhLcx2EnfQVE3eJofUG3xXJwcUl1cCANyN/K2N6d7CC/ufrv2dmNWvDVo2cUE/vREfjhpgw5RwAEDv9/7QWVbdT6b6h8fc3ED6+rXzQWtvV3QJ8MSbN5tIev33dwBVNzdn0vMNfs7VQLNRMw9npOeXytr/yG4BOklJgKcL/hEeCNdGjkaTTH8PF51mj2pyrjCLh3fGlZwis82rA9r74lRavsHRYQPa++L3ef11Yj7nriBczSnGWIkdp4GqGq7vanwH5DSD9GzlLXndmjxcnOAhvwII9wY3Q25xOSLvqEp6x4W2gJ+fJzIzdWcuv8dIvydbYoKicqIeR+zUvNDX50RUTg4aSbOZWpKcSPHyvZ3wy5kMSXOfmOPooMFQGb3qDXntvs7YlZBuchI4U/q188HFrCt1KoM5LhJ76T7SpzUm9Wyp0xfCFHdnJ+SVlMsqi9QzdfVDoZi68ShaebtiaOdmWHPwoqz9GPpOuDg51OpoCeh2aK42omsz/CO8JVobqRmorvnSn4dIn2sjR2x9tJfBppc7mrph7oB28HG/1aw6d0A7nE7Lx50GavQ+GBeCSRtiAVTNSfP53+Zrxf5vaEccvaJbK/Svwab7qShhpMRa0xl9WqO1txsijMylop8Q+rk76zStmrJ5RgTOXS/EoA5+2gRFyVGOQNV35n+xV1ByM5muKycHDaJNTED385N9kZ5fWi8193IxQSGrC1JoVkhrub97c1lzWBgaWaWkEV0DJPfNMXSXGBrohS+PVCUo1uoMPbiTH/qc9NbpkGzIqG7NJScn1hYc4KnTKfDpgUGY/+1JxZoygaoROjEXsgyeT829XE2O7tk8PQL5pRXaH7xuzT1xMjUPve7wxuGL2TrrGusX0tjZEX30Oiqb6mtSc5oEKedKMw9nODk6oNcdykxoZo3zs5GjA+410wwkx8SwQKw4cB592zZFkK+7tj/TQz1b4ssjVzBXgVFVwc080K2FJ8aGtEDnAA+M6BKACev/RmRraU19Td0aIauozKJ9N23srJrvqD4mKCrl6+6M6wWlVh9SXJMlk55JMaSTH7IGt681JJjqrn87HzzSpzWCm3mgRRNXnLyWh3tqTD7VNcA6MW/k6KAzZLE+1UrKLPyVu6u9L/bO6Wd2ZJAcd3f0w90W9nVycnSAt9utmqnVD4UiNbcEuSXlmLHpqMnPvjysE749fs3k3CJSTI5oia+OXsXMO2v3B7mYVWTw2Ooy0Z65US2z7myDVQcv4OEI2z2uYXrv1oho3aRWDcNzg4IwvXdrSaN8zIke3UVn0ENb38b4dXY/yU2Rpr4CDfnh8kxQVGrtpFB8fyIN/zAxH4BSJoQF4lpuscm7u7rQaDSYGG7Z1O5kmkajwVNRt+7gutxMSHbM6oMbhaX10p/JFEtGLZizcEhHvPnLWUwMqzqn6nL9VTI5UVojRwdtc9D6h8PQta0vUGL4Lvn+kOaKTMD27KD2mHtXUK3mrNUPheKv81k6Hc3ffaAbkjML0LuODwY0Zeadd+DeLs0kPWTUWhwdNAZrCjUajSLJSTMPZ4MjMuX0a6nvhyLWF/V+O29zLZu41fluSKoFZuY5IF3W6gujpGaeLrJnLVbKsvHdserPi1g0rKNV+jI183TBB+NuDWdW+w1iYBNXXM0prtOTlru38IKfpysyjSQoSqh+2Kihv5lPY+dazY4DO/jWefZcc3N4aDQau39C8B0K3ES8c383LNx+CgLA5eziuhdKJdR/pSUyo3sLT5y4lqcdJm0trw7vhNzicqvUCtiTvm190NfCIdaP9WuD9/YmyxpGrnZfzYjEjcJStPBS53kzO6otfj6TgYcj6r+Ws4WXK5aPDzE4O6w9e7J/W6w4cB4A8Npw45NVStU5wAPb/tkbxWUV2HbsGvq388GDn/5d5+3a2u11VpBdeveBbthxMg2julv3R03/wYKkvIfCA9GvbVOjI10MMdbGLncmUmtxcXJQbXICADP63IEZFk72pQT9Tr23g0f73oEHw1rAy1XZEUCujRwx2Yb9dZRmnw1XdFvxdXfGtN6t4aPSnugknUajQRufxrJmL9XvaPnRgyFo79cYn0wMVbp4RIpROjmxR6xBISK70rtNU3ypwDTpRGRbsmtQSkpKsGjRIkRGRiIqKgrr1q0zuu6+ffvwwAMPIDw8HKNHj8aePXt0lkdGRqJz5846/woKCuQfBRHdvtTRkkOkSvpPRW9IZJf8nXfewYkTJ/DZZ5/h6tWreOGFFxAYGIjhw4frrJeQkIA5c+ZgwYIFGDhwIPbv34+nn34aW7duRXBwMNLS0pCXl4fdu3fD1fVW+2zjxrYdFklEDUvPVk3w1/ksWxeDSFVeu68zdp/JwJTIus+QbSuyEpTCwkJs2bIFq1evRrdu3dCtWzckJiZi06ZNtRKU7du3o2/fvpg2bRoAoE2bNvj111/x448/Ijg4GMnJyfD390fr1g03eERke1MiW6GJqxN6m3nqMdHtRM6M1GolK0FJSEhAeXk5wsPDte9FRETgk08+QWVlJRwcbrUYjR07FmVltcfs5+VVPaAoKSkJ7drVfYpgIrq9NXI0/DwcImrYZCUoGRkZaNq0KZydb42W8PPzQ0lJCbKzs+Hjc2vug/bt2+t8NjExEQcPHsRDDz0EAEhOTkZRURGmTp2Kc+fOoUuXLli0aJHspEXpaXyrt9eQpweuD4yTeYyRNIyTeYyReYyRNGqIk9R9y0pQioqKdJITANrXpaXGH9l948YNzJ07Fz179sQ999wDAEhJSUFOTg6ee+45eHh4YPXq1ZgxYwZ27NgBDw/pT1X09bXO9OzW2q69YZzMY4ykYZzMY4zMY4ykaQhxkpWguLi41EpEql/X7OhaU2ZmJh555BEIIfDhhx9qm4HWrl2LsrIyuLtXPRny3XffxcCBA7F3716MHj1acpmuX8+DUPDhshpN1R9O6e3aG8bJPMZIGsbJPMbIPMZIGjXEqboM5shKUAICApCVlYXy8nI4OVV9NCMjA66urvDyqj3NeFpamraT7IYNG3SagJydnXVqY1xcXNCqVSukpaXJKRKEgFWCbK3t2hvGyTzGSBrGyTzGyDzGSJqGECdZ86B06dIFTk5OiIuL074XGxuLkJAQnQ6yQNWIn5kzZ8LBwQEbN25EQMCt3sRCCAwZMgTbtm3TWf/ChQsICgqy8FCIiIjIXsiqQXFzc8OYMWOwePFi/Oc//0F6ejrWrVuH6OhoAFW1KZ6ennB1dcXKlStx8eJFfP7559plQFVTkKenJwYNGoRly5ahZcuW8PHxwQcffIDmzZtj4MCBCh8iERERNTSyJ2pbuHAhFi9ejOnTp8PDwwNz587FsGHDAABRUVGIjo7GuHHjsGvXLhQXF2PChAk6nx87dizeeustPP/883BycsL8+fORn5+Pvn37YtWqVXB0dFTmyIiIiKjB0gih9lYo0zIzle8k6+fnqfh27Q3jZB5jJA3jZB5jZB5jJI0a4lRdBnP4NGMiIiJSHSYoREREpDpMUIiIiEh1mKAQERGR6jBBISIiItWRPcxYbfiwQNtgnMxjjKRhnMxjjMxjjKRRQ5yk7rvBDzMmIiIi+8MmHiIiIlIdJihERESkOkxQiIiISHWYoBAREZHqMEEhIiIi1WGCQkRERKrDBIWIiIhUhwkKERERqQ4TFCIiIlIdJig1lJSUYNGiRYiMjERUVBTWrVtn6yLVi19++QWdO3fW+Tdv3jwAwKlTpzBhwgSEhoZi/PjxOHHihM5nt2/fjiFDhiA0NBSzZ8/GjRs3tMuEEHj33XfRt29f9O7dG++88w4qKyvr9djqqrS0FKNGjUJMTIz2vUuXLmHGjBkICwvDiBEjsH//fp3P/Pnnnxg1ahRCQ0Mxbdo0XLp0SWf5+vXrMWDAAISHh2PRokUoKirSLmuo56ChOP373/+udV5t3LhRu7wu505WVhbmzp2L8PBwDB48GN999139HKgF0tLSMG/ePPTu3RsDBgxAdHQ0SkpKAPBcqmYqRjyPbrlw4QL++c9/Ijw8HIMGDcKaNWu0y+zyXBKk9frrr4vRo0eLEydOiJ9//lmEh4eLH3/80dbFsrqPP/5YPP744yI9PV37LycnRxQUFIj+/fuLt956SyQlJYk33nhD9OvXTxQUFAghhIiPjxc9evQQ33zzjTh9+rSYMmWKmDVrlna7a9euFQMHDhSHDx8WBw8eFFFRUWLNmjW2OkzZiouLxezZs0WnTp3EX3/9JYQQorKyUowePVrMnz9fJCUliU8++USEhoaKK1euCCGEuHLliggLCxNr164VZ8+eFU8//bQYNWqUqKysFEII8dNPP4mIiAjx66+/ivj4eDFixAjx2muvaffZEM9BQ3ESQogZM2aIlStX6pxXhYWFQoi6nzuPP/64mD59ujhz5oz46quvRPfu3UV8fHz9HbRElZWVYuLEiWLmzJni7Nmz4vDhw2Lo0KHirbfe4rl0k6kYCcHzqFpFRYUYNmyYmD9/vjh37pzYt2+f6Nmzp/j+++/t9lxignJTQUGBCAkJ0bnAfvTRR2LKlCk2LFX9mD9/vvjvf/9b6/0tW7aIwYMHa0/iyspKMXToUPH1118LIYR4/vnnxQsvvKBd/+rVq6Jz587i4sWLQgghBg4cqF1XCCG+/fZbcffdd1vzUBSTmJgo7r//fjF69GidH94///xThIWFaZM0IYSYPn26+PDDD4UQQrz//vs650xhYaEIDw/Xfn7y5MnadYUQ4vDhw6JHjx6isLCwQZ6DxuIkhBADBgwQf/zxh8HP1eXcuXDhgujUqZO4dOmSdvmiRYt0tqcWSUlJolOnTiIjI0P73g8//CCioqJ4Lt1kKkZC8DyqlpaWJp5++mmRl5enfW/27Nni1VdftdtziU08NyUkJKC8vBzh4eHa9yIiIhAfH9/gmiXkSk5ORtu2bWu9Hx8fj4iICGhuPnpSo9GgZ8+eiIuL0y6PjIzUrt+iRQsEBgYiPj4eaWlpuHbtGnr16qVdHhERgStXriA9Pd2qx6OEQ4cOoU+fPti8ebPO+/Hx8ejatSsaN26sfS8iIsJoTNzc3NCtWzfExcWhoqICx48f11keFhaGsrIyJCQkNMhz0Fic8vPzkZaWZvC8Aup27sTHx6NFixZo1aqVzvKjR48qe3AK8Pf3x5o1a+Dn56fzfn5+Ps+lm0zFiOfRLc2aNcP7778PDw8PCCEQGxuLw4cPo3fv3nZ7LjlZdesNSEZGBpo2bQpnZ2fte35+figpKUF2djZ8fHxsWDrrEULg3Llz2L9/P1auXImKigoMHz4c8+bNQ0ZGBjp06KCzvq+vLxITEwEA6enpaNasWa3lqampyMjIAACd5dUXoNTU1FqfU5vJkycbfD8jI8PoMZtbnpubi5KSEp3lTk5O8Pb2RmpqKhwcHBrcOWgsTsnJydBoNPjkk0/w+++/w9vbG4888gjGjh0LoG7njrEYp6WlKXZcSvHy8sKAAQO0rysrK7Fx40b07duX59JNpmLE88iwwYMH4+rVq7j77rtx77334j//+Y9dnktMUG4qKirS+QMA0L4uLS21RZHqxdWrV7XH/v777+Py5cv497//jeLiYqMxqY5HcXGx0eXFxcXa1zWXAQ07nuZiYmq5oZjUXC6EsJtzMCUlBRqNBkFBQZgyZQoOHz6Ml19+GR4eHhg6dGidzh1zfwM1W7JkCU6dOoWtW7di/fr1PJcMqBmjkydP8jwy4MMPP0RmZiYWL16M6Ohou70uMUG5ycXFpVawq1+7urraokj1omXLloiJiUGTJk2g0WjQpUsXVFZW4vnnn0fv3r0NxqQ6HsZi5ubmpnMCu7i4aP8fqKpebKhcXFyQnZ2t856UmHh5edWKQ83lbm5uqKiosJtzcMyYMbj77rvh7e0NAAgODsb58+fxxRdfYOjQoXU6d4x9Vu0xWrJkCT777DMsXboUnTp14rlkgH6MOnbsyPPIgJCQEABVo2v+9a9/Yfz48TqjbgD7OJfYB+WmgIAAZGVloby8XPteRkYGXF1d4eXlZcOSWZ+3t7e2nwkAtG/fHiUlJfD390dmZqbOupmZmdqqwICAAIPL/f39ERAQAADaataa/+/v72+V46gPxo5ZSky8vb3h4uKis7y8vBzZ2dnamNnLOajRaLQ/KtWCgoK01ed1OXdMfVat3njjDXz66adYsmQJ7r33XgA8l/QZihHPo1syMzOxe/dunfc6dOiAsrKyOl2r1XwuMUG5qUuXLnByctJ2KgKA2NhYhISEwMHBfsP0xx9/oE+fPjrZ9+nTp+Ht7a3tMCaEAFDVX+XIkSMIDQ0FAISGhiI2Nlb7uWvXruHatWsIDQ1FQEAAAgMDdZbHxsYiMDBQ9f1PTAkNDcXJkye11aJA1XEZi0lRURFOnTqF0NBQODg4ICQkRGd5XFwcnJycEBwcbFfn4AcffIAZM2bovJeQkICgoCAAdTt3wsLCcOXKFW37evXysLAwqx6TpZYvX44vv/wS7733HkaOHKl9n+fSLcZixPPolsuXL2POnDk6fWROnDgBHx8fRERE2Oe5ZPVxQg3Iyy+/LEaOHCni4+PFL7/8Inr27Cl27dpl62JZVV5enhgwYIB47rnnRHJysti3b5+IiooSq1atEnl5eaJv377ijTfeEImJieKNN94Q/fv31w5lO3LkiOjWrZv46quvtHMQPP7449ptr1y5UkRFRYm//vpL/PXXXyIqKkqsW7fOVodqsZrDZ8vLy8WIESPEM888I86ePStWrlwpwsLCtPMNXLp0SYSEhIiVK1dq5xsYPXq0dqj29u3bRc+ePcUvv/wi4uPjxciRI8Ubb7yh3VdDPgdrxik+Pl507dpVrFmzRly4cEFs2rRJdO/eXRw5ckQIUfdz59FHHxVTpkwRp0+fFl999ZUICQlR5fwVSUlJokuXLmLp0qU683ikp6fzXLrJVIx4Ht1SXl4uxo0bJx599FGRmJgo9u3bJ/r16yfWr19vt+cSE5QaCgsLxYIFC0RYWJiIiooSn376qa2LVC/Onj0rZsyYIcLCwkT//v3FsmXLtCdufHy8GDNmjAgJCREPPvigOHnypM5nv/76azFw4EARFhYmZs+eLW7cuKFdVl5eLv7zn/+IyMhI0adPH7FkyRLtdhsS/fk9zp8/Lx5++GHRvXt3MXLkSHHgwAGd9fft2yeGDRsmevToIaZPn66dk6HaypUrxZ133ikiIiLEwoULRXFxsXZZQz4H9eP0yy+/iNGjR4uQkBAxfPjwWhe0upw7mZmZ4vHHHxchISFi8ODB4ocffrD+AVpg5cqVolOnTgb/CcFzSQjzMeJ5dEtqaqqYPXu26Nmzp+jfv79YsWKF9njs8VzSCHGz/p6IiIhIJdTXGElERES3PSYoREREpDpMUIiIiEh1mKAQERGR6jBBISIiItVhgkJERESqwwSFiIiIVIcJChEREakOExQiIiJSHSYoREREpDpMUIiIiEh1mKAQERGR6vw/iiA9jMBC3MsAAAAASUVORK5CYII=" + "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], - "execution_count": 18 + "execution_count": 42 }, { "metadata": {}, From 24d45cd261344cb7b6cfe6f20a181c31c8aa46ad Mon Sep 17 00:00:00 2001 From: Jerry Date: Sun, 27 Oct 2024 17:45:03 -0400 Subject: [PATCH 11/22] Update Quickstart notebook --- examples/Quickstart_Diagnostics.ipynb | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/examples/Quickstart_Diagnostics.ipynb b/examples/Quickstart_Diagnostics.ipynb index 6ceddd5bc..d2c92bcaf 100644 --- a/examples/Quickstart_Diagnostics.ipynb +++ b/examples/Quickstart_Diagnostics.ipynb @@ -3,7 +3,7 @@ { "metadata": {}, "cell_type": "markdown", - "source": "Quickstart Diagnostics", + "source": "# Quickstart: Amortized Posterior Estimation", "id": "ee8e90d08cdb035e" }, { @@ -39,6 +39,28 @@ "outputs": [], "execution_count": 24 }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Generative Model Definition\n", + "\n", + "The toy Gaussian model we will use for this tutorial takes a particularly simple form:\n", + "\n", + "\\begin{align}\n", + " \\mathbf{\\mu} &\\sim \\mathcal{N}_D(\\mathbf{0}, \\sigma_0 \\mathbb{I}),\\\\\n", + " \\mathbf{x}_n &\\sim \\mathcal{N}_D(\\mathbf{\\mu}, \\sigma_1 \\mathbb{I}) \\quad \\text{for } n = 1, ..., N,\n", + "\\end{align}\n", + "\n", + "where $\\mathcal{N}_D$\n", + "denotes a multivariate Gaussian (normal) density with $D$\n", + "dimensions, which we set at $D = 4$\n", + "for the current example. For simplicity, we will also set $\\sigma_0 = 1$\n", + "and $\\sigma_1 = 1$\n", + ". We will now implement this model using the latest numpy interface." + ], + "id": "90baf0445b69eb7d" + }, { "metadata": { "ExecuteTime": { From fa3d90c681a02c2c16c5e98007bc95a749341048 Mon Sep 17 00:00:00 2001 From: Jerry Date: Tue, 29 Oct 2024 13:40:21 -0400 Subject: [PATCH 12/22] Add SBC-related diagnostics and z-score --- bayesflow/diagnostics/__init__.py | 4 + bayesflow/diagnostics/plot_sbc_ecdf.py | 184 ++++++++ bayesflow/diagnostics/plot_sbc_histogram.py | 151 ++++++ .../diagnostics/plot_z_score_contraction.py | 154 ++++++ examples/Quickstart_Diagnostics.ipynb | 443 ++++++++++-------- 5 files changed, 735 insertions(+), 201 deletions(-) create mode 100644 bayesflow/diagnostics/plot_sbc_ecdf.py create mode 100644 bayesflow/diagnostics/plot_sbc_histogram.py create mode 100644 bayesflow/diagnostics/plot_z_score_contraction.py diff --git a/bayesflow/diagnostics/__init__.py b/bayesflow/diagnostics/__init__.py index 5e4fdbce2..0d77a37aa 100644 --- a/bayesflow/diagnostics/__init__.py +++ b/bayesflow/diagnostics/__init__.py @@ -1,2 +1,6 @@ from .plot_losses import plot_losses from .plot_recovery import plot_recovery +from .plot_sbc_ecdf import plot_sbc_ecdf +from .plot_sbc_histogram import plot_sbc_histograms +from .plot_distribution_2d import plot_distribution_2d +from .plot_z_score_contraction import plot_z_score_contraction \ No newline at end of file diff --git a/bayesflow/diagnostics/plot_sbc_ecdf.py b/bayesflow/diagnostics/plot_sbc_ecdf.py new file mode 100644 index 000000000..e4c860a5f --- /dev/null +++ b/bayesflow/diagnostics/plot_sbc_ecdf.py @@ -0,0 +1,184 @@ + +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns + +from ..utils.plot_utils import check_posterior_prior_shapes +from ..utils.ecdf import simultaneous_ecdf_bands + +def plot_sbc_ecdf( + post_samples, + prior_samples, + difference=False, + stacked=False, + fig_size=None, + param_names=None, + label_fontsize=16, + legend_fontsize=14, + title_fontsize=18, + tick_fontsize=12, + rank_ecdf_color="#a34f4f", + fill_color="grey", + n_row=None, + n_col=None, + **kwargs, +): + """Creates the empirical CDFs for each marginal rank distribution and plots it against + a uniform ECDF. ECDF simultaneous bands are drawn using simulations from the uniform, + as proposed by [1]. + + For models with many parameters, use `stacked=True` to obtain an idea of the overall calibration + of a posterior approximator. + + [1] Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and + its applications in goodness-of-fit evaluation and multiple sample comparison. Statistics and Computing, + 32(2), 1-21. https://arxiv.org/abs/2103.10522 + + Parameters + ---------- + post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) + The posterior draws obtained from n_data_sets + prior_samples : np.ndarray of shape (n_data_sets, n_params) + The prior draws obtained for generating n_data_sets + difference : bool, optional, default: False + If `True`, plots the ECDF difference. Enables a more dynamic visualization range. + stacked : bool, optional, default: False + If `True`, all ECDFs will be plotted on the same plot. If `False`, each ECDF will + have its own subplot, similar to the behavior of `plot_sbc_histograms`. + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None. Only relevant if `stacked=False`. + fig_size : tuple or None, optional, default: None + The figure size passed to the matplotlib constructor. Inferred if None. + label_fontsize : int, optional, default: 16 + The font size of the y-label and y-label texts + legend_fontsize : int, optional, default: 14 + The font size of the legend text + title_fontsize : int, optional, default: 18 + The font size of the title text. Only relevant if `stacked=False` + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + rank_ecdf_color : str, optional, default: '#a34f4f' + The color to use for the rank ECDFs + fill_color : str, optional, default: 'grey' + The color of the fill arguments. + n_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + n_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + **kwargs : dict, optional, default: {} + Keyword arguments can be passed to control the behavior of ECDF simultaneous band computation + through the ``ecdf_bands_kwargs`` dictionary. See `simultaneous_ecdf_bands` for keyword arguments + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + ShapeError + If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. + """ + + # Sanity checks + check_posterior_prior_shapes(post_samples, prior_samples) + + # Store reference to number of parameters + n_params = post_samples.shape[-1] + + # Compute fractional ranks (using broadcasting) + ranks = np.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) / post_samples.shape[1] + + # Prepare figure + if stacked: + n_row, n_col = 1, 1 + f, ax = plt.subplots(1, 1, figsize=fig_size) + else: + # Determine number of rows and columns for subplots based on inputs + if n_row is None and n_col is None: + n_row = int(np.ceil(n_params / 6)) + n_col = int(np.ceil(n_params / n_row)) + elif n_row is None and n_col is not None: + n_row = int(np.ceil(n_params / n_col)) + elif n_row is not None and n_col is None: + n_col = int(np.ceil(n_params / n_row)) + + # Determine fig_size dynamically, if None + if fig_size is None: + fig_size = (int(5 * n_col), int(5 * n_row)) + + # Initialize figure + f, ax = plt.subplots(n_row, n_col, figsize=fig_size) + ax = np.atleast_1d(ax) + + # Plot individual ecdf of parameters + for j in range(ranks.shape[-1]): + ecdf_single = np.sort(ranks[:, j]) + xx = ecdf_single + yy = np.arange(1, xx.shape[-1] + 1) / float(xx.shape[-1]) + + # Difference, if specified + if difference: + yy -= xx + + if stacked: + if j == 0: + ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs") + else: + ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95) + else: + ax.flat[j].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDF") + + # Compute uniform ECDF and bands + alpha, z, L, H = simultaneous_ecdf_bands(post_samples.shape[0], **kwargs.pop("ecdf_bands_kwargs", {})) + + # Difference, if specified + if difference: + L -= z + H -= z + ylab = "ECDF difference" + else: + ylab = "ECDF" + + # Add simultaneous bounds + if stacked: + titles = [None] + axes = [ax] + else: + axes = ax.flat + if param_names is None: + titles = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] + else: + titles = param_names + + for _ax, title in zip(axes, titles): + _ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands") + + # Prettify plot + sns.despine(ax=_ax) + _ax.grid(alpha=0.35) + _ax.legend(fontsize=legend_fontsize) + _ax.set_title(title, fontsize=title_fontsize) + _ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) + _ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) + + # Only add x-labels to the bottom row + if stacked: + bottom_row = [ax] + else: + bottom_row = ax if n_row == 1 else ax[-1, :] + for _ax in bottom_row: + _ax.set_xlabel("Fractional rank statistic", fontsize=label_fontsize) + + # Only add y-labels to right left-most row + if n_row == 1: # if there is only one row, the ax array is 1D + axes[0].set_ylabel(ylab, fontsize=label_fontsize) + else: # if there is more than one row, the ax array is 2D + for _ax in ax[:, 0]: + _ax.set_ylabel(ylab, fontsize=label_fontsize) + + # Remove unused axes entirely + for _ax in axes[n_params:]: + _ax.remove() + + f.tight_layout() + return f \ No newline at end of file diff --git a/bayesflow/diagnostics/plot_sbc_histogram.py b/bayesflow/diagnostics/plot_sbc_histogram.py new file mode 100644 index 000000000..100ec9704 --- /dev/null +++ b/bayesflow/diagnostics/plot_sbc_histogram.py @@ -0,0 +1,151 @@ + +import logging +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +from scipy.stats import binom +from ..utils.plot_utils import check_posterior_prior_shapes + +def plot_sbc_histograms( + post_samples, + prior_samples, + param_names=None, + fig_size=None, + num_bins=None, + binomial_interval=0.99, + label_fontsize=16, + title_fontsize=18, + tick_fontsize=12, + hist_color="#a34f4f", + n_row=None, + n_col=None, +): + """Creates and plots publication-ready histograms of rank statistics for simulation-based calibration + (SBC) checks according to [1]. + + Any deviation from uniformity indicates miscalibration and thus poor convergence + of the networks or poor combination between generative model / networks. + + [1] Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). + Validating Bayesian inference algorithms with simulation-based calibration. + arXiv preprint arXiv:1804.06788. + + Parameters + ---------- + post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) + The posterior draws obtained from n_data_sets + prior_samples : np.ndarray of shape (n_data_sets, n_params) + The prior draws obtained for generating n_data_sets + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + fig_size : tuple or None, optional, default : None + The figure size passed to the matplotlib constructor. Inferred if None + num_bins : int, optional, default: 10 + The number of bins to use for each marginal histogram + binomial_interval : float in (0, 1), optional, default: 0.99 + The width of the confidence interval for the binomial distribution + label_fontsize : int, optional, default: 16 + The font size of the y-label text + title_fontsize : int, optional, default: 18 + The font size of the title text + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + hist_color : str, optional, default '#a34f4f' + The color to use for the histogram body + n_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + n_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + ShapeError + If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. + """ + + # Sanity check + check_posterior_prior_shapes(post_samples, prior_samples) + + # Determine the ratio of simulations to prior draws + n_sim, n_draws, n_params = post_samples.shape + ratio = int(n_sim / n_draws) + + # Log a warning if N/B ratio recommended by Talts et al. (2018) < 20 + if ratio < 20: + logger = logging.getLogger() + logger.setLevel(logging.INFO) + logger.info( + f"The ratio of simulations / posterior draws should be > 20 " + + f"for reliable variance reduction, but your ratio is {ratio}.\ + Confidence intervals might be unreliable!" + ) + + # Set n_bins automatically, if nothing provided + if num_bins is None: + num_bins = int(ratio / 2) + # Attempt a fix if a single bin is determined so plot still makes sense + if num_bins == 1: + num_bins = 5 + + # Determine n params and param names if None given + if param_names is None: + param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] + + # Determine number of rows and columns for subplots based on inputs + if n_row is None and n_col is None: + n_row = int(np.ceil(n_params / 6)) + n_col = int(np.ceil(n_params / n_row)) + elif n_row is None and n_col is not None: + n_row = int(np.ceil(n_params / n_col)) + elif n_row is not None and n_col is None: + n_col = int(np.ceil(n_params / n_row)) + + # Initialize figure + if fig_size is None: + fig_size = (int(5 * n_col), int(5 * n_row)) + f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) + axarr = np.atleast_1d(axarr) + + # Compute ranks (using broadcasting) + ranks = np.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) + + # Compute confidence interval and mean + N = int(prior_samples.shape[0]) + # uniform distribution expected -> for all bins: equal probability + # p = 1 / num_bins that a rank lands in that bin + endpoints = binom.interval(binomial_interval, N, 1 / num_bins) + mean = N / num_bins # corresponds to binom.mean(N, 1 / num_bins) + + # Plot marginal histograms in a loop + if n_row > 1: + ax = axarr.flat + else: + ax = axarr + for j in range(len(param_names)): + ax[j].axhspan(endpoints[0], endpoints[1], facecolor="gray", alpha=0.3) + ax[j].axhline(mean, color="gray", zorder=0, alpha=0.9) + sns.histplot(ranks[:, j], kde=False, ax=ax[j], color=hist_color, bins=num_bins, alpha=0.95) + ax[j].set_title(param_names[j], fontsize=title_fontsize) + ax[j].spines["right"].set_visible(False) + ax[j].spines["top"].set_visible(False) + ax[j].get_yaxis().set_ticks([]) + ax[j].set_ylabel("") + ax[j].tick_params(axis="both", which="major", labelsize=tick_fontsize) + ax[j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) + + # Only add x-labels to the bottom row + bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] + for _ax in bottom_row: + _ax.set_xlabel("Rank statistic", fontsize=label_fontsize) + + # Remove unused axes entirely + for _ax in axarr[n_params:]: + _ax.remove() + + f.tight_layout() + return f diff --git a/bayesflow/diagnostics/plot_z_score_contraction.py b/bayesflow/diagnostics/plot_z_score_contraction.py new file mode 100644 index 000000000..77cc33922 --- /dev/null +++ b/bayesflow/diagnostics/plot_z_score_contraction.py @@ -0,0 +1,154 @@ + +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +from ..utils.plot_utils import check_posterior_prior_shapes + +def plot_z_score_contraction( + post_samples, + prior_samples, + param_names=None, + fig_size=None, + label_fontsize=16, + title_fontsize=18, + tick_fontsize=12, + color="#8f2727", + n_col=None, + n_row=None, +): + """Implements a graphical check for global model sensitivity by plotting the posterior + z-score over the posterior contraction for each set of posterior samples in ``post_samples`` + according to [1]. + + - The definition of the posterior z-score is: + + post_z_score = (posterior_mean - true_parameters) / posterior_std + + And the score is adequate if it centers around zero and spreads roughly in the interval [-3, 3] + + - The definition of posterior contraction is: + + post_contraction = 1 - (posterior_variance / prior_variance) + + In other words, the posterior contraction is a proxy for the reduction in uncertainty gained by + replacing the prior with the posterior. The ideal posterior contraction tends to 1. + Contraction near zero indicates that the posterior variance is almost identical to + the prior variance for the particular marginal parameter distribution. + + Note: Means and variances will be estimated via their sample-based estimators. + + [1] Schad, D. J., Betancourt, M., & Vasishth, S. (2021). + Toward a principled Bayesian workflow in cognitive science. + Psychological methods, 26(1), 103. + + Paper also available at https://arxiv.org/abs/1904.12765 + + Parameters + ---------- + post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) + The posterior draws obtained from n_data_sets + prior_samples : np.ndarray of shape (n_data_sets, n_params) + The prior draws (true parameters) obtained for generating the n_data_sets + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + fig_size : tuple or None, optional, default : None + The figure size passed to the matplotlib constructor. Inferred if None. + label_fontsize : int, optional, default: 16 + The font size of the y-label text + title_fontsize : int, optional, default: 18 + The font size of the title text + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + color : str, optional, default: '#8f2727' + The color for the true vs. estimated scatter points and error bars + n_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + n_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + ShapeError + If there is a deviation from the expected shapes of ``post_samples`` and ``prior_samples``. + """ + + # Sanity check for shape integrity + check_posterior_prior_shapes(post_samples, prior_samples) + + # Estimate posterior means and stds + post_means = post_samples.mean(axis=1) + post_stds = post_samples.std(axis=1, ddof=1) + post_vars = post_samples.var(axis=1, ddof=1) + + # Estimate prior variance + prior_vars = prior_samples.var(axis=0, keepdims=True, ddof=1) + + # Compute contraction + post_cont = 1 - (post_vars / prior_vars) + + # Compute posterior z score + z_score = (post_means - prior_samples) / post_stds + + # Determine number of params and param names if None given + n_params = prior_samples.shape[-1] + if param_names is None: + param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] + + # Determine number of rows and columns for subplots based on inputs + if n_row is None and n_col is None: + n_row = int(np.ceil(n_params / 6)) + n_col = int(np.ceil(n_params / n_row)) + elif n_row is None and n_col is not None: + n_row = int(np.ceil(n_params / n_col)) + elif n_row is not None and n_col is None: + n_col = int(np.ceil(n_params / n_row)) + + # Initialize figure + if fig_size is None: + fig_size = (int(4 * n_col), int(4 * n_row)) + f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) + + # turn axarr into 1D list + axarr = np.atleast_1d(axarr) + if n_col > 1 or n_row > 1: + axarr_it = axarr.flat + else: + axarr_it = axarr + + # Loop and plot + for i, ax in enumerate(axarr_it): + if i >= n_params: + break + + ax.scatter(post_cont[:, i], z_score[:, i], color=color, alpha=0.5) + ax.set_title(param_names[i], fontsize=title_fontsize) + sns.despine(ax=ax) + ax.grid(alpha=0.5) + ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) + ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) + ax.set_xlim([-0.05, 1.05]) + + # Only add x-labels to the bottom row + bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] + for _ax in bottom_row: + _ax.set_xlabel("Posterior contraction", fontsize=label_fontsize) + + # Only add y-labels to right left-most row + if n_row == 1: # if there is only one row, the ax array is 1D + axarr[0].set_ylabel("Posterior z-score", fontsize=label_fontsize) + # If there is more than one row, the ax array is 2D + else: + for _ax in axarr[:, 0]: + _ax.set_ylabel("Posterior z-score", fontsize=label_fontsize) + + # Remove unused axes entirely + for _ax in axarr_it[n_params:]: + _ax.remove() + + f.tight_layout() + return f diff --git a/examples/Quickstart_Diagnostics.ipynb b/examples/Quickstart_Diagnostics.ipynb index d2c92bcaf..223af9849 100644 --- a/examples/Quickstart_Diagnostics.ipynb +++ b/examples/Quickstart_Diagnostics.ipynb @@ -9,8 +9,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T16:18:05.726356Z", - "start_time": "2024-10-22T16:18:05.710363Z" + "end_time": "2024-10-29T17:19:36.903584Z", + "start_time": "2024-10-29T17:19:36.887584Z" } }, "cell_type": "code", @@ -37,7 +37,7 @@ ], "id": "56c348ceefe0a66f", "outputs": [], - "execution_count": 24 + "execution_count": 69 }, { "metadata": {}, @@ -64,8 +64,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T16:24:23.691704Z", - "start_time": "2024-10-22T16:24:23.677134Z" + "end_time": "2024-10-29T17:19:41.043173Z", + "start_time": "2024-10-29T17:19:41.029441Z" } }, "cell_type": "code", @@ -80,39 +80,39 @@ ], "id": "214241c510d751f4", "outputs": [], - "execution_count": 38 + "execution_count": 70 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T16:24:24.077994Z", - "start_time": "2024-10-22T16:24:24.058994Z" + "end_time": "2024-10-29T17:19:44.441846Z", + "start_time": "2024-10-29T17:19:44.424298Z" } }, "cell_type": "code", - "source": "simulator = bf.simulators.CompositeLambdaSimulator([theta_prior, forward_model])", + "source": "simulator = bf.make_simulator([theta_prior, forward_model])", "id": "938dc70eb8ba4a54", "outputs": [], - "execution_count": 39 + "execution_count": 71 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T16:24:24.971064Z", - "start_time": "2024-10-22T16:24:24.948063Z" + "end_time": "2024-10-29T17:19:44.829362Z", + "start_time": "2024-10-29T17:19:44.810435Z" } }, "cell_type": "code", "source": "sample_data = simulator.sample((50,))", "id": "931b7f6a77c8401b", "outputs": [], - "execution_count": 40 + "execution_count": 72 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T16:24:27.263198Z", - "start_time": "2024-10-22T16:24:27.249771Z" + "end_time": "2024-10-29T17:19:45.641792Z", + "start_time": "2024-10-29T17:19:45.631779Z" } }, "cell_type": "code", @@ -139,13 +139,13 @@ ] } ], - "execution_count": 41 + "execution_count": 73 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T16:25:08.028115Z", - "start_time": "2024-10-22T16:25:08.013014Z" + "end_time": "2024-10-29T17:19:56.765935Z", + "start_time": "2024-10-29T17:19:56.742426Z" } }, "cell_type": "code", @@ -155,139 +155,174 @@ { "data": { "text/plain": [ - "{'theta': array([[ 0.34066415, -0.5138845 , 1.4528089 , -0.49958685],\n", - " [ 0.86015004, 0.48635587, 0.2364767 , -0.53709507],\n", - " [-0.6582664 , -0.1106401 , -0.5822995 , 0.29959023],\n", - " [ 0.17613287, 0.40979308, -1.1803418 , 0.7906092 ],\n", - " [-2.5815134 , -0.5926008 , -1.442977 , -1.1212678 ],\n", - " [ 0.15085632, 0.8538437 , -0.71999 , -0.6779198 ],\n", - " [-0.04969181, 0.45948943, 0.6696255 , 0.9931811 ],\n", - " [ 1.015672 , 0.28774238, 0.18076487, -0.11111598],\n", - " [-0.7719207 , -1.3176122 , 0.5294132 , 0.4176514 ],\n", - " [ 0.03191099, -0.6768063 , 0.5141813 , -1.592261 ],\n", - " [-0.99575025, -1.8044442 , 0.56740063, -1.9281672 ],\n", - " [ 0.81070745, -0.60243636, -0.10667904, -0.3417887 ],\n", - " [ 1.0685011 , -1.3776896 , 1.8168131 , -0.8139481 ],\n", - " [ 0.01134184, 0.02382061, 1.6661643 , -0.46634912],\n", - " [-1.8478132 , -0.08229433, -0.04664409, -0.11284911],\n", - " [ 1.4040706 , -0.67715555, -0.2592975 , -0.20792411],\n", - " [ 2.399644 , 0.89749336, 2.4230204 , 0.0970002 ],\n", - " [ 0.93040395, 0.25475293, 0.8398071 , 0.29117548],\n", - " [-0.16029291, -0.02478953, 0.29951358, 0.33260188],\n", - " [-0.86853355, -1.1873287 , 1.9413403 , 0.32616952],\n", - " [-0.67677224, 0.02859171, 0.5428518 , -1.5521122 ],\n", - " [ 2.3776774 , -0.6828046 , 0.5556347 , -1.4531173 ],\n", - " [ 0.17357443, -0.45678964, -0.12053017, -0.8963106 ],\n", - " [ 0.20243092, 0.4169088 , 0.4405855 , 0.06946267],\n", - " [-0.4409229 , 0.07481287, -0.82419586, 0.33597344],\n", - " [-1.0189365 , 1.2648267 , -0.84935266, 0.58711445],\n", - " [ 0.8026323 , -0.73901856, -0.3391541 , 0.5761913 ],\n", - " [-0.9766659 , -1.1858828 , -0.5103791 , 0.73724025],\n", - " [ 1.5424025 , 0.28468883, -0.05811996, -1.0388048 ],\n", - " [ 1.6943154 , -0.36717394, 0.37467596, -0.18305473],\n", - " [-1.4355195 , -1.1271007 , 0.98609 , -2.1474707 ],\n", - " [-0.24416563, -0.88949496, -0.83292514, 0.05820413],\n", - " [ 1.0845547 , 0.97108537, 0.18267912, 0.16928157],\n", - " [ 1.5283549 , -0.9298855 , -1.9587208 , 1.4929713 ],\n", - " [ 0.34513605, 0.904506 , 0.46237883, -1.4228871 ],\n", - " [-0.81769085, 0.7091762 , -0.54571545, 0.5346092 ],\n", - " [-1.221732 , -1.3575743 , -1.3833972 , 1.5352001 ],\n", - " [-1.1201226 , -0.11686669, -0.21259853, -0.01677035],\n", - " [-0.1394734 , -0.3124989 , -0.21038432, -0.2977672 ],\n", - " [ 0.41691035, 0.28065392, -0.38032046, 0.95429885],\n", - " [-1.771154 , -1.1321709 , -1.9100127 , 0.5539506 ],\n", - " [-1.447865 , 1.2216287 , 2.154635 , -0.3226352 ],\n", - " [ 0.68915546, 0.41079593, -0.05922764, -2.326437 ],\n", - " [-0.81387454, -1.0814589 , -0.6311428 , -0.16105291],\n", - " [-0.12934463, 0.26514062, -1.6791768 , -0.20046751],\n", - " [-0.19893628, -0.48227343, 0.38067642, -0.8310641 ],\n", - " [ 1.2004272 , 0.0041292 , -0.02631984, 1.3608695 ],\n", - " [ 2.2010703 , 1.2613806 , -1.1433984 , -0.1893912 ],\n", - " [ 0.38409767, -0.2333284 , 0.67292047, 1.7366157 ],\n", - " [-0.22914144, 1.6965197 , 1.2130772 , 0.39718068]],\n", - " dtype=float32),\n", - " 'x': array([[[ 3.6254582 , -1.2774805 , 1.829676 , -0.32256916],\n", - " [-1.256944 , 0.3428607 , 1.6378316 , -1.9851911 ],\n", - " [ 0.7485846 , 1.0144739 , 2.0964758 , -0.01522239],\n", + "{'theta': array([[-0.25933109, -0.29311586, -0.43468539, -0.75589758],\n", + " [-0.62746534, 0.26783042, -0.54751926, -0.33454591],\n", + " [ 0.48858301, -0.48960682, 0.41613236, 0.38619082],\n", + " [-1.26044397, 0.28186682, 0.16251423, 0.03970525],\n", + " [-0.81672093, -0.33752783, 0.59320318, 0.06854435],\n", + " [-0.8297982 , -0.89704741, 0.82610742, 0.79129585],\n", + " [-0.27320508, -1.24441094, 1.73280939, 2.74006395],\n", + " [-1.25845926, -0.67339622, -0.51586832, -0.64329048],\n", + " [-1.13386794, -0.46101155, 0.20951863, -0.00562751],\n", + " [-0.22513122, -1.57013183, -0.16388857, -0.43789601],\n", + " [-0.81567616, 1.14158432, -0.46095738, -1.04537264],\n", + " [ 0.27219349, 0.48780616, -0.61371472, -1.40612089],\n", + " [-1.27484604, 0.72096922, -0.08216321, 0.06403711],\n", + " [ 1.01779488, 0.971248 , -0.19348285, -0.5629501 ],\n", + " [ 0.11858625, -0.38555645, -0.35270227, -0.9507192 ],\n", + " [ 0.49935472, 0.52885169, -1.24047814, 0.25604984],\n", + " [ 0.04351305, 1.33752058, 1.34091618, -0.53722738],\n", + " [-1.33391022, -2.78228719, -0.58246235, -0.34395886],\n", + " [-0.85785128, 0.180416 , 1.22235207, 0.33605605],\n", + " [-1.38918174, -1.46010187, -0.09209822, -1.64827012],\n", + " [ 0.01536091, 0.1219928 , -2.29736111, 0.65614956],\n", + " [ 1.80411329, 2.02312911, -0.3190316 , 1.50478783],\n", + " [-0.32357835, -0.65384842, -1.10786669, -0.70311434],\n", + " [ 0.53677133, 0.27758978, -1.93671838, -0.54012384],\n", + " [-2.13069546, -0.12857099, -0.32827032, 0.63181261],\n", + " [ 1.1965332 , -1.15250211, -0.46760947, 2.16295808],\n", + " [ 0.47006833, -0.39309495, 1.53650972, -0.09418776],\n", + " [-0.52887427, 1.55835402, -0.47106973, -1.17590053],\n", + " [-0.38151978, 0.43464099, -0.00713016, -0.05979383],\n", + " [ 1.71986058, 1.38505429, -0.93926711, -2.06344947],\n", + " [-1.48315744, 0.44258625, -0.62879085, 0.92203648],\n", + " [ 1.19590277, -0.88942841, -0.95436464, 0.58319199],\n", + " [ 0.23751304, -0.33223209, 1.62014829, -1.00444741],\n", + " [ 1.02197763, 0.04611375, 1.60963096, 0.43521423],\n", + " [-0.42665146, 0.2669926 , -1.0226335 , -1.04550126],\n", + " [-0.23524817, -0.51315255, -0.88112403, -1.26276976],\n", + " [ 0.11218564, -0.53560355, -0.28893432, -0.55060254],\n", + " [-0.02889448, 0.87763176, -1.05474197, -1.8619115 ],\n", + " [-1.43372355, -0.96116922, -0.66400162, -1.89558178],\n", + " [-0.01243617, -1.14342978, -1.69764533, -0.62080292],\n", + " [-0.57908435, 1.94704634, -1.35626736, -0.48912263],\n", + " [ 0.33043446, -0.46925831, 0.3484858 , 1.06964219],\n", + " [ 0.44160495, -1.29482916, -2.42269311, -2.02783931],\n", + " [ 0.12393234, -0.35137249, -1.1411141 , 0.40853396],\n", + " [-0.71259773, 0.95532399, 1.26653677, -1.11252735],\n", + " [-0.68099989, 2.12920842, -0.76729667, -1.22271607],\n", + " [-0.09915802, -0.67623208, 0.44114305, -0.03945971],\n", + " [ 0.19181804, 0.65245839, 1.07749326, 0.86936139],\n", + " [-1.73115914, 0.10496796, 2.21345244, -1.98460671],\n", + " [ 0.52023042, 0.93434087, 0.65236837, 0.03921144]]),\n", + " 'x': array([[[-1.26781729e+00, -1.56848527e+00, -1.75007674e+00,\n", + " -1.93042184e-01],\n", + " [-9.37775092e-01, -2.20027558e+00, -7.95896880e-01,\n", + " -1.47081401e+00],\n", + " [-4.61856770e-01, -3.39117216e-01, -7.30478779e-01,\n", + " -3.47666937e-01],\n", " ...,\n", - " [ 0.08928863, -0.51128244, 1.1220746 , 0.7665733 ],\n", - " [ 0.3698297 , -3.188901 , 2.025168 , 0.01316792],\n", - " [-0.0478575 , -1.2997395 , 0.98696446, -1.16682 ]],\n", + " [ 5.80608803e-01, 1.39517252e+00, -1.22429003e+00,\n", + " -1.29332305e+00],\n", + " [ 8.34692166e-01, -7.89882484e-01, -3.93390701e-01,\n", + " 3.24383636e-01],\n", + " [ 1.24818014e+00, -1.05635872e-01, -1.17490469e+00,\n", + " -2.62252489e+00]],\n", " \n", - " [[ 2.8832037 , 0.8261143 , -1.1377803 , -1.1128289 ],\n", - " [ 0.9582984 , 0.79203564, -0.3229012 , -1.6079041 ],\n", - " [ 0.16740797, 1.4642444 , -0.06614214, 0.3791665 ],\n", + " [[-1.14491834e+00, 7.14058585e-01, -2.08920723e-01,\n", + " -1.05861866e+00],\n", + " [-1.03280203e-01, 7.35892999e-02, -1.32098640e+00,\n", + " -4.51334284e-01],\n", + " [ 5.55593604e-01, -2.41702128e+00, -1.49961133e+00,\n", + " 6.74848497e-01],\n", " ...,\n", - " [ 1.3970535 , 0.965935 , 0.5211107 , -0.23564771],\n", - " [ 0.92998844, -0.22871257, -1.2391579 , 0.9288718 ],\n", - " [ 0.23330195, -1.122806 , 0.29205167, -0.21602471]],\n", + " [ 2.72805116e-02, -2.43346329e-01, -1.62073085e-01,\n", + " -3.46627566e-01],\n", + " [-1.09300949e+00, -9.60072721e-01, -2.52253395e+00,\n", + " -5.34212639e-01],\n", + " [-3.27940871e-01, 5.38197735e-02, -1.64006570e+00,\n", + " -1.25524962e+00]],\n", " \n", - " [[-2.1422992 , 1.1895313 , -1.2813323 , 0.94614196],\n", - " [-0.6536977 , -1.7648559 , -2.6089973 , -0.25280094],\n", - " [-0.9758994 , -0.09433198, 0.5126696 , -0.4680349 ],\n", + " [[ 4.45699278e-02, -9.64632138e-01, 8.05963387e-01,\n", + " 1.32152877e+00],\n", + " [ 1.68230274e+00, -1.14377222e+00, 4.93747135e-01,\n", + " 6.33770351e-01],\n", + " [ 2.46336846e+00, -7.59719410e-01, 6.19234586e-01,\n", + " -1.06262060e+00],\n", " ...,\n", - " [ 0.6029331 , -0.5393456 , -3.4281068 , 1.9625674 ],\n", - " [-2.7362714 , 1.6726367 , -0.20103195, -0.2370804 ],\n", - " [ 1.6226304 , -0.20690091, 0.6907045 , -0.4412011 ]],\n", + " [-2.21330491e-01, -1.63664977e+00, -4.24202685e-01,\n", + " 2.21401046e+00],\n", + " [-5.92631178e-01, -1.18871354e+00, 1.05690079e+00,\n", + " 1.82679143e-02],\n", + " [-6.52295715e-01, -1.74524166e+00, -1.13048986e-01,\n", + " 2.76003369e-01]],\n", " \n", " ...,\n", " \n", - " [[ 2.5122259 , 2.7933333 , -1.8370881 , -0.21269593],\n", - " [ 0.08193951, 2.768969 , -1.8215785 , -1.3286033 ],\n", - " [ 1.3069335 , 1.077523 , -3.191183 , -0.19069603],\n", + " [[ 8.90274556e-01, 2.34647339e+00, -3.12386828e-01,\n", + " 5.98836989e-01],\n", + " [ 1.08533152e+00, -6.57631048e-01, 1.08294434e+00,\n", + " 1.43612378e+00],\n", + " [-5.44868476e-01, 5.08563546e-01, 1.72220536e+00,\n", + " 2.73708834e+00],\n", " ...,\n", - " [ 2.0579758 , 1.8073362 , -0.04059944, -0.1998354 ],\n", - " [ 1.7150037 , 1.5609382 , -1.7395236 , -1.3606325 ],\n", - " [ 2.9268007 , 0.23673572, -0.95533824, -1.3200113 ]],\n", + " [-1.26782912e+00, 1.03425503e+00, -7.61075824e-01,\n", + " 2.78565962e-01],\n", + " [-8.92495473e-01, 1.83013938e+00, 1.13817595e+00,\n", + " 1.29987144e+00],\n", + " [ 2.15594110e-01, 1.58127540e+00, 1.19747991e+00,\n", + " 7.24400265e-01]],\n", " \n", - " [[ 1.4436094 , -1.3594241 , 0.415787 , 2.261267 ],\n", - " [ 0.31457698, -0.8279396 , 1.7133617 , 1.7376964 ],\n", - " [-0.72784173, -1.3070168 , 1.0091938 , 2.5164726 ],\n", + " [[-1.14824916e+00, 1.89313212e+00, 8.44448110e-01,\n", + " -1.91060527e+00],\n", + " [-9.58832969e-01, 3.88623976e-01, 3.21674153e+00,\n", + " -2.67463490e+00],\n", + " [-2.29650363e+00, -1.91242140e+00, 1.58249382e+00,\n", + " -1.27508674e+00],\n", " ...,\n", - " [ 0.4194977 , -0.2666566 , 0.8669603 , 3.1416023 ],\n", - " [-0.5483485 , -0.539848 , -0.2195546 , 1.9718776 ],\n", - " [ 0.20176356, 1.069642 , -0.65165126, 2.8493927 ]],\n", + " [-1.73705189e+00, 1.63838962e-01, 1.43653118e+00,\n", + " -2.34070132e+00],\n", + " [-9.01991105e-01, 8.94754497e-01, 1.82770411e+00,\n", + " -2.53140932e-01],\n", + " [-2.30099046e+00, 2.96103008e-01, 9.80372984e-01,\n", + " -5.14524347e-01]],\n", " \n", - " [[-0.6640122 , 1.3560729 , 0.5739129 , -0.85077333],\n", - " [-1.5651604 , 2.4200397 , 1.7220724 , -2.3291683 ],\n", - " [ 0.12706083, 0.7526239 , 0.46398893, 0.5266859 ],\n", + " [[ 2.46427849e-01, 3.94032144e-01, 1.05723707e+00,\n", + " 1.06481307e+00],\n", + " [ 8.04469283e-01, 1.51746135e+00, 1.05348254e+00,\n", + " 1.54013789e+00],\n", + " [-2.61156347e-01, -5.98895418e-01, 1.47485751e+00,\n", + " 2.36642360e+00],\n", " ...,\n", - " [ 0.28359127, 1.544274 , 1.2267944 , -0.26292163],\n", - " [-0.24361266, 2.2830348 , -0.09784857, -0.17053986],\n", - " [-0.12756453, 0.9381281 , 1.8230177 , 0.8788254 ]]],\n", - " dtype=float32)}" + " [-1.18418617e-01, 1.20249872e-03, 2.39667469e+00,\n", + " 1.57961813e-01],\n", + " [ 2.27457411e+00, 8.33884636e-02, 7.01592201e-01,\n", + " 6.07449893e-01],\n", + " [ 1.51785989e+00, 8.85566442e-01, -1.75853019e-01,\n", + " -7.83206578e-01]]])}" ] }, - "execution_count": 43, + "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 43 + "execution_count": 74 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T16:25:14.230395Z", - "start_time": "2024-10-22T16:25:14.219237Z" + "end_time": "2024-10-29T17:38:01.874530Z", + "start_time": "2024-10-29T17:38:01.864368Z" } }, "cell_type": "code", "source": [ - "data_adapter = bf.ContinuousApproximator.build_data_adapter(\n", + "data_adapter = bf.ContinuousApproximator.build_adapter(\n", " inference_variables=[\"theta\"],\n", - " inference_conditions=[\"x\"]\n", + " inference_conditions=[\"x\"],\n", + " summary_variables=[\"x\"]\n", ")" ], "id": "b0f547fc9dfec62e", "outputs": [], - "execution_count": 44 + "execution_count": 98 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T16:25:14.901880Z", - "start_time": "2024-10-22T16:25:14.887223Z" + "end_time": "2024-10-29T17:38:02.624841Z", + "start_time": "2024-10-29T17:38:02.613736Z" } }, "cell_type": "code", @@ -299,13 +334,13 @@ ], "id": "d6a75322b3e87b16", "outputs": [], - "execution_count": 45 + "execution_count": 99 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T17:39:40.767333Z", - "start_time": "2024-10-22T17:39:27.893610Z" + "end_time": "2024-10-29T17:38:16.384568Z", + "start_time": "2024-10-29T17:38:03.076637Z" } }, "cell_type": "code", @@ -317,104 +352,105 @@ " simulator=simulator, \n", " batch_size=batch_size, \n", " num_batches=num_training_batches, \n", - " data_adapter=data_adapter\n", + " adapter=data_adapter\n", ")\n", "\n", "validation_dataset = bf.datasets.OnlineDataset(\n", " simulator=simulator,\n", " batch_size=batch_size,\n", " num_batches=num_validation_batches,\n", - " data_adapter=data_adapter\n", + " adapter=data_adapter\n", ")" ], "id": "f54a245984369b8b", "outputs": [], - "execution_count": 69 + "execution_count": 100 }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T17:38:17.553107Z", + "start_time": "2024-10-29T17:38:17.521782Z" + } + }, "cell_type": "code", - "source": [ - "summary_network = bf.networks.DeepSet(summary_dim=10)\n", - "summary_network.build(input_shape=(training_samples['x'].shape))" - ], + "source": "summary_network = bf.networks.DeepSet(summary_dim=10)", "id": "6d219a2947a41c39", "outputs": [], - "execution_count": null + "execution_count": 101 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T17:33:25.157778Z", - "start_time": "2024-10-22T17:33:25.114152Z" + "end_time": "2024-10-29T17:38:18.051306Z", + "start_time": "2024-10-29T17:38:17.902383Z" + } + }, + "cell_type": "code", + "source": "summary_network.build(input_shape=(50, 100, 4))", + "id": "7969c4bd99a55111", + "outputs": [], + "execution_count": 102 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T17:38:18.238719Z", + "start_time": "2024-10-29T17:38:18.224166Z" } }, "cell_type": "code", "source": [ "inference_network = bf.networks.FlowMatching(\n", " subnet=\"mlp\",\n", - " subnet_kwargs=dict(\n", + " optimal_transport_kwargs=dict(\n", " depth=6,\n", " width=256,\n", " ),\n", ")\n", - "inference_network.build()" + "# inference_network.build()" ], "id": "ecc20e920b0dc330", "outputs": [], - "execution_count": 61 + "execution_count": 103 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T16:29:39.431488Z", - "start_time": "2024-10-22T16:29:39.361352Z" + "end_time": "2024-10-29T17:38:18.994085Z", + "start_time": "2024-10-29T17:38:18.975450Z" } }, "cell_type": "code", - "source": [ - "test_sim = simulator.sample((4,))\n", - "z, log_det_J = summary_network(test_sim['x'])" - ], + "source": "test_sim = simulator.sample((4,))", "id": "2d182d111fdacf3b", - "outputs": [ - { - "ename": "ValueError", - "evalue": "too many values to unpack (expected 2)", - "output_type": "error", - "traceback": [ - "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[1;32mIn[58], line 2\u001B[0m\n\u001B[0;32m 1\u001B[0m test_sim \u001B[38;5;241m=\u001B[39m simulator\u001B[38;5;241m.\u001B[39msample((\u001B[38;5;241m4\u001B[39m,))\n\u001B[1;32m----> 2\u001B[0m z, log_det_J \u001B[38;5;241m=\u001B[39m summary_network(test_sim[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mx\u001B[39m\u001B[38;5;124m'\u001B[39m])\n", - "\u001B[1;31mValueError\u001B[0m: too many values to unpack (expected 2)" - ] - } - ], - "execution_count": 58 + "outputs": [], + "execution_count": 104 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T17:33:32.494258Z", - "start_time": "2024-10-22T17:33:32.477712Z" + "end_time": "2024-10-29T17:38:19.521845Z", + "start_time": "2024-10-29T17:38:19.507238Z" } }, "cell_type": "code", "source": [ "approximator = bf.ContinuousApproximator(\n", " inference_network=inference_network,\n", - " data_adapter=data_adapter,\n", + " summary_network=summary_network,\n", + " adapter=data_adapter,\n", ")" ], "id": "a3b83230f640d6d9", "outputs": [], - "execution_count": 62 + "execution_count": 105 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T17:33:33.040002Z", - "start_time": "2024-10-22T17:33:33.017491Z" + "end_time": "2024-10-29T17:38:20.772249Z", + "start_time": "2024-10-29T17:38:20.752704Z" } }, "cell_type": "code", @@ -424,13 +460,13 @@ ], "id": "f0c0c672f6667945", "outputs": [], - "execution_count": 63 + "execution_count": 106 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T17:33:33.566726Z", - "start_time": "2024-10-22T17:33:33.550196Z" + "end_time": "2024-10-29T17:38:21.069937Z", + "start_time": "2024-10-29T17:38:21.056929Z" } }, "cell_type": "code", @@ -443,7 +479,6 @@ " \"validation_loss\": [],\n", " }\n", "\n", - "\n", " def on_train_batch_end(self, batch, logs=None):\n", " # 'logs' is a dictionary containing loss and other metrics\n", " training_loss = logs.get('loss')\n", @@ -456,39 +491,39 @@ ], "id": "359d6e9fe112d405", "outputs": [], - "execution_count": 64 + "execution_count": 107 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T17:33:34.490293Z", - "start_time": "2024-10-22T17:33:34.464277Z" + "end_time": "2024-10-29T17:38:21.975173Z", + "start_time": "2024-10-29T17:38:21.967175Z" } }, "cell_type": "code", "source": "approximator.compile(optimizer=optimizer)", "id": "7b96a6c3943dcf40", "outputs": [], - "execution_count": 65 + "execution_count": 108 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T17:33:34.834571Z", - "start_time": "2024-10-22T17:33:34.821630Z" + "end_time": "2024-10-29T17:38:22.848536Z", + "start_time": "2024-10-29T17:38:22.841418Z" } }, "cell_type": "code", "source": "batch_loss_history = BatchLossHistory()", "id": "e683fe5d365b279e", "outputs": [], - "execution_count": 66 + "execution_count": 109 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-22T17:33:36.209593Z", - "start_time": "2024-10-22T17:33:35.778861Z" + "end_time": "2024-10-29T17:38:23.600932Z", + "start_time": "2024-10-29T17:38:23.482465Z" } }, "cell_type": "code", @@ -497,7 +532,7 @@ " epochs=10,\n", " dataset=training_dataset,\n", " validation_data=validation_dataset,\n", - " callbacks=[batch_loss_history]\n", + " callbacks=[batch_loss_history],\n", ")" ], "id": "768ee6ac6ce0ef37", @@ -511,41 +546,31 @@ ] }, { - "ename": "TypeError", - "evalue": "Cannot concatenate arrays with different numbers of dimensions: got (128, 4), (128, 1), (128, 100, 4).", + "ename": "KeyError", + "evalue": "\"Missing keys: {'x'}\"", "output_type": "error", "traceback": [ "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[1;31mTypeError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[1;32mIn[67], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m history \u001B[38;5;241m=\u001B[39m \u001B[43mapproximator\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 2\u001B[0m \u001B[43m \u001B[49m\u001B[43mepochs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m10\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[0;32m 3\u001B[0m \u001B[43m \u001B[49m\u001B[43mdataset\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtraining_dataset\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 4\u001B[0m \u001B[43m \u001B[49m\u001B[43mvalidation_data\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvalidation_dataset\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 5\u001B[0m \u001B[43m \u001B[49m\u001B[43mcallbacks\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43m[\u001B[49m\u001B[43mbatch_loss_history\u001B[49m\u001B[43m]\u001B[49m\n\u001B[0;32m 6\u001B[0m \u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\approximators\\continuous_approximator.py:109\u001B[0m, in \u001B[0;36mContinuousApproximator.fit\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 108\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mfit\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m--> 109\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39mfit(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs, data_adapter\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdata_adapter)\n", - "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\approximators\\approximator.py:82\u001B[0m, in \u001B[0;36mApproximator.fit\u001B[1;34m(self, dataset, simulator, **kwargs)\u001B[0m\n\u001B[0;32m 80\u001B[0m mock_data \u001B[38;5;241m=\u001B[39m dataset[\u001B[38;5;241m0\u001B[39m]\n\u001B[0;32m 81\u001B[0m mock_data \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mtree\u001B[38;5;241m.\u001B[39mmap_structure(keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mconvert_to_tensor, mock_data)\n\u001B[1;32m---> 82\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbuild_from_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmock_data\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 84\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39mfit(dataset\u001B[38;5;241m=\u001B[39mdataset, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n", - "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\approximators\\approximator.py:23\u001B[0m, in \u001B[0;36mApproximator.build_from_data\u001B[1;34m(self, data)\u001B[0m\n\u001B[0;32m 22\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mbuild_from_data\u001B[39m(\u001B[38;5;28mself\u001B[39m, data: \u001B[38;5;28mdict\u001B[39m[\u001B[38;5;28mstr\u001B[39m, \u001B[38;5;28many\u001B[39m]) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m---> 23\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcompute_metrics(\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mdata, stage\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtraining\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 24\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuilt \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n", - "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\approximators\\continuous_approximator.py:95\u001B[0m, in \u001B[0;36mContinuousApproximator.compute_metrics\u001B[1;34m(self, inference_variables, inference_conditions, summary_variables, stage)\u001B[0m\n\u001B[0;32m 92\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m 93\u001B[0m inference_conditions \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mconcatenate([inference_conditions, summary_outputs], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m---> 95\u001B[0m inference_metrics \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43minference_network\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcompute_metrics\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 96\u001B[0m \u001B[43m \u001B[49m\u001B[43minference_variables\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mconditions\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minference_conditions\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mstage\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mstage\u001B[49m\n\u001B[0;32m 97\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 99\u001B[0m loss \u001B[38;5;241m=\u001B[39m inference_metrics\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mloss\u001B[39m\u001B[38;5;124m\"\u001B[39m, keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mzeros(())) \u001B[38;5;241m+\u001B[39m summary_metrics\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mloss\u001B[39m\u001B[38;5;124m\"\u001B[39m, keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mzeros(()))\n\u001B[0;32m 101\u001B[0m inference_metrics \u001B[38;5;241m=\u001B[39m {\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mkey\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m/inference_\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mkey\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m: value \u001B[38;5;28;01mfor\u001B[39;00m key, value \u001B[38;5;129;01min\u001B[39;00m inference_metrics\u001B[38;5;241m.\u001B[39mitems()}\n", - "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\networks\\flow_matching\\flow_matching.py:122\u001B[0m, in \u001B[0;36mFlowMatching.compute_metrics\u001B[1;34m(self, x, conditions, stage)\u001B[0m\n\u001B[0;32m 118\u001B[0m target_velocity \u001B[38;5;241m=\u001B[39m x1 \u001B[38;5;241m-\u001B[39m x0\n\u001B[0;32m 120\u001B[0m base_metrics \u001B[38;5;241m=\u001B[39m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39mcompute_metrics(x1, conditions, stage)\n\u001B[1;32m--> 122\u001B[0m predicted_velocity \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mintegrator\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvelocity\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mt\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mconditions\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 124\u001B[0m loss \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mlosses\u001B[38;5;241m.\u001B[39mmean_squared_error(target_velocity, predicted_velocity)\n\u001B[0;32m 125\u001B[0m loss \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mmean(loss)\n", - "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\networks\\flow_matching\\integrators\\euler.py:45\u001B[0m, in \u001B[0;36mEulerIntegrator.velocity\u001B[1;34m(self, x, t, conditions, **kwargs)\u001B[0m\n\u001B[0;32m 43\u001B[0m xtc \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mconcatenate([x, t], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)\n\u001B[0;32m 44\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m---> 45\u001B[0m xtc \u001B[38;5;241m=\u001B[39m \u001B[43mkeras\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mops\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconcatenate\u001B[49m\u001B[43m(\u001B[49m\u001B[43m[\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mt\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mconditions\u001B[49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maxis\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m-\u001B[39;49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[0;32m 47\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moutput_projector(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39msubnet(xtc, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs))\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\keras\\src\\ops\\numpy.py:1352\u001B[0m, in \u001B[0;36mconcatenate\u001B[1;34m(xs, axis)\u001B[0m\n\u001B[0;32m 1350\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m any_symbolic_tensors(xs):\n\u001B[0;32m 1351\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m Concatenate(axis\u001B[38;5;241m=\u001B[39maxis)\u001B[38;5;241m.\u001B[39msymbolic_call(xs)\n\u001B[1;32m-> 1352\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mbackend\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnumpy\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconcatenate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mxs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maxis\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43maxis\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\keras\\src\\backend\\jax\\numpy.py:405\u001B[0m, in \u001B[0;36mconcatenate\u001B[1;34m(xs, axis)\u001B[0m\n\u001B[0;32m 400\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m 401\u001B[0m xs \u001B[38;5;241m=\u001B[39m [\n\u001B[0;32m 402\u001B[0m x\u001B[38;5;241m.\u001B[39mtodense() \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(x, jax_sparse\u001B[38;5;241m.\u001B[39mJAXSparse) \u001B[38;5;28;01melse\u001B[39;00m x\n\u001B[0;32m 403\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m x \u001B[38;5;129;01min\u001B[39;00m xs\n\u001B[0;32m 404\u001B[0m ]\n\u001B[1;32m--> 405\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mjnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconcatenate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mxs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maxis\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43maxis\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\numpy\\lax_numpy.py:4243\u001B[0m, in \u001B[0;36mconcatenate\u001B[1;34m(arrays, axis, dtype)\u001B[0m\n\u001B[0;32m 4241\u001B[0m k \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m16\u001B[39m\n\u001B[0;32m 4242\u001B[0m \u001B[38;5;28;01mwhile\u001B[39;00m \u001B[38;5;28mlen\u001B[39m(arrays_out) \u001B[38;5;241m>\u001B[39m \u001B[38;5;241m1\u001B[39m:\n\u001B[1;32m-> 4243\u001B[0m arrays_out \u001B[38;5;241m=\u001B[39m [lax\u001B[38;5;241m.\u001B[39mconcatenate(arrays_out[i:i\u001B[38;5;241m+\u001B[39mk], axis)\n\u001B[0;32m 4244\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(\u001B[38;5;241m0\u001B[39m, \u001B[38;5;28mlen\u001B[39m(arrays_out), k)]\n\u001B[0;32m 4245\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m arrays_out[\u001B[38;5;241m0\u001B[39m]\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\numpy\\lax_numpy.py:4243\u001B[0m, in \u001B[0;36m\u001B[1;34m(.0)\u001B[0m\n\u001B[0;32m 4241\u001B[0m k \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m16\u001B[39m\n\u001B[0;32m 4242\u001B[0m \u001B[38;5;28;01mwhile\u001B[39;00m \u001B[38;5;28mlen\u001B[39m(arrays_out) \u001B[38;5;241m>\u001B[39m \u001B[38;5;241m1\u001B[39m:\n\u001B[1;32m-> 4243\u001B[0m arrays_out \u001B[38;5;241m=\u001B[39m [\u001B[43mlax\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconcatenate\u001B[49m\u001B[43m(\u001B[49m\u001B[43marrays_out\u001B[49m\u001B[43m[\u001B[49m\u001B[43mi\u001B[49m\u001B[43m:\u001B[49m\u001B[43mi\u001B[49m\u001B[38;5;241;43m+\u001B[39;49m\u001B[43mk\u001B[49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maxis\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 4244\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(\u001B[38;5;241m0\u001B[39m, \u001B[38;5;28mlen\u001B[39m(arrays_out), k)]\n\u001B[0;32m 4245\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m arrays_out[\u001B[38;5;241m0\u001B[39m]\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\lax\\lax.py:650\u001B[0m, in \u001B[0;36mconcatenate\u001B[1;34m(operands, dimension)\u001B[0m\n\u001B[0;32m 648\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(op, Array):\n\u001B[0;32m 649\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m op\n\u001B[1;32m--> 650\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mconcatenate_p\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbind\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43moperands\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdimension\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mdimension\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\core.py:438\u001B[0m, in \u001B[0;36mPrimitive.bind\u001B[1;34m(self, *args, **params)\u001B[0m\n\u001B[0;32m 435\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mbind\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mparams):\n\u001B[0;32m 436\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m (\u001B[38;5;129;01mnot\u001B[39;00m config\u001B[38;5;241m.\u001B[39menable_checks\u001B[38;5;241m.\u001B[39mvalue \u001B[38;5;129;01mor\u001B[39;00m\n\u001B[0;32m 437\u001B[0m \u001B[38;5;28mall\u001B[39m(\u001B[38;5;28misinstance\u001B[39m(arg, Tracer) \u001B[38;5;129;01mor\u001B[39;00m valid_jaxtype(arg) \u001B[38;5;28;01mfor\u001B[39;00m arg \u001B[38;5;129;01min\u001B[39;00m args)), args\n\u001B[1;32m--> 438\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbind_with_trace\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfind_top_trace\u001B[49m\u001B[43m(\u001B[49m\u001B[43margs\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mparams\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\core.py:442\u001B[0m, in \u001B[0;36mPrimitive.bind_with_trace\u001B[1;34m(self, trace, args, params)\u001B[0m\n\u001B[0;32m 440\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mbind_with_trace\u001B[39m(\u001B[38;5;28mself\u001B[39m, trace, args, params):\n\u001B[0;32m 441\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m pop_level(trace\u001B[38;5;241m.\u001B[39mlevel):\n\u001B[1;32m--> 442\u001B[0m out \u001B[38;5;241m=\u001B[39m \u001B[43mtrace\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mprocess_primitive\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mmap\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mtrace\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfull_raise\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43margs\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mparams\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 443\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mmap\u001B[39m(full_lower, out) \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmultiple_results \u001B[38;5;28;01melse\u001B[39;00m full_lower(out)\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\core.py:948\u001B[0m, in \u001B[0;36mEvalTrace.process_primitive\u001B[1;34m(self, primitive, tracers, params)\u001B[0m\n\u001B[0;32m 946\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m call_impl_with_key_reuse_checks(primitive, primitive\u001B[38;5;241m.\u001B[39mimpl, \u001B[38;5;241m*\u001B[39mtracers, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mparams)\n\u001B[0;32m 947\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m--> 948\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m primitive\u001B[38;5;241m.\u001B[39mimpl(\u001B[38;5;241m*\u001B[39mtracers, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mparams)\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\dispatch.py:90\u001B[0m, in \u001B[0;36mapply_primitive\u001B[1;34m(prim, *args, **params)\u001B[0m\n\u001B[0;32m 88\u001B[0m prev \u001B[38;5;241m=\u001B[39m lib\u001B[38;5;241m.\u001B[39mjax_jit\u001B[38;5;241m.\u001B[39mswap_thread_local_state_disable_jit(\u001B[38;5;28;01mFalse\u001B[39;00m)\n\u001B[0;32m 89\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m---> 90\u001B[0m outs \u001B[38;5;241m=\u001B[39m \u001B[43mfun\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 91\u001B[0m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[0;32m 92\u001B[0m lib\u001B[38;5;241m.\u001B[39mjax_jit\u001B[38;5;241m.\u001B[39mswap_thread_local_state_disable_jit(prev)\n", - " \u001B[1;31m[... skipping hidden 18 frame]\u001B[0m\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\jax\\_src\\lax\\lax.py:3904\u001B[0m, in \u001B[0;36m_concatenate_shape_rule\u001B[1;34m(*operands, **kwargs)\u001B[0m\n\u001B[0;32m 3902\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mlen\u001B[39m({operand\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;28;01mfor\u001B[39;00m operand \u001B[38;5;129;01min\u001B[39;00m operands}) \u001B[38;5;241m!=\u001B[39m \u001B[38;5;241m1\u001B[39m:\n\u001B[0;32m 3903\u001B[0m msg \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mCannot concatenate arrays with different numbers of dimensions: got \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m-> 3904\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(msg\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m, \u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39mjoin(\u001B[38;5;28mstr\u001B[39m(o\u001B[38;5;241m.\u001B[39mshape) \u001B[38;5;28;01mfor\u001B[39;00m o \u001B[38;5;129;01min\u001B[39;00m operands)))\n\u001B[0;32m 3905\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;241m0\u001B[39m \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m dimension \u001B[38;5;241m<\u001B[39m operands[\u001B[38;5;241m0\u001B[39m]\u001B[38;5;241m.\u001B[39mndim:\n\u001B[0;32m 3906\u001B[0m msg \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mconcatenate dimension out of bounds: dimension \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m for shapes \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n", - "\u001B[1;31mTypeError\u001B[0m: Cannot concatenate arrays with different numbers of dimensions: got (128, 4), (128, 1), (128, 100, 4)." + "\u001B[1;31mKeyError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[1;32mIn[110], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m history \u001B[38;5;241m=\u001B[39m \u001B[43mapproximator\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 2\u001B[0m \u001B[43m \u001B[49m\u001B[43mepochs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m10\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[0;32m 3\u001B[0m \u001B[43m \u001B[49m\u001B[43mdataset\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtraining_dataset\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 4\u001B[0m \u001B[43m \u001B[49m\u001B[43mvalidation_data\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvalidation_dataset\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 5\u001B[0m \u001B[43m \u001B[49m\u001B[43mcallbacks\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43m[\u001B[49m\u001B[43mbatch_loss_history\u001B[49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 6\u001B[0m \u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\approximators\\continuous_approximator.py:114\u001B[0m, in \u001B[0;36mContinuousApproximator.fit\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 113\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mfit\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m--> 114\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39mfit(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs, adapter\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39madapter)\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\approximators\\approximator.py:80\u001B[0m, in \u001B[0;36mApproximator.fit\u001B[1;34m(self, dataset, simulator, **kwargs)\u001B[0m\n\u001B[0;32m 78\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuilt:\n\u001B[0;32m 79\u001B[0m logging\u001B[38;5;241m.\u001B[39minfo(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mBuilding on a test batch.\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m---> 80\u001B[0m mock_data \u001B[38;5;241m=\u001B[39m \u001B[43mdataset\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m]\u001B[49m\n\u001B[0;32m 81\u001B[0m mock_data \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mtree\u001B[38;5;241m.\u001B[39mmap_structure(keras\u001B[38;5;241m.\u001B[39mops\u001B[38;5;241m.\u001B[39mconvert_to_tensor, mock_data)\n\u001B[0;32m 82\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuild_from_data(mock_data)\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\datasets\\online_dataset.py:38\u001B[0m, in \u001B[0;36mOnlineDataset.__getitem__\u001B[1;34m(self, item)\u001B[0m\n\u001B[0;32m 35\u001B[0m batch \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39msimulator\u001B[38;5;241m.\u001B[39msample((\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbatch_size,))\n\u001B[0;32m 37\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39madapter \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m---> 38\u001B[0m batch \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43madapter\u001B[49m\u001B[43m(\u001B[49m\u001B[43mbatch\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbatch_size\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 40\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m batch\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\adapters\\adapter.py:73\u001B[0m, in \u001B[0;36mAdapter.__call__\u001B[1;34m(self, data, inverse, **kwargs)\u001B[0m\n\u001B[0;32m 70\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m inverse:\n\u001B[0;32m 71\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39minverse(data, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m---> 73\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mforward(data, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\adapters\\adapter.py:57\u001B[0m, in \u001B[0;36mAdapter.forward\u001B[1;34m(self, data, **kwargs)\u001B[0m\n\u001B[0;32m 54\u001B[0m data \u001B[38;5;241m=\u001B[39m data\u001B[38;5;241m.\u001B[39mcopy()\n\u001B[0;32m 56\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m transform \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mtransforms:\n\u001B[1;32m---> 57\u001B[0m data \u001B[38;5;241m=\u001B[39m transform(data, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 59\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m data\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\adapters\\transforms\\transform.py:11\u001B[0m, in \u001B[0;36mTransform.__call__\u001B[1;34m(self, data, inverse, **kwargs)\u001B[0m\n\u001B[0;32m 8\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m inverse:\n\u001B[0;32m 9\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39minverse(data, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m---> 11\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mforward(data, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\adapters\\transforms\\map_transform.py:51\u001B[0m, in \u001B[0;36mMapTransform.forward\u001B[1;34m(self, data, strict, **kwargs)\u001B[0m\n\u001B[0;32m 48\u001B[0m missing_keys \u001B[38;5;241m=\u001B[39m required_keys \u001B[38;5;241m-\u001B[39m available_keys\n\u001B[0;32m 50\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m strict \u001B[38;5;129;01mand\u001B[39;00m missing_keys:\n\u001B[1;32m---> 51\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mKeyError\u001B[39;00m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mMissing keys: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mmissing_keys\u001B[38;5;132;01m!r}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 53\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m key, transform \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mtransform_map\u001B[38;5;241m.\u001B[39mitems():\n\u001B[0;32m 54\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m key \u001B[38;5;129;01min\u001B[39;00m data:\n", + "\u001B[1;31mKeyError\u001B[0m: \"Missing keys: {'x'}\"" ] } ], - "execution_count": 67 + "execution_count": 110 }, { "metadata": { "ExecuteTime": { - "end_time": "2024-10-20T18:45:43.804570Z", - "start_time": "2024-10-20T18:45:42.197334Z" + "end_time": "2024-10-29T17:37:22.721463Z", + "start_time": "2024-10-29T17:37:22.075196Z" } }, "cell_type": "code", @@ -559,18 +584,34 @@ ], "id": "4aa8f4aa440e9925", "outputs": [ + { + "ename": "ValueError", + "evalue": "Number of rows must be a positive integer, not 0", + "output_type": "error", + "traceback": [ + "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[1;32mIn[97], line 3\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpandas\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mpd\u001B[39;00m\n\u001B[1;32m----> 3\u001B[0m f \u001B[38;5;241m=\u001B[39m \u001B[43mplot_losses\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 4\u001B[0m \u001B[43m \u001B[49m\u001B[43mtrain_losses\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mpd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mDataFrame\u001B[49m\u001B[43m(\u001B[49m\u001B[43mbatch_loss_history\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mlosses\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mtraining_loss\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\n\u001B[0;32m 5\u001B[0m \u001B[43m \u001B[49m\u001B[43mval_losses\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mpd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mDataFrame\u001B[49m\u001B[43m(\u001B[49m\u001B[43mbatch_loss_history\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mlosses\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mvalidation_loss\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 6\u001B[0m \u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\Documents\\Native\\Development\\BayesFlow\\bayesflow\\diagnostics\\plot_losses.py:78\u001B[0m, in \u001B[0;36mplot_losses\u001B[1;34m(train_losses, val_losses, moving_average, ma_window_fraction, fig_size, train_color, val_color, lw_train, lw_val, grid_alpha, legend_fontsize, label_fontsize, title_fontsize)\u001B[0m\n\u001B[0;32m 76\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m fig_size \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m 77\u001B[0m fig_size \u001B[38;5;241m=\u001B[39m (\u001B[38;5;241m16\u001B[39m, \u001B[38;5;28mint\u001B[39m(\u001B[38;5;241m4\u001B[39m \u001B[38;5;241m*\u001B[39m n_row))\n\u001B[1;32m---> 78\u001B[0m f, axarr \u001B[38;5;241m=\u001B[39m \u001B[43mplt\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msubplots\u001B[49m\u001B[43m(\u001B[49m\u001B[43mn_row\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfigsize\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mfig_size\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 80\u001B[0m \u001B[38;5;66;03m# Get the number of steps as an array\u001B[39;00m\n\u001B[0;32m 81\u001B[0m train_step_index \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marange(\u001B[38;5;241m1\u001B[39m, \u001B[38;5;28mlen\u001B[39m(train_losses) \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m)\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\matplotlib\\pyplot.py:1502\u001B[0m, in \u001B[0;36msubplots\u001B[1;34m(nrows, ncols, sharex, sharey, squeeze, width_ratios, height_ratios, subplot_kw, gridspec_kw, **fig_kw)\u001B[0m\n\u001B[0;32m 1358\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 1359\u001B[0m \u001B[38;5;124;03mCreate a figure and a set of subplots.\u001B[39;00m\n\u001B[0;32m 1360\u001B[0m \n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 1499\u001B[0m \n\u001B[0;32m 1500\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 1501\u001B[0m fig \u001B[38;5;241m=\u001B[39m figure(\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mfig_kw)\n\u001B[1;32m-> 1502\u001B[0m axs \u001B[38;5;241m=\u001B[39m \u001B[43mfig\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msubplots\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnrows\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnrows\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mncols\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mncols\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msharex\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43msharex\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msharey\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43msharey\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1503\u001B[0m \u001B[43m \u001B[49m\u001B[43msqueeze\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43msqueeze\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msubplot_kw\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43msubplot_kw\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1504\u001B[0m \u001B[43m \u001B[49m\u001B[43mgridspec_kw\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgridspec_kw\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mheight_ratios\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mheight_ratios\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1505\u001B[0m \u001B[43m \u001B[49m\u001B[43mwidth_ratios\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mwidth_ratios\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1506\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m fig, axs\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\matplotlib\\figure.py:905\u001B[0m, in \u001B[0;36mFigureBase.subplots\u001B[1;34m(self, nrows, ncols, sharex, sharey, squeeze, width_ratios, height_ratios, subplot_kw, gridspec_kw)\u001B[0m\n\u001B[0;32m 901\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mwidth_ratios\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m must not be defined both as \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 902\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mparameter and as key in \u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mgridspec_kw\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 903\u001B[0m gridspec_kw[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mwidth_ratios\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m=\u001B[39m width_ratios\n\u001B[1;32m--> 905\u001B[0m gs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39madd_gridspec(nrows, ncols, figure\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mgridspec_kw)\n\u001B[0;32m 906\u001B[0m axs \u001B[38;5;241m=\u001B[39m gs\u001B[38;5;241m.\u001B[39msubplots(sharex\u001B[38;5;241m=\u001B[39msharex, sharey\u001B[38;5;241m=\u001B[39msharey, squeeze\u001B[38;5;241m=\u001B[39msqueeze,\n\u001B[0;32m 907\u001B[0m subplot_kw\u001B[38;5;241m=\u001B[39msubplot_kw)\n\u001B[0;32m 908\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m axs\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\matplotlib\\figure.py:1527\u001B[0m, in \u001B[0;36mFigureBase.add_gridspec\u001B[1;34m(self, nrows, ncols, **kwargs)\u001B[0m\n\u001B[0;32m 1488\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 1489\u001B[0m \u001B[38;5;124;03mReturn a `.GridSpec` that has this figure as a parent. This allows\u001B[39;00m\n\u001B[0;32m 1490\u001B[0m \u001B[38;5;124;03mcomplex layout of Axes in the figure.\u001B[39;00m\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 1523\u001B[0m \n\u001B[0;32m 1524\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 1526\u001B[0m _ \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mpop(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mfigure\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;28;01mNone\u001B[39;00m) \u001B[38;5;66;03m# pop in case user has added this...\u001B[39;00m\n\u001B[1;32m-> 1527\u001B[0m gs \u001B[38;5;241m=\u001B[39m GridSpec(nrows\u001B[38;5;241m=\u001B[39mnrows, ncols\u001B[38;5;241m=\u001B[39mncols, figure\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1528\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m gs\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\matplotlib\\gridspec.py:379\u001B[0m, in \u001B[0;36mGridSpec.__init__\u001B[1;34m(self, nrows, ncols, figure, left, bottom, right, top, wspace, hspace, width_ratios, height_ratios)\u001B[0m\n\u001B[0;32m 376\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mhspace \u001B[38;5;241m=\u001B[39m hspace\n\u001B[0;32m 377\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfigure \u001B[38;5;241m=\u001B[39m figure\n\u001B[1;32m--> 379\u001B[0m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__init__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mnrows\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mncols\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 380\u001B[0m \u001B[43m \u001B[49m\u001B[43mwidth_ratios\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mwidth_ratios\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 381\u001B[0m \u001B[43m \u001B[49m\u001B[43mheight_ratios\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mheight_ratios\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\matplotlib\\gridspec.py:49\u001B[0m, in \u001B[0;36mGridSpecBase.__init__\u001B[1;34m(self, nrows, ncols, height_ratios, width_ratios)\u001B[0m\n\u001B[0;32m 34\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 35\u001B[0m \u001B[38;5;124;03mParameters\u001B[39;00m\n\u001B[0;32m 36\u001B[0m \u001B[38;5;124;03m----------\u001B[39;00m\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 46\u001B[0m \u001B[38;5;124;03m If not given, all rows will have the same height.\u001B[39;00m\n\u001B[0;32m 47\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 48\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(nrows, Integral) \u001B[38;5;129;01mor\u001B[39;00m nrows \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[1;32m---> 49\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m 50\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mNumber of rows must be a positive integer, not \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mnrows\u001B[38;5;132;01m!r}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 51\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(ncols, Integral) \u001B[38;5;129;01mor\u001B[39;00m ncols \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[0;32m 52\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m 53\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mNumber of columns must be a positive integer, not \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mncols\u001B[38;5;132;01m!r}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m)\n", + "\u001B[1;31mValueError\u001B[0m: Number of rows must be a positive integer, not 0" + ] + }, { "data": { "text/plain": [ - "
" - ], - "image/png": "" + "
" + ] }, "metadata": {}, "output_type": "display_data" } ], - "execution_count": 36 + "execution_count": 97 }, { "metadata": { From 35267fedee2392874d7158a554b16b9c1eb6d62b Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 30 Oct 2024 21:02:04 -0400 Subject: [PATCH 13/22] Fixed linting from plot_losses and plot_recovery --- bayesflow/diagnostics/plot_losses.py | 3 --- bayesflow/diagnostics/plot_recovery.py | 1 - 2 files changed, 4 deletions(-) diff --git a/bayesflow/diagnostics/plot_losses.py b/bayesflow/diagnostics/plot_losses.py index 3797b0257..d43815771 100644 --- a/bayesflow/diagnostics/plot_losses.py +++ b/bayesflow/diagnostics/plot_losses.py @@ -3,9 +3,6 @@ import seaborn as sns import matplotlib.pyplot as plt -from tensorflow.keras import ops -from ..utils.plot_utils import initialize_figure - def plot_losses( train_losses, diff --git a/bayesflow/diagnostics/plot_recovery.py b/bayesflow/diagnostics/plot_recovery.py index 4031a7542..ac7d32dcd 100644 --- a/bayesflow/diagnostics/plot_recovery.py +++ b/bayesflow/diagnostics/plot_recovery.py @@ -5,7 +5,6 @@ import matplotlib.pyplot as plt import seaborn as sns -from ..utils.plot_utils import preprocess, postprocess from ..utils.plot_utils import check_posterior_prior_shapes def plot_recovery( From e290506f141147f04d2cc74f3c0c6bb8a3e270a8 Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 30 Oct 2024 21:05:52 -0400 Subject: [PATCH 14/22] Fixed linting from plot_utils --- bayesflow/utils/plot_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index 402a61e05..08e75078f 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -25,28 +25,28 @@ def check_posterior_prior_shapes(post_samples, prior_samples): if len(post_samples.shape) != 3: raise ShapeError( - f"post_samples should be a 3-dimensional array, with the " - + f"first dimension being the number of (simulated) data sets, " - + f"the second dimension being the number of posterior draws per data set, " - + f"and the third dimension being the number of parameters (marginal distributions), " + "post_samples should be a 3-dimensional array, with the " + + "first dimension being the number of (simulated) data sets, " + + "the second dimension being the number of posterior draws per data set, " + + "and the third dimension being the number of parameters (marginal distributions), " + f"but your input has dimensions {len(post_samples.shape)}" ) elif len(prior_samples.shape) != 2: raise ShapeError( - f"prior_samples should be a 2-dimensional array, with the " - + f"first dimension being the number of (simulated) data sets / prior draws " - + f"and the second dimension being the number of parameters (marginal distributions), " + "prior_samples should be a 2-dimensional array, with the " + + "first dimension being the number of (simulated) data sets / prior draws " + + "and the second dimension being the number of parameters (marginal distributions), " + f"but your input has dimensions {len(prior_samples.shape)}" ) elif post_samples.shape[0] != prior_samples.shape[0]: raise ShapeError( - f"The number of elements over the first dimension of post_samples and prior_samples" + "The number of elements over the first dimension of post_samples and prior_samples" + f"should match, but post_samples has {post_samples.shape[0]} and prior_samples has " + f"{prior_samples.shape[0]} elements, respectively." ) elif post_samples.shape[-1] != prior_samples.shape[-1]: raise ShapeError( - f"The number of elements over the last dimension of post_samples and prior_samples" + "The number of elements over the last dimension of post_samples and prior_samples" + f"should match, but post_samples has {post_samples.shape[1]} and prior_samples has " + f"{prior_samples.shape[-1]} elements, respectively." ) From 95881af958151e6a1576b3f8cb652da7d0c2d5a4 Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 30 Oct 2024 21:08:41 -0400 Subject: [PATCH 15/22] Fixed linting from plot_sbc_histogram --- bayesflow/diagnostics/plot_sbc_histogram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/diagnostics/plot_sbc_histogram.py b/bayesflow/diagnostics/plot_sbc_histogram.py index 100ec9704..41542a046 100644 --- a/bayesflow/diagnostics/plot_sbc_histogram.py +++ b/bayesflow/diagnostics/plot_sbc_histogram.py @@ -80,7 +80,7 @@ def plot_sbc_histograms( logger = logging.getLogger() logger.setLevel(logging.INFO) logger.info( - f"The ratio of simulations / posterior draws should be > 20 " + "The ratio of simulations / posterior draws should be > 20 " + f"for reliable variance reduction, but your ratio is {ratio}.\ Confidence intervals might be unreliable!" ) From 6404b873e942da62c72bbda07e9d5a5037f8d0ba Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 30 Oct 2024 21:54:05 -0400 Subject: [PATCH 16/22] Reformatting --- bayesflow/diagnostics/__init__.py | 2 +- bayesflow/diagnostics/plot_distribution_2d.py | 28 +++++-- bayesflow/diagnostics/plot_losses.py | 35 ++++++--- bayesflow/diagnostics/plot_recovery.py | 44 +++++++---- bayesflow/diagnostics/plot_sbc_ecdf.py | 78 +++++++++++++------ .../diagnostics/plot_z_score_contraction.py | 28 ++++--- 6 files changed, 148 insertions(+), 67 deletions(-) diff --git a/bayesflow/diagnostics/__init__.py b/bayesflow/diagnostics/__init__.py index 0d77a37aa..d75daa681 100644 --- a/bayesflow/diagnostics/__init__.py +++ b/bayesflow/diagnostics/__init__.py @@ -3,4 +3,4 @@ from .plot_sbc_ecdf import plot_sbc_ecdf from .plot_sbc_histogram import plot_sbc_histograms from .plot_distribution_2d import plot_distribution_2d -from .plot_z_score_contraction import plot_z_score_contraction \ No newline at end of file +from .plot_z_score_contraction import plot_z_score_contraction diff --git a/bayesflow/diagnostics/plot_distribution_2d.py b/bayesflow/diagnostics/plot_distribution_2d.py index f03ba7907..e0a0b722b 100644 --- a/bayesflow/diagnostics/plot_distribution_2d.py +++ b/bayesflow/diagnostics/plot_distribution_2d.py @@ -5,6 +5,7 @@ from bayesflow.types import Tensor + def plot_distribution_2d( samples: dict[str, Tensor] = None, parameters: str = None, @@ -17,7 +18,8 @@ def plot_distribution_2d( **kwargs ): """ - A more flexible pair plot function for multiple distributions based upon collected samples. + A more flexible pair plot function for multiple distributions based upon + collected samples. Parameters ---------- @@ -43,7 +45,7 @@ def plot_distribution_2d( Additional keyword arguments passed to the sns.PairGrid constructor """ # Get latent dimensions - dim = samples.shape[-1] + dim = samples.values().shape[-1] # Get number of params if n_params is None: @@ -65,16 +67,28 @@ def plot_distribution_2d( # Generate plots artist = sns.PairGrid(data_to_plot, height=height, **kwargs) - artist.map_diag(sns.histplot, fill=True, color=color, alpha=alpha, kde=True) + artist.map_diag( + sns.histplot, fill=True, color=color, alpha=alpha, kde=True + ) # Incorporate exceptions for generating KDE plots try: - artist.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha) + artist.map_lower( + sns.kdeplot, fill=True, color=color, alpha=alpha + ) except Exception as e: - logging.warning("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.") - artist.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) + logging.warning( + "KDE failed due to the following exception:\n" + + repr(e) + + "\nSubstituting scatter plot." + ) + artist.map_lower( + sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color + ) - artist.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) + artist.map_upper( + sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color + ) if render: # Generate grids diff --git a/bayesflow/diagnostics/plot_losses.py b/bayesflow/diagnostics/plot_losses.py index d43815771..111ef63d6 100644 --- a/bayesflow/diagnostics/plot_losses.py +++ b/bayesflow/diagnostics/plot_losses.py @@ -19,25 +19,31 @@ def plot_losses( label_fontsize=14, title_fontsize=16, ): - """A generic helper function to plot the losses of a series of training epochs and runs. + """ + A generic helper function to plot the losses of a series of training epochs + and runs. Parameters ---------- train_losses : pd.DataFrame - The (plottable) history as returned by a train_[...] method of a ``Trainer`` instance. - Alternatively, you can just pass a data frame of validation losses instead of train losses, - if you only want to plot the validation loss. + The (plottable) history as returned by a train_[...] method of a + ``Trainer`` instance. + Alternatively, you can just pass a data frame of validation losses + instead of train losses, if you only want to plot the validation loss. val_losses : pd.DataFrame or None, optional, default: None - The (plottable) validation history as returned by a train_[...] method of a ``Trainer`` instance. - If left ``None``, only train losses are plotted. Should have the same number of columns - as ``train_losses``. + The (plottable) validation history as returned by a train_[...] method + of a ``Trainer`` instance. + If left ``None``, only train losses are plotted. Should have the same + number of columns as ``train_losses``. moving_average : bool, optional, default: False A flag for adding a moving average line of the train_losses. ma_window_fraction : int, optional, default: 0.01 - Window size for the moving average as a fraction of total training steps. + Window size for the moving average as a fraction of total + training steps. fig_size : tuple or None, optional, default: None - The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` + The figure size passed to the ``matplotlib`` constructor. + Inferred if ``None`` train_color : str, optional, default: '#8f2727' The color for the train loss trajectory val_color : str, optional, default: black @@ -88,11 +94,18 @@ def plot_losses( looper = [axarr] if n_row == 1 else axarr.flat for i, ax in enumerate(looper): # Plot train curve - ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training") + ax.plot( + train_step_index, train_losses.iloc[:, i], + color=train_color, lw=lw_train, alpha=0.9, label="Training" + ) if moving_average and train_losses.columns[i] == "Loss": moving_average_window = int(train_losses.shape[0] * ma_window_fraction) smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean() - ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)") + ax.plot( + train_step_index, smoothed_loss, + color="grey", lw=lw_train, + label="Training (Moving Average)" + ) # Plot optional val curve if val_losses is not None: diff --git a/bayesflow/diagnostics/plot_recovery.py b/bayesflow/diagnostics/plot_recovery.py index ac7d32dcd..b97eee1c8 100644 --- a/bayesflow/diagnostics/plot_recovery.py +++ b/bayesflow/diagnostics/plot_recovery.py @@ -27,18 +27,24 @@ def plot_recovery( ylabel="Estimated", **kwargs, ): - """Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty. - The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate - can be controlled with the ``uncertainty_agg`` argument. + """ + Creates and plots publication-ready recovery plot with true estimate + vs. point estimate + uncertainty. + The point estimate can be controlled with the ``point_agg`` argument, + and the uncertainty estimate can be controlled with the + ``uncertainty_agg`` argument. - This plot yields similar information as the "posterior z-score", but allows for generic - point and uncertainty estimates: + This plot yields similar information as the "posterior z-score", + but allows for generic point and uncertainty estimates: https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html - Important: Posterior aggregates play no special role in Bayesian inference and should only - be used heuristically. For instance, in the case of multi-modal posteriors, common point - estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing. + Important: + Posterior aggregates play no special role in Bayesian inference and + should only be used heuristically. + For instance, in the case of multi-modal posteriors, common point + estimates, such as mean, (geometric) median, or maximum a posteriori + (MAP) mean nothing. Parameters ---------- @@ -133,9 +139,13 @@ def plot_recovery( # Add scatter and error bars if uncertainty_agg is not None: - _ = ax.errorbar(prior_samples[:, i], est[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color, **kwargs) + _ = ax.errorbar( + prior_samples[:, i], est[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color, **kwargs + ) else: - _ = ax.scatter(prior_samples[:, i], est[:, i], alpha=0.5, color=color, **kwargs) + _ = ax.scatter( + prior_samples[:, i], est[:, i], alpha=0.5, color=color, **kwargs + ) # Make plots quadratic to avoid visual illusions lower = min(prior_samples[:, i].min(), est[:, i].min()) @@ -179,11 +189,17 @@ def plot_recovery( # Prettify sns.despine(ax=ax) ax.grid(alpha=0.5) - ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) - ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) + ax.tick_params( + axis="both", which="major", labelsize=tick_fontsize + ) + ax.tick_params( + axis="both", which="minor", labelsize=tick_fontsize + ) # Only add x-labels to the bottom row - bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] + bottom_row = axarr if n_row == 1 else ( + axarr[0] if n_col == 1 else axarr[n_row - 1, :] + ) for _ax in bottom_row: _ax.set_xlabel(xlabel, fontsize=label_fontsize) @@ -200,4 +216,4 @@ def plot_recovery( _ax.remove() f.tight_layout() - return f \ No newline at end of file + return f diff --git a/bayesflow/diagnostics/plot_sbc_ecdf.py b/bayesflow/diagnostics/plot_sbc_ecdf.py index e4c860a5f..cc8614436 100644 --- a/bayesflow/diagnostics/plot_sbc_ecdf.py +++ b/bayesflow/diagnostics/plot_sbc_ecdf.py @@ -6,6 +6,7 @@ from ..utils.plot_utils import check_posterior_prior_shapes from ..utils.ecdf import simultaneous_ecdf_bands + def plot_sbc_ecdf( post_samples, prior_samples, @@ -23,16 +24,19 @@ def plot_sbc_ecdf( n_col=None, **kwargs, ): - """Creates the empirical CDFs for each marginal rank distribution and plots it against - a uniform ECDF. ECDF simultaneous bands are drawn using simulations from the uniform, + """ + Creates the empirical CDFs for each marginal rank distribution + and plots it against a uniform ECDF. + ECDF simultaneous bands are drawn using simulations from the uniform, as proposed by [1]. - For models with many parameters, use `stacked=True` to obtain an idea of the overall calibration - of a posterior approximator. + For models with many parameters, use `stacked=True` to obtain an idea + of the overall calibration of a posterior approximator. - [1] Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and - its applications in goodness-of-fit evaluation and multiple sample comparison. Statistics and Computing, - 32(2), 1-21. https://arxiv.org/abs/2103.10522 + [1] Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test + for discrete uniformity and its applications in goodness-of-fit evaluation + and multiple sample comparison. Statistics and Computing, 32(2), 1-21. + https://arxiv.org/abs/2103.10522 Parameters ---------- @@ -41,20 +45,25 @@ def plot_sbc_ecdf( prior_samples : np.ndarray of shape (n_data_sets, n_params) The prior draws obtained for generating n_data_sets difference : bool, optional, default: False - If `True`, plots the ECDF difference. Enables a more dynamic visualization range. + If `True`, plots the ECDF difference. + Enables a more dynamic visualization range. stacked : bool, optional, default: False - If `True`, all ECDFs will be plotted on the same plot. If `False`, each ECDF will - have its own subplot, similar to the behavior of `plot_sbc_histograms`. + If `True`, all ECDFs will be plotted on the same plot. + If `False`, each ECDF will have its own subplot, + similar to the behavior of `plot_sbc_histograms`. param_names : list or None, optional, default: None - The parameter names for nice plot titles. Inferred if None. Only relevant if `stacked=False`. + The parameter names for nice plot titles. + Inferred if None. Only relevant if `stacked=False`. fig_size : tuple or None, optional, default: None - The figure size passed to the matplotlib constructor. Inferred if None. + The figure size passed to the matplotlib constructor. + Inferred if None. label_fontsize : int, optional, default: 16 The font size of the y-label and y-label texts legend_fontsize : int, optional, default: 14 The font size of the legend text title_fontsize : int, optional, default: 18 - The font size of the title text. Only relevant if `stacked=False` + The font size of the title text. + Only relevant if `stacked=False` tick_fontsize : int, optional, default: 12 The font size of the axis ticklabels rank_ecdf_color : str, optional, default: '#a34f4f' @@ -62,12 +71,15 @@ def plot_sbc_ecdf( fill_color : str, optional, default: 'grey' The color of the fill arguments. n_row : int, optional, default: None - The number of rows for the subplots. Dynamically determined if None. + The number of rows for the subplots. + Dynamically determined if None. n_col : int, optional, default: None - The number of columns for the subplots. Dynamically determined if None. + The number of columns for the subplots. + Dynamically determined if None. **kwargs : dict, optional, default: {} - Keyword arguments can be passed to control the behavior of ECDF simultaneous band computation - through the ``ecdf_bands_kwargs`` dictionary. See `simultaneous_ecdf_bands` for keyword arguments + Keyword arguments can be passed to control the behavior of + ECDF simultaneous band computation through the ``ecdf_bands_kwargs`` + dictionary. See `simultaneous_ecdf_bands` for keyword arguments Returns ------- @@ -76,7 +88,8 @@ def plot_sbc_ecdf( Raises ------ ShapeError - If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. + If there is a deviation form the expected shapes of `post_samples` + and `prior_samples`. """ # Sanity checks @@ -86,7 +99,9 @@ def plot_sbc_ecdf( n_params = post_samples.shape[-1] # Compute fractional ranks (using broadcasting) - ranks = np.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) / post_samples.shape[1] + ranks = np.sum( + post_samples < prior_samples[:, np.newaxis, :], axis=1 + ) / post_samples.shape[1] # Prepare figure if stacked: @@ -122,14 +137,25 @@ def plot_sbc_ecdf( if stacked: if j == 0: - ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs") + ax.plot( + xx, yy, + color=rank_ecdf_color, alpha=0.95, + label="Rank ECDFs" + ) else: ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95) else: - ax.flat[j].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDF") + ax.flat[j].plot( + xx, yy, + color=rank_ecdf_color, alpha=0.95, + label="Rank ECDF" + ) # Compute uniform ECDF and bands - alpha, z, L, H = simultaneous_ecdf_bands(post_samples.shape[0], **kwargs.pop("ecdf_bands_kwargs", {})) + alpha, z, L, H = simultaneous_ecdf_bands( + post_samples.shape[0], + **kwargs.pop("ecdf_bands_kwargs", {}) + ) # Difference, if specified if difference: @@ -151,7 +177,11 @@ def plot_sbc_ecdf( titles = param_names for _ax, title in zip(axes, titles): - _ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands") + _ax.fill_between( + z, L, H, + color=fill_color, alpha=0.2, + label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands" + ) # Prettify plot sns.despine(ax=_ax) @@ -181,4 +211,4 @@ def plot_sbc_ecdf( _ax.remove() f.tight_layout() - return f \ No newline at end of file + return f diff --git a/bayesflow/diagnostics/plot_z_score_contraction.py b/bayesflow/diagnostics/plot_z_score_contraction.py index 77cc33922..d28ccfc5a 100644 --- a/bayesflow/diagnostics/plot_z_score_contraction.py +++ b/bayesflow/diagnostics/plot_z_score_contraction.py @@ -5,6 +5,7 @@ from ..utils.plot_utils import check_posterior_prior_shapes + def plot_z_score_contraction( post_samples, prior_samples, @@ -17,26 +18,31 @@ def plot_z_score_contraction( n_col=None, n_row=None, ): - """Implements a graphical check for global model sensitivity by plotting the posterior - z-score over the posterior contraction for each set of posterior samples in ``post_samples`` - according to [1]. + """ + Implements a graphical check for global model sensitivity by plotting the + posterior z-score over the posterior contraction for each set of posterior + samples in ``post_samples`` according to [1]. - The definition of the posterior z-score is: post_z_score = (posterior_mean - true_parameters) / posterior_std - And the score is adequate if it centers around zero and spreads roughly in the interval [-3, 3] + And the score is adequate if it centers around zero and spreads roughly + in the interval [-3, 3] - The definition of posterior contraction is: post_contraction = 1 - (posterior_variance / prior_variance) - In other words, the posterior contraction is a proxy for the reduction in uncertainty gained by - replacing the prior with the posterior. The ideal posterior contraction tends to 1. - Contraction near zero indicates that the posterior variance is almost identical to - the prior variance for the particular marginal parameter distribution. + In other words, the posterior contraction is a proxy for the reduction in + uncertainty gained by replacing the prior with the posterior. + The ideal posterior contraction tends to 1. + Contraction near zero indicates that the posterior variance is almost + identical to the prior variance for the particular marginal parameter + distribution. - Note: Means and variances will be estimated via their sample-based estimators. + Note: + Means and variances will be estimated via their sample-based estimators. [1] Schad, D. J., Betancourt, M., & Vasishth, S. (2021). Toward a principled Bayesian workflow in cognitive science. @@ -134,7 +140,9 @@ def plot_z_score_contraction( ax.set_xlim([-0.05, 1.05]) # Only add x-labels to the bottom row - bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] + bottom_row = axarr if n_row == 1 else ( + axarr[0] if n_col == 1 else axarr[n_row - 1, :] + ) for _ax in bottom_row: _ax.set_xlabel("Posterior contraction", fontsize=label_fontsize) From 1281ccfa33c03162fdf7b619f2ecb709e4e62137 Mon Sep 17 00:00:00 2001 From: Jerry Date: Wed, 30 Oct 2024 22:04:49 -0400 Subject: [PATCH 17/22] Reformatting --- bayesflow/diagnostics/plot_losses.py | 27 +++++++++---------- bayesflow/diagnostics/plot_recovery.py | 2 +- .../diagnostics/plot_z_score_contraction.py | 21 +++++++-------- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/bayesflow/diagnostics/plot_losses.py b/bayesflow/diagnostics/plot_losses.py index 111ef63d6..7af710591 100644 --- a/bayesflow/diagnostics/plot_losses.py +++ b/bayesflow/diagnostics/plot_losses.py @@ -1,23 +1,22 @@ - import numpy as np import seaborn as sns import matplotlib.pyplot as plt def plot_losses( - train_losses, - val_losses=None, - moving_average=False, - ma_window_fraction=0.01, - fig_size=None, - train_color="#8f2727", - val_color="black", - lw_train=2, - lw_val=3, - grid_alpha=0.5, - legend_fontsize=14, - label_fontsize=14, - title_fontsize=16, + train_losses, + val_losses=None, + moving_average=False, + ma_window_fraction=0.01, + fig_size=None, + train_color="#8f2727", + val_color="black", + lw_train=2, + lw_val=3, + grid_alpha=0.5, + legend_fontsize=14, + label_fontsize=14, + title_fontsize=16, ): """ A generic helper function to plot the losses of a series of training epochs diff --git a/bayesflow/diagnostics/plot_recovery.py b/bayesflow/diagnostics/plot_recovery.py index b97eee1c8..c0dba065e 100644 --- a/bayesflow/diagnostics/plot_recovery.py +++ b/bayesflow/diagnostics/plot_recovery.py @@ -1,4 +1,3 @@ - import numpy as np from scipy.stats import median_abs_deviation from sklearn.metrics import r2_score @@ -7,6 +6,7 @@ from ..utils.plot_utils import check_posterior_prior_shapes + def plot_recovery( post_samples, prior_samples, diff --git a/bayesflow/diagnostics/plot_z_score_contraction.py b/bayesflow/diagnostics/plot_z_score_contraction.py index d28ccfc5a..40d3b7105 100644 --- a/bayesflow/diagnostics/plot_z_score_contraction.py +++ b/bayesflow/diagnostics/plot_z_score_contraction.py @@ -1,4 +1,3 @@ - import numpy as np import matplotlib.pyplot as plt import seaborn as sns @@ -7,16 +6,16 @@ def plot_z_score_contraction( - post_samples, - prior_samples, - param_names=None, - fig_size=None, - label_fontsize=16, - title_fontsize=18, - tick_fontsize=12, - color="#8f2727", - n_col=None, - n_row=None, + post_samples, + prior_samples, + param_names=None, + fig_size=None, + label_fontsize=16, + title_fontsize=18, + tick_fontsize=12, + color="#8f2727", + n_col=None, + n_row=None, ): """ Implements a graphical check for global model sensitivity by plotting the From 974d9d4d0bfbd75cf8f47ffffcb3646412de2324 Mon Sep 17 00:00:00 2001 From: Jerry Date: Thu, 31 Oct 2024 11:50:07 -0400 Subject: [PATCH 18/22] Testing pre-commit --- bayesflow/utils/plot_utils.py | 54 ++++++----------------------------- 1 file changed, 9 insertions(+), 45 deletions(-) diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index 08e75078f..027e8ae1f 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -1,4 +1,3 @@ - import numpy as np import matplotlib.pyplot as plt @@ -52,12 +51,7 @@ def check_posterior_prior_shapes(post_samples, prior_samples): ) -def get_count_and_names( - samples, - names: list = None, - symbol: str = None, - n_objects: int = None -): +def get_count_and_names(samples, names: list = None, symbol: str = None, n_objects: int = None): """ Determine the number of objects, such as parameters or models, and their respective names if None given. @@ -91,12 +85,7 @@ def get_count_and_names( return n_objects, names -def configure_layout( - n_total: int, - n_row: int = None, - n_col: int = None, - stacked: bool = False -): +def configure_layout(n_total: int, n_row: int = None, n_col: int = None, stacked: bool = False): """ Determine the number of rows and columns in diagnostics visualizations. @@ -133,9 +122,9 @@ def configure_layout( def initialize_figure( - n_row: int = None, - n_col: int = None, - fig_size: tuple = None, + n_row: int = None, + n_col: int = None, + fig_size: tuple = None, ): """ Initialize a set of figures @@ -146,8 +135,6 @@ def initialize_figure( Number of rows in a figure n_col : int Number of columns in a figure - stacked : bool - Whether subplots in a figure are stacked by rows fig_size : tuple Size of the figure adjusting to the display resolution or the designer's desire @@ -197,25 +184,14 @@ def collapse_axes(axarr, n_row: int = 1, n_col: int = 1): return ax -def add_xlabels( - axarr, - n_row: int = None, - n_col: int = None, - xlabel: str = None, - label_fontsize: int = None -): +def add_xlabels(axarr, n_row: int = None, n_col: int = None, xlabel: str = None, label_fontsize: int = None): # Only add x-labels to the bottom row bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] for _ax in bottom_row: _ax.set_xlabel(xlabel, fontsize=label_fontsize) -def add_ylabels( - axarr, - n_row: int = None, - ylabel: str = None, - label_fontsize: int = None -): +def add_ylabels(axarr, n_row: int = None, ylabel: str = None, label_fontsize: int = None): # Only add y-labels to right left-most row if n_row == 1: # if there is only one row, the ax array is 1D axarr[0].set_ylabel(ylabel, fontsize=label_fontsize) @@ -226,12 +202,7 @@ def add_ylabels( def add_labels( - axarr, - n_row: int = None, - n_col: int = None, - xlabel: str = None, - ylabel: str = None, - label_fontsize: int = None + axarr, n_row: int = None, n_col: int = None, xlabel: str = None, ylabel: str = None, label_fontsize: int = None ): """ Wrapper function for configuring labels for both axes. @@ -245,12 +216,7 @@ def remove_unused_axes(axarr_it, n_params: int = None): _ax.remove() -def preprocess( - post_samples, - prior_samples, - fig_size: tuple = None, - collapse: bool = True -): +def preprocess(post_samples, prior_samples, fig_size: tuple = None, collapse: bool = True): """ Procedural wrapper that encompasses all preprocessing steps, including shape-checking, parameter name generation, layout configuration, @@ -264,8 +230,6 @@ def preprocess( The prior draws obtained for generating n_data_sets fig_size : tuple, optional, default: None Size of the figure adjusting to the display resolution - stacked : bool, optional, default: False - Whether subplots in a figure are stacked by rows collapse : bool, optional, default: True Whether subplots in a figure are collapsed into rows """ From 4eabc60475917a9491ae7f4f6eca2cfdbe3d0423 Mon Sep 17 00:00:00 2001 From: Jerry Date: Fri, 1 Nov 2024 07:52:10 -0400 Subject: [PATCH 19/22] Reformat using pre-commit --- bayesflow/diagnostics/plot_distribution_2d.py | 41 +++++--------- bayesflow/diagnostics/plot_losses.py | 39 ++++++------- bayesflow/diagnostics/plot_recovery.py | 56 ++++++++----------- bayesflow/diagnostics/plot_sbc_ecdf.py | 28 ++-------- bayesflow/diagnostics/plot_sbc_histogram.py | 26 ++++----- .../diagnostics/plot_z_score_contraction.py | 24 ++++---- bayesflow/utils/exceptions/shape_error.py | 1 - examples/mm_gsn.stan | 4 +- 8 files changed, 84 insertions(+), 135 deletions(-) diff --git a/bayesflow/diagnostics/plot_distribution_2d.py b/bayesflow/diagnostics/plot_distribution_2d.py index e0a0b722b..6d24e19a7 100644 --- a/bayesflow/diagnostics/plot_distribution_2d.py +++ b/bayesflow/diagnostics/plot_distribution_2d.py @@ -1,4 +1,3 @@ - import logging import seaborn as sns import pandas as pd @@ -7,15 +6,15 @@ def plot_distribution_2d( - samples: dict[str, Tensor] = None, - parameters: str = None, - n_params: int = None, - param_names: list = None, - height: float = 2.5, - color: str | tuple = "#8f2727", - alpha: float = 0.9, - render: bool = True, - **kwargs + samples: dict[str, Tensor] = None, + parameters: str = None, + n_params: int = None, + param_names: list = None, + height: float = 2.5, + color: str | tuple = "#8f2727", + alpha: float = 0.9, + render: bool = True, + **kwargs, ): """ A more flexible pair plot function for multiple distributions based upon @@ -67,28 +66,16 @@ def plot_distribution_2d( # Generate plots artist = sns.PairGrid(data_to_plot, height=height, **kwargs) - artist.map_diag( - sns.histplot, fill=True, color=color, alpha=alpha, kde=True - ) + artist.map_diag(sns.histplot, fill=True, color=color, alpha=alpha, kde=True) # Incorporate exceptions for generating KDE plots try: - artist.map_lower( - sns.kdeplot, fill=True, color=color, alpha=alpha - ) + artist.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha) except Exception as e: - logging.warning( - "KDE failed due to the following exception:\n" - + repr(e) - + "\nSubstituting scatter plot." - ) - artist.map_lower( - sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color - ) + logging.warning("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.") + artist.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) - artist.map_upper( - sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color - ) + artist.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color) if render: # Generate grids diff --git a/bayesflow/diagnostics/plot_losses.py b/bayesflow/diagnostics/plot_losses.py index 7af710591..215e53179 100644 --- a/bayesflow/diagnostics/plot_losses.py +++ b/bayesflow/diagnostics/plot_losses.py @@ -4,19 +4,19 @@ def plot_losses( - train_losses, - val_losses=None, - moving_average=False, - ma_window_fraction=0.01, - fig_size=None, - train_color="#8f2727", - val_color="black", - lw_train=2, - lw_val=3, - grid_alpha=0.5, - legend_fontsize=14, - label_fontsize=14, - title_fontsize=16, + train_losses, + val_losses=None, + moving_average=False, + ma_window_fraction=0.01, + fig_size=None, + train_color="#8f2727", + val_color="black", + lw_train=2, + lw_val=3, + grid_alpha=0.5, + legend_fontsize=14, + label_fontsize=14, + title_fontsize=16, ): """ A generic helper function to plot the losses of a series of training epochs @@ -83,7 +83,7 @@ def plot_losses( train_step_index = np.arange(1, len(train_losses) + 1) if val_losses is not None: val_step = int(np.floor(len(train_losses) / len(val_losses))) - val_step_index = train_step_index[(val_step - 1)::val_step] + val_step_index = train_step_index[(val_step - 1) :: val_step] # If unequal length due to some reason, attempt a fix if val_step_index.shape[0] > val_losses.shape[0]: @@ -93,18 +93,11 @@ def plot_losses( looper = [axarr] if n_row == 1 else axarr.flat for i, ax in enumerate(looper): # Plot train curve - ax.plot( - train_step_index, train_losses.iloc[:, i], - color=train_color, lw=lw_train, alpha=0.9, label="Training" - ) + ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training") if moving_average and train_losses.columns[i] == "Loss": moving_average_window = int(train_losses.shape[0] * ma_window_fraction) smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean() - ax.plot( - train_step_index, smoothed_loss, - color="grey", lw=lw_train, - label="Training (Moving Average)" - ) + ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)") # Plot optional val curve if val_losses is not None: diff --git a/bayesflow/diagnostics/plot_recovery.py b/bayesflow/diagnostics/plot_recovery.py index c0dba065e..98d082b24 100644 --- a/bayesflow/diagnostics/plot_recovery.py +++ b/bayesflow/diagnostics/plot_recovery.py @@ -8,24 +8,24 @@ def plot_recovery( - post_samples, - prior_samples, - point_agg=np.median, - uncertainty_agg=median_abs_deviation, - param_names=None, - fig_size=None, - label_fontsize=16, - title_fontsize=18, - metric_fontsize=16, - tick_fontsize=12, - add_corr=True, - add_r2=True, - color="#8f2727", - n_col=None, - n_row=None, - xlabel="Ground truth", - ylabel="Estimated", - **kwargs, + post_samples, + prior_samples, + point_agg=np.median, + uncertainty_agg=median_abs_deviation, + param_names=None, + fig_size=None, + label_fontsize=16, + title_fontsize=18, + metric_fontsize=16, + tick_fontsize=12, + add_corr=True, + add_r2=True, + color="#8f2727", + n_col=None, + n_row=None, + xlabel="Ground truth", + ylabel="Estimated", + **kwargs, ): """ Creates and plots publication-ready recovery plot with true estimate @@ -139,13 +139,9 @@ def plot_recovery( # Add scatter and error bars if uncertainty_agg is not None: - _ = ax.errorbar( - prior_samples[:, i], est[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color, **kwargs - ) + _ = ax.errorbar(prior_samples[:, i], est[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color, **kwargs) else: - _ = ax.scatter( - prior_samples[:, i], est[:, i], alpha=0.5, color=color, **kwargs - ) + _ = ax.scatter(prior_samples[:, i], est[:, i], alpha=0.5, color=color, **kwargs) # Make plots quadratic to avoid visual illusions lower = min(prior_samples[:, i].min(), est[:, i].min()) @@ -189,17 +185,11 @@ def plot_recovery( # Prettify sns.despine(ax=ax) ax.grid(alpha=0.5) - ax.tick_params( - axis="both", which="major", labelsize=tick_fontsize - ) - ax.tick_params( - axis="both", which="minor", labelsize=tick_fontsize - ) + ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) + ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) # Only add x-labels to the bottom row - bottom_row = axarr if n_row == 1 else ( - axarr[0] if n_col == 1 else axarr[n_row - 1, :] - ) + bottom_row = axarr if n_row == 1 else (axarr[0] if n_col == 1 else axarr[n_row - 1, :]) for _ax in bottom_row: _ax.set_xlabel(xlabel, fontsize=label_fontsize) diff --git a/bayesflow/diagnostics/plot_sbc_ecdf.py b/bayesflow/diagnostics/plot_sbc_ecdf.py index cc8614436..cbe680b4c 100644 --- a/bayesflow/diagnostics/plot_sbc_ecdf.py +++ b/bayesflow/diagnostics/plot_sbc_ecdf.py @@ -1,4 +1,3 @@ - import matplotlib.pyplot as plt import numpy as np import seaborn as sns @@ -99,9 +98,7 @@ def plot_sbc_ecdf( n_params = post_samples.shape[-1] # Compute fractional ranks (using broadcasting) - ranks = np.sum( - post_samples < prior_samples[:, np.newaxis, :], axis=1 - ) / post_samples.shape[1] + ranks = np.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) / post_samples.shape[1] # Prepare figure if stacked: @@ -137,25 +134,14 @@ def plot_sbc_ecdf( if stacked: if j == 0: - ax.plot( - xx, yy, - color=rank_ecdf_color, alpha=0.95, - label="Rank ECDFs" - ) + ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs") else: ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95) else: - ax.flat[j].plot( - xx, yy, - color=rank_ecdf_color, alpha=0.95, - label="Rank ECDF" - ) + ax.flat[j].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDF") # Compute uniform ECDF and bands - alpha, z, L, H = simultaneous_ecdf_bands( - post_samples.shape[0], - **kwargs.pop("ecdf_bands_kwargs", {}) - ) + alpha, z, L, H = simultaneous_ecdf_bands(post_samples.shape[0], **kwargs.pop("ecdf_bands_kwargs", {})) # Difference, if specified if difference: @@ -177,11 +163,7 @@ def plot_sbc_ecdf( titles = param_names for _ax, title in zip(axes, titles): - _ax.fill_between( - z, L, H, - color=fill_color, alpha=0.2, - label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands" - ) + _ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands") # Prettify plot sns.despine(ax=_ax) diff --git a/bayesflow/diagnostics/plot_sbc_histogram.py b/bayesflow/diagnostics/plot_sbc_histogram.py index 41542a046..82e58342c 100644 --- a/bayesflow/diagnostics/plot_sbc_histogram.py +++ b/bayesflow/diagnostics/plot_sbc_histogram.py @@ -1,4 +1,3 @@ - import logging import numpy as np import matplotlib.pyplot as plt @@ -7,19 +6,20 @@ from scipy.stats import binom from ..utils.plot_utils import check_posterior_prior_shapes + def plot_sbc_histograms( - post_samples, - prior_samples, - param_names=None, - fig_size=None, - num_bins=None, - binomial_interval=0.99, - label_fontsize=16, - title_fontsize=18, - tick_fontsize=12, - hist_color="#a34f4f", - n_row=None, - n_col=None, + post_samples, + prior_samples, + param_names=None, + fig_size=None, + num_bins=None, + binomial_interval=0.99, + label_fontsize=16, + title_fontsize=18, + tick_fontsize=12, + hist_color="#a34f4f", + n_row=None, + n_col=None, ): """Creates and plots publication-ready histograms of rank statistics for simulation-based calibration (SBC) checks according to [1]. diff --git a/bayesflow/diagnostics/plot_z_score_contraction.py b/bayesflow/diagnostics/plot_z_score_contraction.py index 40d3b7105..151dddea7 100644 --- a/bayesflow/diagnostics/plot_z_score_contraction.py +++ b/bayesflow/diagnostics/plot_z_score_contraction.py @@ -6,16 +6,16 @@ def plot_z_score_contraction( - post_samples, - prior_samples, - param_names=None, - fig_size=None, - label_fontsize=16, - title_fontsize=18, - tick_fontsize=12, - color="#8f2727", - n_col=None, - n_row=None, + post_samples, + prior_samples, + param_names=None, + fig_size=None, + label_fontsize=16, + title_fontsize=18, + tick_fontsize=12, + color="#8f2727", + n_col=None, + n_row=None, ): """ Implements a graphical check for global model sensitivity by plotting the @@ -139,9 +139,7 @@ def plot_z_score_contraction( ax.set_xlim([-0.05, 1.05]) # Only add x-labels to the bottom row - bottom_row = axarr if n_row == 1 else ( - axarr[0] if n_col == 1 else axarr[n_row - 1, :] - ) + bottom_row = axarr if n_row == 1 else (axarr[0] if n_col == 1 else axarr[n_row - 1, :]) for _ax in bottom_row: _ax.set_xlabel("Posterior contraction", fontsize=label_fontsize) diff --git a/bayesflow/utils/exceptions/shape_error.py b/bayesflow/utils/exceptions/shape_error.py index c167bdab6..ebb28d156 100644 --- a/bayesflow/utils/exceptions/shape_error.py +++ b/bayesflow/utils/exceptions/shape_error.py @@ -1,4 +1,3 @@ - class ShapeError(Exception): """Class for error in expected shapes.""" diff --git a/examples/mm_gsn.stan b/examples/mm_gsn.stan index 0d0702ca1..7b33a8645 100644 --- a/examples/mm_gsn.stan +++ b/examples/mm_gsn.stan @@ -27,11 +27,11 @@ transformed parameters { model { vector[D] mu; - + // Priors theta1_centered ~ normal(0.0, 0.1); theta2_centered ~ normal(0.0, 0.1); - + // Likelihood for (i in 1:D) { mu[i] = theta1_scaled * design_s[i] / (theta2_scaled + design_s[i]); From afa8b9afb59b27584a71904d132b6eef63c1ad17 Mon Sep 17 00:00:00 2001 From: Jerry Date: Sat, 2 Nov 2024 16:42:25 -0400 Subject: [PATCH 20/22] Add plot_calibration_curves --- .../diagnostics/plot_calibration_curves.py | 136 ++++++++++++++++++ bayesflow/utils/comp_utils.py | 0 2 files changed, 136 insertions(+) create mode 100644 bayesflow/diagnostics/plot_calibration_curves.py create mode 100644 bayesflow/utils/comp_utils.py diff --git a/bayesflow/diagnostics/plot_calibration_curves.py b/bayesflow/diagnostics/plot_calibration_curves.py new file mode 100644 index 000000000..c5071ae1f --- /dev/null +++ b/bayesflow/diagnostics/plot_calibration_curves.py @@ -0,0 +1,136 @@ +import numpy as np +import matplotlib.pyplot as plt + +from ..utils.comp_utils import expected_calibration_error + + +def plot_calibration_curves( + true_models, + pred_models, + model_names=None, + num_bins=10, + label_fontsize=16, + legend_fontsize=14, + title_fontsize=18, + tick_fontsize=12, + epsilon=0.02, + fig_size=None, + color="#8f2727", + n_row=None, + n_col=None, +): + """Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities + for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin. + Depends on the ``expected_calibration_error`` function for computing the ECE. + + Parameters + ---------- + true_models : np.ndarray of shape (num_data_sets, num_models) + The one-hot-encoded true model indices per data set. + pred_models : np.ndarray of shape (num_data_sets, num_models) + The predicted posterior model probabilities (PMPs) per data set. + model_names : list or None, optional, default: None + The model names for nice plot titles. Inferred if None. + num_bins : int, optional, default: 10 + The number of bins to use for the calibration curves (and marginal histograms). + label_fontsize : int, optional, default: 16 + The font size of the y-label and y-label texts + legend_fontsize : int, optional, default: 14 + The font size of the legend text (ECE value) + title_fontsize : int, optional, default: 18 + The font size of the title text. Only relevant if `stacked=False` + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + epsilon : float, optional, default: 0.02 + A small amount to pad the [0, 1]-bounded axes from both side. + fig_size : tuple or None, optional, default: None + The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` + color : str, optional, default: '#8f2727' + The color of the calibration curves + n_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + n_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + + Returns + ------- + fig : plt.Figure - the figure instance for optional saving + """ + + num_models = true_models.shape[-1] + if model_names is None: + model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)] + + # Determine number of rows and columns for subplots based on inputs + if n_row is None and n_col is None: + n_row = int(np.ceil(num_models / 6)) + n_col = int(np.ceil(num_models / n_row)) + elif n_row is None and n_col is not None: + n_row = int(np.ceil(num_models / n_col)) + elif n_row is not None and n_col is None: + n_col = int(np.ceil(num_models / n_row)) + + # Compute calibration + cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins) + + # Initialize figure + if fig_size is None: + fig_size = (int(5 * n_col), int(5 * n_row)) + fig, ax_array = plt.subplots(n_row, n_col, figsize=fig_size) + if n_row > 1: + ax = ax_array.flat + + # Plot marginal calibration curves in a loop + if n_row > 1: + ax = ax_array.flat + else: + ax = ax_array + for j in range(num_models): + # Plot calibration curve + ax[j].plot(probs_pred[j], probs_true[j], "o-", color=color) + + # Plot PMP distribution over bins + uniform_bins = np.linspace(0.0, 1.0, num_bins + 1) + norm_weights = np.ones_like(pred_models) / len(pred_models) + ax[j].hist(pred_models[:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3) + + # Plot AB line + ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9) + + # Tweak plot + ax[j].tick_params(axis="both", which="major", labelsize=tick_fontsize) + ax[j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) + ax[j].set_title(model_names[j], fontsize=title_fontsize) + ax[j].spines["right"].set_visible(False) + ax[j].spines["top"].set_visible(False) + ax[j].set_xlim([0 - epsilon, 1 + epsilon]) + ax[j].set_ylim([0 - epsilon, 1 + epsilon]) + ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) + ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) + ax[j].grid(alpha=0.5) + + # Add ECE label + ax[j].text( + 0.1, + 0.9, + r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}".format(cal_errs[j]), + horizontalalignment="left", + verticalalignment="center", + transform=ax[j].transAxes, + size=legend_fontsize, + ) + + # Only add x-labels to the bottom row + bottom_row = ax_array if n_row == 1 else ax_array[0] if n_col == 1 else ax_array[n_row - 1, :] + for _ax in bottom_row: + _ax.set_xlabel("Predicted probability", fontsize=label_fontsize) + + # Only add y-labels to left-most row + if n_row == 1: # if there is only one row, the ax array is 1D + ax[0].set_ylabel("True probability", fontsize=label_fontsize) + else: # if there is more than one row, the ax array is 2D + for _ax in ax_array[:, 0]: + _ax.set_ylabel("True probability", fontsize=label_fontsize) + + fig.tight_layout() + return fig diff --git a/bayesflow/utils/comp_utils.py b/bayesflow/utils/comp_utils.py new file mode 100644 index 000000000..e69de29bb From 7e90e61a94ebc257b3785847fa5b8959b0fe7d76 Mon Sep 17 00:00:00 2001 From: Jerry Date: Sat, 2 Nov 2024 17:00:16 -0400 Subject: [PATCH 21/22] Fixed E721 for plot_posterior_2d and plot_prior_2d --- bayesflow/diagnostics/__init__.py | 4 + bayesflow/diagnostics/plot_latent_space_2d.py | 31 +++++ bayesflow/diagnostics/plot_posterior_2d.py | 128 ++++++++++++++++++ bayesflow/diagnostics/plot_prior_2d.py | 43 ++++++ bayesflow/utils/__init__.py | 2 + bayesflow/utils/comp_utils.py | 63 +++++++++ bayesflow/utils/plot_utils.py | 68 +++++----- 7 files changed, 305 insertions(+), 34 deletions(-) create mode 100644 bayesflow/diagnostics/plot_latent_space_2d.py create mode 100644 bayesflow/diagnostics/plot_posterior_2d.py create mode 100644 bayesflow/diagnostics/plot_prior_2d.py diff --git a/bayesflow/diagnostics/__init__.py b/bayesflow/diagnostics/__init__.py index d75daa681..a4484dd06 100644 --- a/bayesflow/diagnostics/__init__.py +++ b/bayesflow/diagnostics/__init__.py @@ -4,3 +4,7 @@ from .plot_sbc_histogram import plot_sbc_histograms from .plot_distribution_2d import plot_distribution_2d from .plot_z_score_contraction import plot_z_score_contraction +from .plot_prior_2d import plot_prior_2d +from .plot_posterior_2d import plot_posterior_2d +from .plot_latent_space_2d import plot_latent_space_2d +from .plot_calibration_curves import plot_calibration_curves diff --git a/bayesflow/diagnostics/plot_latent_space_2d.py b/bayesflow/diagnostics/plot_latent_space_2d.py new file mode 100644 index 000000000..9fb56fbc4 --- /dev/null +++ b/bayesflow/diagnostics/plot_latent_space_2d.py @@ -0,0 +1,31 @@ +from .plot_distribution_2d import plot_distribution_2d + +from keras import backend as K + + +def plot_latent_space_2d(z_samples, height: float = 2.5, color="#8f2727", **kwargs): + """Creates pair plots for the latent space learned by the inference network. Enables + visual inspection of the latent space and whether its structure corresponds to the + one enforced by the optimization criterion. + + Parameters + ---------- + z_samples : np.ndarray or tf.Tensor of shape (n_sim, n_params) + The latent samples computed through a forward pass of the inference network. + height : float, optional, default: 2.5 + The height of the pair plot. + color : str, optional, default : '#8f2727' + The color of the plot + **kwargs : dict, optional + Additional keyword arguments passed to the sns.PairGrid constructor + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + """ + + # Try to convert z_samples, if eventually tf.Tensor is passed + if not isinstance(z_samples, K.tf.Tensor): + z_samples = K.constant(z_samples) + + plot_distribution_2d(z_samples, context="Latent Dim", height=height, color=color, render=True, **kwargs) diff --git a/bayesflow/diagnostics/plot_posterior_2d.py b/bayesflow/diagnostics/plot_posterior_2d.py new file mode 100644 index 000000000..a676539d2 --- /dev/null +++ b/bayesflow/diagnostics/plot_posterior_2d.py @@ -0,0 +1,128 @@ +import pandas as pd +import seaborn as sns + +from matplotlib.lines import Line2D +from .plot_distribution_2d import plot_distribution_2d + + +def plot_posterior_2d( + posterior_draws, + prior=None, + prior_draws=None, + param_names: list = None, + height: int = 3, + label_fontsize: int = 14, + legend_fontsize: int = 16, + tick_fontsize: int = 12, + post_color: str | tuple = "#8f2727", + prior_color: str | tuple = "gray", + post_alpha: float = 0.9, + prior_alpha: float = 0.7, + **kwargs, +): + """Generates a bivariate pairplot given posterior draws and optional prior or prior draws. + + posterior_draws : np.ndarray of shape (n_post_draws, n_params) + The posterior draws obtained for a SINGLE observed data set. + prior : bayesflow.forward_inference.Prior instance or None, optional, default: None + The optional prior object having an input-output signature as given by ayesflow.forward_inference.Prior + prior_draws : np.ndarray of shape (n_prior_draws, n_params) or None, optonal (default: None) + The optional prior draws obtained from the prior. If both prior and prior_draws are provided, prior_draws + will be used. + param_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + height : float, optional, default: 3 + The height of the pairplot + label_fontsize : int, optional, default: 14 + The font size of the x and y-label texts (parameter names) + legend_fontsize : int, optional, default: 16 + The font size of the legend text + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + post_color : str, optional, default: '#8f2727' + The color for the posterior histograms and KDEs + priors_color : str, optional, default: gray + The color for the optional prior histograms and KDEs + post_alpha : float in [0, 1], optonal, default: 0.9 + The opacity of the posterior plots + prior_alpha : float in [0, 1], optonal, default: 0.7 + The opacity of the prior plots + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + AssertionError + If the shape of posterior_draws is not 2-dimensional. + """ + + # Ensure correct shape + assert ( + len(posterior_draws.shape) + ) == 2, "Shape of `posterior_samples` for a single data set should be 2 dimensional!" + + # Plot posterior first + g = plot_distribution_2d(posterior_draws, context="\\theta", param_names=param_names, render=False, **kwargs) + + # Obtain n_draws and n_params + n_draws, n_params = posterior_draws.shape + + # If prior object is given and no draws, obtain draws + if prior is not None and prior_draws is None: + draws = prior(n_draws) + if isinstance(draws, dict): + prior_draws = draws["prior_draws"] + else: + prior_draws = draws + + # Attempt to determine parameter names + if param_names is None: + if hasattr(prior, "param_names"): + if prior.param_names is not None: + param_names = prior.param_names + else: + param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] + else: + param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)] + + # Add prior, if given + if prior_draws is not None: + prior_draws_df = pd.DataFrame(prior_draws, columns=param_names) + g.data = prior_draws_df + g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1) + g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1) + + # Add legend, if prior also given + if prior_draws is not None or prior is not None: + handles = [ + Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha), + Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha), + ] + g.legend(handles, ["Posterior", "Prior"], fontsize=legend_fontsize, loc="center right") + + n_row, n_col = g.axes.shape + + for i in range(n_row): + # Remove upper axis + for j in range(i + 1, n_col): + g.axes[i, j].axis("off") + + # Modify tick sizes + for j in range(i + 1): + g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize) + g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) + + # Add nice labels + for i, param_name in enumerate(param_names): + g.axes[i, 0].set_ylabel(param_name, fontsize=label_fontsize) + g.axes[len(param_names) - 1, i].set_xlabel(param_name, fontsize=label_fontsize) + + # Add grids + for i in range(n_params): + for j in range(n_params): + g.axes[i, j].grid(alpha=0.5) + + g.tight_layout() + return g diff --git a/bayesflow/diagnostics/plot_prior_2d.py b/bayesflow/diagnostics/plot_prior_2d.py new file mode 100644 index 000000000..f695e508b --- /dev/null +++ b/bayesflow/diagnostics/plot_prior_2d.py @@ -0,0 +1,43 @@ +from .plot_distribution_2d import plot_distribution_2d + + +def plot_prior_2d( + prior, + param_names: list = None, + n_samples: int = 2000, + height: float = 2.5, + color: str | tuple = "#8f2727", + **kwargs, +): + """Creates pair-plots for a given joint prior. + + Parameters + ---------- + prior : callable + The prior object which takes a single integer argument and generates random draws. + param_names : list of str or None, optional, default None + An optional list of strings which + n_samples : int, optional, default: 1000 + The number of random draws from the joint prior + height : float, optional, default: 2.5 + The height of the pair plot + color : str, optional, default : '#8f2727' + The color of the plot + **kwargs : dict, optional + Additional keyword arguments passed to the sns.PairGrid constructor + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + """ + + # Generate prior draws + prior_samples = prior(n_samples) + + # Handle dict type + if isinstance(prior_samples, dict): + prior_samples = prior_samples["prior_draws"] + + plot_distribution_2d( + prior_samples, context="Prior", height=height, color=color, param_names=param_names, render=True, **kwargs + ) diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 169a13128..68c9de275 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -39,3 +39,5 @@ tree_concatenate, tree_stack, ) +from .comp_utils import expected_calibration_error +from .plot_utils import check_posterior_prior_shapes diff --git a/bayesflow/utils/comp_utils.py b/bayesflow/utils/comp_utils.py index e69de29bb..b21b3f03c 100644 --- a/bayesflow/utils/comp_utils.py +++ b/bayesflow/utils/comp_utils.py @@ -0,0 +1,63 @@ +import numpy as np + +from sklearn.calibration import calibration_curve + + +def expected_calibration_error(m_true, m_pred, num_bins=10): + """Estimates the expected calibration error (ECE) of a model comparison network according to [1]. + + [1] Naeini, M. P., Cooper, G., & Hauskrecht, M. (2015). + Obtaining well calibrated probabilities using bayesian binning. + In Proceedings of the AAAI conference on artificial intelligence (Vol. 29, No. 1). + + Notes + ----- + Make sure that ``m_true`` are **one-hot encoded** classes! + + Parameters + ---------- + m_true : np.ndarray of shape (num_sim, num_models) + The one-hot-encoded true model indices. + m_pred : tf.tensor of shape (num_sim, num_models) + The predicted posterior model probabilities. + num_bins : int, optional, default: 10 + The number of bins to use for the calibration curves (and marginal histograms). + + Returns + ------- + cal_errs : list of length (num_models) + The ECEs for each model. + probs : list of length (num_models) + The bin information for constructing the calibration curves. + Each list contains two arrays of length (num_bins) with the predicted and true probabilities for each bin. + """ + + # Convert tf.Tensors to numpy, if passed + if type(m_true) is not np.ndarray: + m_true = m_true.numpy() + if type(m_pred) is not np.ndarray: + m_pred = m_pred.numpy() + + # Extract number of models and prepare containers + n_models = m_true.shape[1] + cal_errs = [] + probs_true = [] + probs_pred = [] + + # Loop for each model and compute calibration errs per bin + for k in range(n_models): + y_true = (m_true.argmax(axis=1) == k).astype(np.float32) + y_prob = m_pred[:, k] + prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=num_bins) + + # Compute ECE by weighting bin errors by bin size + bins = np.linspace(0.0, 1.0, num_bins + 1) + binids = np.searchsorted(bins[1:-1], y_prob) + bin_total = np.bincount(binids, minlength=len(bins)) + nonzero = bin_total != 0 + cal_err = np.sum(np.abs(prob_true - prob_pred) * (bin_total[nonzero] / len(y_true))) + + cal_errs.append(cal_err) + probs_true.append(prob_true) + probs_pred.append(prob_pred) + return cal_errs, probs_true, probs_pred diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index 027e8ae1f..19170f7e7 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -85,7 +85,7 @@ def get_count_and_names(samples, names: list = None, symbol: str = None, n_objec return n_objects, names -def configure_layout(n_total: int, n_row: int = None, n_col: int = None, stacked: bool = False): +def set_layout(n_total: int, n_row: int = None, n_col: int = None, stacked: bool = False): """ Determine the number of rows and columns in diagnostics visualizations. @@ -121,7 +121,7 @@ def configure_layout(n_total: int, n_row: int = None, n_col: int = None, stacked return n_row, n_col -def initialize_figure( +def make_figure( n_row: int = None, n_col: int = None, fig_size: tuple = None, @@ -141,27 +141,27 @@ def initialize_figure( Returns ------- - f, axarr + f, ax_array Initialized figures """ if n_row == 1 and n_col == 1: - f, axarr = plt.subplots(1, 1, figsize=fig_size) + f, ax_array = plt.subplots(1, 1, figsize=fig_size) else: if fig_size is None: fig_size = (int(5 * n_col), int(5 * n_row)) - f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) + f, ax_array = plt.subplots(n_row, n_col, figsize=fig_size) - return f, axarr + return f, ax_array -def collapse_axes(axarr, n_row: int = 1, n_col: int = 1): +def flatten_axes(ax_array, n_row: int = 1, n_col: int = 1): """ Collapse a 2D array of subplot Axes into a 1D array Parameters ---------- - axarr : 2D array of Axes + ax_array : 2D array of Axes An array of axes for subplots n_row : int, default: 1 Number of rows for the axes @@ -174,49 +174,49 @@ def collapse_axes(axarr, n_row: int = 1, n_col: int = 1): Collapsed axes for subplots """ - ax = np.atleast_1d(axarr) - # turn axarr into 1D list + ax = np.atleast_1d(ax_array) + # turn ax_array into 1D list if n_row > 1 or n_col > 1: - ax = axarr.flat + ax = ax_array.flat else: - ax = axarr + ax = ax_array return ax -def add_xlabels(axarr, n_row: int = None, n_col: int = None, xlabel: str = None, label_fontsize: int = None): +def add_x_labels(ax_array, n_row: int = None, n_col: int = None, x_label: str = None, label_fontsize: int = None): # Only add x-labels to the bottom row - bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] + bottom_row = ax_array if n_row == 1 else ax_array[0] if n_col == 1 else ax_array[n_row - 1, :] for _ax in bottom_row: - _ax.set_xlabel(xlabel, fontsize=label_fontsize) + _ax.set_xlabel(x_label, fontsize=label_fontsize) -def add_ylabels(axarr, n_row: int = None, ylabel: str = None, label_fontsize: int = None): +def add_y_labels(ax_array, n_row: int = None, y_label: str = None, label_fontsize: int = None): # Only add y-labels to right left-most row if n_row == 1: # if there is only one row, the ax array is 1D - axarr[0].set_ylabel(ylabel, fontsize=label_fontsize) + ax_array[0].set_ylabel(y_label, fontsize=label_fontsize) # If there is more than one row, the ax array is 2D else: - for _ax in axarr[:, 0]: - _ax.set_ylabel(ylabel, fontsize=label_fontsize) + for _ax in ax_array[:, 0]: + _ax.set_ylabel(y_label, fontsize=label_fontsize) def add_labels( - axarr, n_row: int = None, n_col: int = None, xlabel: str = None, ylabel: str = None, label_fontsize: int = None + ax_array, n_row: int = None, n_col: int = None, x_label: str = None, y_label: str = None, label_fontsize: int = None ): """ Wrapper function for configuring labels for both axes. """ - add_xlabels(axarr, n_row, n_col, xlabel, label_fontsize) - add_ylabels(axarr, n_row, ylabel, label_fontsize) + add_x_labels(ax_array, n_row, n_col, x_label, label_fontsize) + add_y_labels(ax_array, n_row, y_label, label_fontsize) -def remove_unused_axes(axarr_it, n_params: int = None): - for _ax in axarr_it[n_params:]: - _ax.remove() +def remove_unused_axes(ax_array_it, n_params: int = None): + for ax in ax_array_it[n_params:]: + ax.remove() -def preprocess(post_samples, prior_samples, fig_size: tuple = None, collapse: bool = True): +def preprocess(post_samples, prior_samples, fig_size: tuple = None, flatten: bool = True): """ Procedural wrapper that encompasses all preprocessing steps, including shape-checking, parameter name generation, layout configuration, @@ -230,7 +230,7 @@ def preprocess(post_samples, prior_samples, fig_size: tuple = None, collapse: bo The prior draws obtained for generating n_data_sets fig_size : tuple, optional, default: None Size of the figure adjusting to the display resolution - collapse : bool, optional, default: True + flatten : bool, optional, default: True Whether subplots in a figure are collapsed into rows """ @@ -241,18 +241,18 @@ def preprocess(post_samples, prior_samples, fig_size: tuple = None, collapse: bo n_params, param_names = get_count_and_names(post_samples) # Configure layout - n_row, n_col = configure_layout(n_params) + n_row, n_col = set_layout(n_params) # Initialize figure - f, axarr = initialize_figure(n_row, n_col, fig_size=fig_size) + f, ax_array = make_figure(n_row, n_col, fig_size=fig_size) - # turn axarr into 1D list - if collapse: - axarr_it = collapse_axes(axarr, n_row, n_col) + # turn ax_array into 1D list + if flatten: + ax_array_it = flatten_axes(ax_array, n_row, n_col) else: - axarr_it = axarr + ax_array_it = ax_array - return f, axarr, axarr_it, n_row, n_col, n_params, param_names + return f, ax_array, ax_array_it, n_row, n_col, n_params, param_names def postprocess(*args): From 84a49d32febbf2d4a0701ab4606da14c9d14d6a9 Mon Sep 17 00:00:00 2001 From: Jerry Date: Sat, 2 Nov 2024 17:02:27 -0400 Subject: [PATCH 22/22] Add plot_mmd_hypothesis_test --- .../diagnostics/plot_mmd_hypothesis_test.py | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 bayesflow/diagnostics/plot_mmd_hypothesis_test.py diff --git a/bayesflow/diagnostics/plot_mmd_hypothesis_test.py b/bayesflow/diagnostics/plot_mmd_hypothesis_test.py new file mode 100644 index 000000000..01b336094 --- /dev/null +++ b/bayesflow/diagnostics/plot_mmd_hypothesis_test.py @@ -0,0 +1,100 @@ +import matplotlib.pyplot as plt +import seaborn as sns + +from keras import ops + + +def plot_mmd_hypothesis_test( + mmd_null, + mmd_observed: float = None, + alpha_level: float = 0.05, + null_color: str | tuple = (0.16407, 0.020171, 0.577478), + observed_color: str | tuple = "red", + alpha_color: str | tuple = "orange", + truncate_v_lines_at_kde: bool = False, + x_min: float = None, + x_max: float = None, + bw_factor: float = 1.5, +): + """ + + Parameters + ---------- + mmd_null : np.ndarray + The samples from the MMD sampling distribution under the null hypothesis "the model is well-specified" + mmd_observed : float + The observed MMD value + alpha_level : float, optional, default: 0.05 + The rejection probability (type I error) + null_color : str or tuple, optional, default: (0.16407, 0.020171, 0.577478) + The color of the H0 sampling distribution + observed_color : str or tuple, optional, default: "red" + The color of the observed MMD + alpha_color : str or tuple, optional, default: "orange" + The color of the rejection area + truncate_v_lines_at_kde: bool, optional, default: False + true: cut off the vlines at the kde + false: continue kde lines across the plot + x_min : float, optional, default: None + The lower x-axis limit + x_max : float, optional, default: None + The upper x-axis limit + bw_factor : float, optional, default: 1.5 + bandwidth (aka. smoothing parameter) of the kernel density estimate + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + """ + + def draw_v_line_to_kde(x, kde_object, color, label=None, **kwargs): + kde_x, kde_y = kde_object.lines[0].get_data() + idx = ops.argmin(ops.abs(kde_x - x)) + plt.vlines(x=x, ymin=0, ymax=kde_y[idx], color=color, linewidth=3, label=label, **kwargs) + + def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs): + kde_x, kde_y = kde_object.lines[0].get_data() + if x_end is not None: + plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start) & (kde_x <= x_end), interpolate=True, **kwargs) + else: + plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start), interpolate=True, **kwargs) + + f = plt.figure(figsize=(8, 4)) + + kde = sns.kdeplot(mmd_null, fill=False, linewidth=0, bw_adjust=bw_factor) + sns.kdeplot(mmd_null, fill=True, alpha=0.12, color=null_color, bw_adjust=bw_factor) + + if truncate_v_lines_at_kde: + draw_v_line_to_kde(x=mmd_observed, kde_object=kde, color=observed_color, label=r"Observed data") + else: + plt.vlines( + x=mmd_observed, + ymin=0, + ymax=plt.gca().get_ylim()[1], + color=observed_color, + linewidth=3, + label=r"Observed data", + ) + + mmd_critical = ops.quantile(mmd_null, 1 - alpha_level) + fill_area_under_kde( + kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level*100)}% rejection area" + ) + + if truncate_v_lines_at_kde: + draw_v_line_to_kde(x=mmd_critical, kde_object=kde, color=alpha_color) + else: + plt.vlines(x=mmd_critical, color=alpha_color, linewidth=3, ymin=0, ymax=plt.gca().get_ylim()[1]) + + sns.kdeplot(mmd_null, fill=False, linewidth=3, color=null_color, label=r"$H_0$", bw_adjust=bw_factor) + + plt.xlabel(r"MMD", fontsize=20) + plt.ylabel("") + plt.yticks([]) + plt.xlim(x_min, x_max) + plt.tick_params(axis="both", which="major", labelsize=16) + + plt.legend(fontsize=20) + sns.despine() + + return f