From 5bb21056443274623e58b55351d7d3a14037d402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Tue, 18 Feb 2025 12:19:38 +0100 Subject: [PATCH 1/7] introduce VariableArray class --- .../diagnostics/metrics/calibration_error.py | 1 + .../metrics/posterior_contraction.py | 3 +- .../metrics/root_mean_squared_error.py | 3 +- .../diagnostics/plots/pairs_posterior.py | 6 +- bayesflow/diagnostics/plots/pairs_samples.py | 7 ++- bayesflow/utils/dict_utils.py | 57 +++++++++++++++---- bayesflow/utils/plot_utils.py | 6 +- 7 files changed, 65 insertions(+), 18 deletions(-) diff --git a/bayesflow/diagnostics/metrics/calibration_error.py b/bayesflow/diagnostics/metrics/calibration_error.py index 3fa808b8c..af7a7c030 100644 --- a/bayesflow/diagnostics/metrics/calibration_error.py +++ b/bayesflow/diagnostics/metrics/calibration_error.py @@ -88,4 +88,5 @@ def calibration_error( # Aggregate errors across alpha error = aggregation(absolute_errors, axis=0) + variable_names = samples["estimates"].variable_names return {"values": error, "metric_name": "Calibration Error", "variable_names": variable_names} diff --git a/bayesflow/diagnostics/metrics/posterior_contraction.py b/bayesflow/diagnostics/metrics/posterior_contraction.py index 523d3e767..eb7f898c6 100644 --- a/bayesflow/diagnostics/metrics/posterior_contraction.py +++ b/bayesflow/diagnostics/metrics/posterior_contraction.py @@ -58,4 +58,5 @@ def posterior_contraction( prior_vars = samples["targets"].var(axis=0, keepdims=True, ddof=1) contraction = 1 - (post_vars / prior_vars) contraction = aggregation(contraction, axis=0) - return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": samples["variable_names"]} + variable_names = samples["estimates"].variable_names + return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": variable_names} diff --git a/bayesflow/diagnostics/metrics/root_mean_squared_error.py b/bayesflow/diagnostics/metrics/root_mean_squared_error.py index 91ef38ce6..3289cd0e8 100644 --- a/bayesflow/diagnostics/metrics/root_mean_squared_error.py +++ b/bayesflow/diagnostics/metrics/root_mean_squared_error.py @@ -65,4 +65,5 @@ def root_mean_squared_error( metric_name = "RMSE" rmse = aggregation(rmse, axis=0) - return {"values": rmse, "metric_name": metric_name, "variable_names": samples["variable_names"]} + variable_names = samples["estimates"].variable_names + return {"values": rmse, "metric_name": metric_name, "variable_names": variable_names} diff --git a/bayesflow/diagnostics/plots/pairs_posterior.py b/bayesflow/diagnostics/plots/pairs_posterior.py index a77fb61b3..ea91a6a0a 100644 --- a/bayesflow/diagnostics/plots/pairs_posterior.py +++ b/bayesflow/diagnostics/plots/pairs_posterior.py @@ -135,7 +135,11 @@ def plot_true_params(x, **kwargs): plt.axvline(param, color="black", linestyle="--") # Add vertical line # only plot on the diagonal a vertical line for the true parameter - g.data = pd.DataFrame(plot_data["targets"][np.newaxis], columns=plot_data["variable_names"]) + + g.data = pd.DataFrame( + plot_data["targets"][np.newaxis], + columns=plot_data["targets"].variable_names, + ) g.map_diag(plot_true_params) return g diff --git a/bayesflow/diagnostics/plots/pairs_samples.py b/bayesflow/diagnostics/plots/pairs_samples.py index 2bdd363bd..8ab81111b 100644 --- a/bayesflow/diagnostics/plots/pairs_samples.py +++ b/bayesflow/diagnostics/plots/pairs_samples.py @@ -88,7 +88,8 @@ def _pairs_samples( ) # Convert samples to pd.DataFrame - data_to_plot = pd.DataFrame(plot_data["estimates"], columns=plot_data["variable_names"]) + variable_names = plot_data["estimates"].variable_names + data_to_plot = pd.DataFrame(plot_data["estimates"], columns=variable_names) # initialize plot artist = sns.PairGrid(data_to_plot, height=height, **kwargs) @@ -122,8 +123,8 @@ def _pairs_samples( # adjust font size of labels # the labels themselves remain the same as before, i.e., variable_names - artist.axes[i, 0].set_ylabel(plot_data["variable_names"][i], fontsize=label_fontsize) - artist.axes[dim - 1, i].set_xlabel(plot_data["variable_names"][i], fontsize=label_fontsize) + artist.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize) + artist.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize) # Return figure artist.tight_layout() diff --git a/bayesflow/utils/dict_utils.py b/bayesflow/utils/dict_utils.py index 180f0c59d..e76b72466 100644 --- a/bayesflow/utils/dict_utils.py +++ b/bayesflow/utils/dict_utils.py @@ -129,13 +129,36 @@ def split_arrays(data: Mapping[str, np.ndarray], axis: int = -1) -> Mapping[str, return result -def validate_variable_array( +class VariableArray(np.ndarray): + """ + An enriched numpy array with information on variable keys and names + to be used in post-processing, specifically the diagnostics module. + + The current implemention is very basic and we may want to extend it + in the future should this general structure prove useful. + + Design according to + https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray + """ + def __new__(cls, input_array, variable_keys=None, variable_names=None): + obj = np.asarray(input_array).view(cls) + obj.variable_keys = variable_keys + obj.variable_names = variable_names + return obj + + def __array_finalize__(self, obj): + if obj is None: return + self.variable_keys = getattr(obj, 'variable_keys', None) + self.variable_names = getattr(obj, 'variable_names', None) + + +def make_variable_array( x: Mapping[str, np.ndarray] | np.ndarray, dataset_ids: Sequence[int] | int = None, variable_keys: Sequence[str] | str = None, variable_names: Sequence[str] | str = None, default_name: str = "v", -): +) -> VariableArray: """ Helper function to validate arrays for use in the diagnostics module. @@ -180,7 +203,15 @@ def validate_variable_array( # Case arrays provided elif isinstance(x, np.ndarray): - if variable_names is None: + if isinstance(x, VariableArray): + # reuse existing variable keys and names if contained in x + if variable_names is None: + variable_names = x.variable_names + if variable_keys in None: + variable_keys = x.variable_keys + + # use default names if not otherwise specified + if variable_names is None: variable_names = [f"${default_name}_{{{i}}}$" for i in range(x.shape[-1])] if dataset_ids is not None: @@ -192,8 +223,14 @@ def validate_variable_array( if len(variable_names) is not x.shape[-1]: raise ValueError("Length of 'variable_names' should be the same as the number of variables.") + + if variable_keys is None: + # every variable will count as its own key if not otherwise specified + variable_keys = variable_names + + x = VariableArray(x, variable_keys=variable_keys, variable_names=variable_names) - return x, variable_keys, variable_names + return x def dicts_to_arrays( @@ -245,7 +282,7 @@ def dicts_to_arrays( # other to be validated arrays (see below) will take use # the variable_keys and variable_names implied by estimates - estimates, variable_keys, variable_names = validate_variable_array( + estimates = make_variable_array( estimates, dataset_ids=dataset_ids, variable_keys=variable_keys, @@ -254,16 +291,14 @@ def dicts_to_arrays( ) if targets is not None: - targets, _, _ = validate_variable_array( + targets = make_variable_array( targets, dataset_ids=dataset_ids, - variable_keys=variable_keys, - variable_names=variable_names, + variable_keys=estimates.variable_keys, + variable_names=estimates.variable_names, ) - + return dict( estimates=estimates, targets=targets, - variable_names=variable_names, - num_variables=len(variable_names), ) diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index fd8512d0e..7ff6004ce 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -60,11 +60,15 @@ def prepare_plot_data( check_estimates_prior_shapes(plot_data["estimates"], plot_data["targets"]) # Configure layout - num_row, num_col = set_layout(plot_data["num_variables"], num_row, num_col, stacked) + variable_names = plot_data["estimates"].variable_names + num_variables = len(variable_names) + num_row, num_col = set_layout(num_variables, num_row, num_col, stacked) # Initialize figure fig, axes = make_figure(num_row, num_col, figsize=figsize) + plot_data["variable_names"] = variable_names + plot_data["num_variables"] = num_variables plot_data["fig"] = fig plot_data["axes"] = axes plot_data["num_row"] = num_row From cf9df95cc9771b954e7f2b9f05b4f074a73cea5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Tue, 18 Feb 2025 16:04:07 +0100 Subject: [PATCH 2/7] enable also plotting the prior via pairs_posterior --- .../diagnostics/plots/pairs_posterior.py | 49 ++++-------- bayesflow/diagnostics/plots/pairs_samples.py | 75 +++++++++++++------ bayesflow/utils/dict_utils.py | 14 ++++ bayesflow/utils/plot_utils.py | 10 ++- 4 files changed, 85 insertions(+), 63 deletions(-) diff --git a/bayesflow/diagnostics/plots/pairs_posterior.py b/bayesflow/diagnostics/plots/pairs_posterior.py index ea91a6a0a..cdee3f566 100644 --- a/bayesflow/diagnostics/plots/pairs_posterior.py +++ b/bayesflow/diagnostics/plots/pairs_posterior.py @@ -20,14 +20,12 @@ def pairs_posterior( variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, height: int = 3, + post_color: str | tuple = "#132a70", + prior_color: str | tuple = "gray", + alpha = 0.9, label_fontsize: int = 14, tick_fontsize: int = 12, - # arguments related to priors which is currently unused - # legend_fontsize: int = 16, - # post_color: str | tuple = "#132a70", - # prior_color: str | tuple = "gray", - # post_alpha: float = 0.9, - # prior_alpha: float = 0.7, + legend_fontsize: int = 16, **kwargs, ) -> sns.PairGrid: """Generates a bivariate pair plot given posterior draws and optional prior or prior draws. @@ -75,6 +73,7 @@ def pairs_posterior( plot_data = dicts_to_arrays( estimates=estimates, targets=targets, + priors=priors, dataset_ids=dataset_id, variable_keys=variable_keys, variable_names=variable_names, @@ -90,56 +89,34 @@ def pairs_posterior( g = _pairs_samples( plot_data=plot_data, height=height, + color=post_color, + color2=prior_color, + alpha=alpha, label_fontsize=label_fontsize, tick_fontsize=tick_fontsize, + legend_fontsize=legend_fontsize, **kwargs, ) - # add priors - if priors is not None: - # TODO: integrate priors into plot_data and then use - # proper coloring of posterior vs. prior using the hue argument in PairGrid - raise ValueError("Plotting prior samples is not yet implemented.") - - """ - # this is currently not working as expected as it doesn't show the off diagonal plots - prior_samples_df = pd.DataFrame(priors, columns=plot_data["variable_names"]) - g.data = prior_samples_df - g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1) - g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1) - - # Add legend to differentiate between prior and posterior - handles = [ - Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha), - Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha), - ] - handles_names = ["Posterior", "Prior"] - if targets is not None: - handles.append(Line2D(xdata=[], ydata=[], color="black", lw=3, linestyle="--")) - handles_names.append("True Parameter") - plt.legend(handles=handles, labels=handles_names, fontsize=legend_fontsize, loc="center right") - """ - # add true parameters if plot_data["targets"] is not None: - # TODO: also add true parameters to the off diagonal plots? - # drop dataset axis if it is still present but of length 1 targets_shape = plot_data["targets"].shape if len(targets_shape) == 2 and targets_shape[0] == 1: plot_data["targets"] = np.squeeze(plot_data["targets"], axis=0) # Custom function to plot true parameters on the diagonal - def plot_true_params(x, **kwargs): + def plot_true_params(x, hue=None, **kwargs): + # hue needs to be added to handle the case of plotting both posterior and prior param = x.iloc[0] # Get the single true value for the diagonal + # only plot on the diagonal a vertical line for the true parameter plt.axvline(param, color="black", linestyle="--") # Add vertical line - # only plot on the diagonal a vertical line for the true parameter - g.data = pd.DataFrame( plot_data["targets"][np.newaxis], columns=plot_data["targets"].variable_names, ) + g.data["Source"] = "True Parameter" g.map_diag(plot_true_params) return g diff --git a/bayesflow/diagnostics/plots/pairs_samples.py b/bayesflow/diagnostics/plots/pairs_samples.py index 8ab81111b..2357e06ae 100644 --- a/bayesflow/diagnostics/plots/pairs_samples.py +++ b/bayesflow/diagnostics/plots/pairs_samples.py @@ -68,9 +68,11 @@ def _pairs_samples( plot_data: dict, height: float = 2.5, color: str | tuple = "#132a70", + color2: str | tuple = "gray", alpha: float = 0.9, label_fontsize: int = 14, tick_fontsize: int = 12, + legend_fontsize: int = 14, **kwargs, ) -> sns.PairGrid: # internal version of pairs_samples creating the seaborn plot @@ -87,46 +89,71 @@ def _pairs_samples( f"your samples array has a shape of {estimates_shape}." ) - # Convert samples to pd.DataFrame variable_names = plot_data["estimates"].variable_names - data_to_plot = pd.DataFrame(plot_data["estimates"], columns=variable_names) - - # initialize plot - artist = sns.PairGrid(data_to_plot, height=height, **kwargs) - # Generate grids - # in the off diagonal plots, the grids appears in front of the points/densities - # TODO: can we put the grid in the background somehow? - dim = artist.axes.shape[0] - for i in range(dim): - for j in range(dim): - artist.axes[i, j].grid(alpha=0.5) + # Convert samples to pd.DataFrame + if plot_data["priors"] is not None: + # differentiate posterior from prior draws + # row bind posterior and prior draws + samples = np.vstack((plot_data["priors"], plot_data["estimates"])) + data_to_plot = pd.DataFrame(samples, columns=variable_names) + + # ensure that the source of the samples is stored + source_prior = np.repeat("Prior", plot_data["priors"].shape[0]) + source_post = np.repeat("Posterior", plot_data["estimates"].shape[0]) + data_to_plot["Source"] = np.concatenate((source_prior, source_post)) + data_to_plot["Source"] = pd.Categorical(data_to_plot["Source"], categories=["Prior", "Posterior"]) + + color = [color, color2] + + # initialize plot + g = sns.PairGrid(data_to_plot, height=height, hue="Source", **kwargs) + + else: + # plot just the one set of distributions + data_to_plot = pd.DataFrame(plot_data["estimates"], columns=variable_names) + + # initialize plot + g = sns.PairGrid(data_to_plot, height=height, **kwargs) # add histograms + KDEs to the diagonal - artist.map_diag(sns.histplot, fill=True, color=color, alpha=alpha, kde=True) + g.map_diag(sns.histplot, fill=True, color=color, alpha=alpha, kde=True) + + # add scatterplots to the upper diagonal + g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) - # Incorporate exceptions for generating KDE plots + # add KDEs to the lower diagonal try: - artist.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha) + g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha) except Exception as e: logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.") - artist.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) + g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) - artist.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) + # need to add legend here such that colors are recognized + if plot_data["priors"] is not None: + g.add_legend(fontsize=legend_fontsize, loc="center right") + g._legend.set_title(None) - dim = artist.axes.shape[0] + # Generate grids + dim = g.axes.shape[0] + for i in range(dim): + for j in range(dim): + g.axes[i, j].grid(alpha=0.5) + g.axes[i, j].set_axisbelow(True) + + dim = g.axes.shape[0] for i in range(dim): # Modify tick sizes for j in range(i + 1): - artist.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize) - artist.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) + g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize) + g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) # adjust font size of labels # the labels themselves remain the same as before, i.e., variable_names - artist.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize) - artist.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize) + g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize) + g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize) # Return figure - artist.tight_layout() + g.tight_layout() - return artist + return g diff --git a/bayesflow/utils/dict_utils.py b/bayesflow/utils/dict_utils.py index e76b72466..965809a81 100644 --- a/bayesflow/utils/dict_utils.py +++ b/bayesflow/utils/dict_utils.py @@ -236,6 +236,7 @@ def make_variable_array( def dicts_to_arrays( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray = None, + priors: Mapping[str, np.ndarray] | np.ndarray = None, dataset_ids: Sequence[int] | int = None, variable_keys: Sequence[str] | str = None, variable_names: Sequence[str] | str = None, @@ -272,6 +273,10 @@ def dicts_to_arrays( dataset_ids : Sequence of integers indexing the datasets to select (default = None). By default, use all datasets. + variable_keys : list or None, optional, default: None + Select keys from the dictionary provided in samples. + By default, select all keys. + variable_names : Sequence[str], optional (default = None) Optional variable names to act as a filter if dicts provided or actual variable names in case of array inputs. @@ -297,8 +302,17 @@ def dicts_to_arrays( variable_keys=estimates.variable_keys, variable_names=estimates.variable_names, ) + + if priors is not None: + priors = make_variable_array( + priors, + # priors are data independent so datasets_ids is not passed here + variable_keys=estimates.variable_keys, + variable_names=estimates.variable_names, + ) return dict( estimates=estimates, targets=targets, + priors=priors, ) diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index 7ff6004ce..9b5008669 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -36,6 +36,8 @@ def prepare_plot_data( targets : dict[str, ndarray] or ndarray, optional (default = None) Ground truth values corresponding to the estimates. Must match the structure and dimensionality of `estimates` in terms of first and last axis. + variable_keys : list or None, optional, default: None + Select keys from the dictionary provided in samples. By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to act as a filter if dicts provided or actual variable names in case of array args num_col : int @@ -59,16 +61,18 @@ def prepare_plot_data( ) check_estimates_prior_shapes(plot_data["estimates"], plot_data["targets"]) - # Configure layout + # store variable information at top level for easy access variable_names = plot_data["estimates"].variable_names num_variables = len(variable_names) + plot_data["variable_names"] = variable_names + plot_data["num_variables"] = num_variables + + # Configure layout num_row, num_col = set_layout(num_variables, num_row, num_col, stacked) # Initialize figure fig, axes = make_figure(num_row, num_col, figsize=figsize) - plot_data["variable_names"] = variable_names - plot_data["num_variables"] = num_variables plot_data["fig"] = fig plot_data["axes"] = axes plot_data["num_row"] = num_row From 400ff911b17660560dc3920aa0c53c4c6d2da8e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Wed, 19 Feb 2025 11:01:38 +0100 Subject: [PATCH 3/7] further fixes to pairs_posterior --- .../diagnostics/plots/pairs_posterior.py | 8 +++--- bayesflow/diagnostics/plots/pairs_samples.py | 26 ++++++++++++++----- bayesflow/utils/dict_utils.py | 20 +++++++------- bayesflow/utils/plot_utils.py | 2 +- 4 files changed, 35 insertions(+), 21 deletions(-) diff --git a/bayesflow/diagnostics/plots/pairs_posterior.py b/bayesflow/diagnostics/plots/pairs_posterior.py index cdee3f566..859611175 100644 --- a/bayesflow/diagnostics/plots/pairs_posterior.py +++ b/bayesflow/diagnostics/plots/pairs_posterior.py @@ -22,10 +22,10 @@ def pairs_posterior( height: int = 3, post_color: str | tuple = "#132a70", prior_color: str | tuple = "gray", - alpha = 0.9, + alpha=0.9, label_fontsize: int = 14, tick_fontsize: int = 12, - legend_fontsize: int = 16, + legend_fontsize: int = 14, **kwargs, ) -> sns.PairGrid: """Generates a bivariate pair plot given posterior draws and optional prior or prior draws. @@ -113,10 +113,10 @@ def plot_true_params(x, hue=None, **kwargs): plt.axvline(param, color="black", linestyle="--") # Add vertical line g.data = pd.DataFrame( - plot_data["targets"][np.newaxis], + plot_data["targets"][np.newaxis], columns=plot_data["targets"].variable_names, ) - g.data["Source"] = "True Parameter" + g.data["_source"] = "True Parameter" g.map_diag(plot_true_params) return g diff --git a/bayesflow/diagnostics/plots/pairs_samples.py b/bayesflow/diagnostics/plots/pairs_samples.py index 2357e06ae..b1d045e2d 100644 --- a/bayesflow/diagnostics/plots/pairs_samples.py +++ b/bayesflow/diagnostics/plots/pairs_samples.py @@ -101,13 +101,17 @@ def _pairs_samples( # ensure that the source of the samples is stored source_prior = np.repeat("Prior", plot_data["priors"].shape[0]) source_post = np.repeat("Posterior", plot_data["estimates"].shape[0]) - data_to_plot["Source"] = np.concatenate((source_prior, source_post)) - data_to_plot["Source"] = pd.Categorical(data_to_plot["Source"], categories=["Prior", "Posterior"]) + data_to_plot["_source"] = np.concatenate((source_prior, source_post)) + data_to_plot["_source"] = pd.Categorical(data_to_plot["_source"], categories=["Prior", "Posterior"]) - color = [color, color2] - # initialize plot - g = sns.PairGrid(data_to_plot, height=height, hue="Source", **kwargs) + g = sns.PairGrid( + data_to_plot, + height=height, + hue="_source", + palette=[color2, color], + **kwargs, + ) else: # plot just the one set of distributions @@ -117,8 +121,16 @@ def _pairs_samples( g = sns.PairGrid(data_to_plot, height=height, **kwargs) # add histograms + KDEs to the diagonal - g.map_diag(sns.histplot, fill=True, color=color, alpha=alpha, kde=True) - + g.map_diag( + sns.histplot, + fill=True, + kde=True, + color=color, + alpha=alpha, + stat="density", + common_norm=False, + ) + # add scatterplots to the upper diagonal g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) diff --git a/bayesflow/utils/dict_utils.py b/bayesflow/utils/dict_utils.py index 965809a81..3f755ba34 100644 --- a/bayesflow/utils/dict_utils.py +++ b/bayesflow/utils/dict_utils.py @@ -134,22 +134,24 @@ class VariableArray(np.ndarray): An enriched numpy array with information on variable keys and names to be used in post-processing, specifically the diagnostics module. - The current implemention is very basic and we may want to extend it + The current implemention is very basic and we may want to extend it in the future should this general structure prove useful. - Design according to + Design according to https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray """ - def __new__(cls, input_array, variable_keys=None, variable_names=None): + + def __new__(cls, input_array, variable_keys=None, variable_names=None): obj = np.asarray(input_array).view(cls) obj.variable_keys = variable_keys obj.variable_names = variable_names return obj def __array_finalize__(self, obj): - if obj is None: return - self.variable_keys = getattr(obj, 'variable_keys', None) - self.variable_names = getattr(obj, 'variable_names', None) + if obj is None: + return + self.variable_keys = getattr(obj, "variable_keys", None) + self.variable_names = getattr(obj, "variable_names", None) def make_variable_array( @@ -211,7 +213,7 @@ def make_variable_array( variable_keys = x.variable_keys # use default names if not otherwise specified - if variable_names is None: + if variable_names is None: variable_names = [f"${default_name}_{{{i}}}$" for i in range(x.shape[-1])] if dataset_ids is not None: @@ -223,7 +225,7 @@ def make_variable_array( if len(variable_names) is not x.shape[-1]: raise ValueError("Length of 'variable_names' should be the same as the number of variables.") - + if variable_keys is None: # every variable will count as its own key if not otherwise specified variable_keys = variable_names @@ -310,7 +312,7 @@ def dicts_to_arrays( variable_keys=estimates.variable_keys, variable_names=estimates.variable_names, ) - + return dict( estimates=estimates, targets=targets, diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index 9b5008669..2a1ccb2c8 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -63,7 +63,7 @@ def prepare_plot_data( # store variable information at top level for easy access variable_names = plot_data["estimates"].variable_names - num_variables = len(variable_names) + num_variables = len(variable_names) plot_data["variable_names"] = variable_names plot_data["num_variables"] = num_variables From 9d09934a3da89704e6d7d995e802c08a4097d2db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Wed, 19 Feb 2025 14:00:24 +0100 Subject: [PATCH 4/7] fix issue #324 --- bayesflow/diagnostics/plots/calibration_ecdf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/diagnostics/plots/calibration_ecdf.py b/bayesflow/diagnostics/plots/calibration_ecdf.py index 90062d1bc..adf153642 100644 --- a/bayesflow/diagnostics/plots/calibration_ecdf.py +++ b/bayesflow/diagnostics/plots/calibration_ecdf.py @@ -143,7 +143,7 @@ def calibration_ecdf( # Plot individual ecdf of parameters for j in range(ranks.shape[-1]): - ecdf_single = np.sort(ranks[:, j]) + ecdf_single = np.pad(np.sort(ranks[:, j]), (1, 1), constant_values=(0, 1)) xx = ecdf_single yy = np.arange(1, xx.shape[-1] + 1) / float(xx.shape[-1]) From 28346888f3570dfbdd86d7c8e6597b7095240526 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Wed, 19 Feb 2025 09:20:20 -0500 Subject: [PATCH 5/7] Small refactor, squeezing still unclear --- .../diagnostics/plots/pairs_posterior.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/bayesflow/diagnostics/plots/pairs_posterior.py b/bayesflow/diagnostics/plots/pairs_posterior.py index 859611175..f924f2c12 100644 --- a/bayesflow/diagnostics/plots/pairs_posterior.py +++ b/bayesflow/diagnostics/plots/pairs_posterior.py @@ -55,10 +55,12 @@ def pairs_posterior( The color for the posterior histograms and KDEs priors_color : str, optional, default: gray The color for the optional prior histograms and KDEs - post_alpha : float in [0, 1], optonal, default: 0.9 + post_alpha : float in [0, 1], optional, default: 0.9 The opacity of the posterior plots - prior_alpha : float in [0, 1], optonal, default: 0.7 + prior_alpha : float in [0, 1], optional, default: 0.7 The opacity of the prior plots + **kwargs : dict, optional, default: {} + Further optional keyword arguments propagated to `_pairs_samples` Returns ------- @@ -105,13 +107,6 @@ def pairs_posterior( if len(targets_shape) == 2 and targets_shape[0] == 1: plot_data["targets"] = np.squeeze(plot_data["targets"], axis=0) - # Custom function to plot true parameters on the diagonal - def plot_true_params(x, hue=None, **kwargs): - # hue needs to be added to handle the case of plotting both posterior and prior - param = x.iloc[0] # Get the single true value for the diagonal - # only plot on the diagonal a vertical line for the true parameter - plt.axvline(param, color="black", linestyle="--") # Add vertical line - g.data = pd.DataFrame( plot_data["targets"][np.newaxis], columns=plot_data["targets"].variable_names, @@ -120,3 +115,12 @@ def plot_true_params(x, hue=None, **kwargs): g.map_diag(plot_true_params) return g + + +def plot_true_params(x, **kwargs): + """Custom function to plot true parameters on the diagonal.""" + + # hue needs to be added to handle the case of plotting both posterior and prior + param = x.iloc[0] # Get the single true value for the diagonal + # only plot on the diagonal a vertical line for the true parameter + plt.axvline(param, color="black", linestyle="--") From dbfc0d5d7896f455a1806a4dce1f7efb50e922ca Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Wed, 19 Feb 2025 09:21:34 -0500 Subject: [PATCH 6/7] Sneak in small change in tutorial name --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b883eab7b..a6aadc9e1 100644 --- a/README.md +++ b/README.md @@ -94,9 +94,9 @@ conda env create --file environment.yaml --name bayesflow Check out some of our walk-through notebooks below. We are actively working on porting all notebooks to the new interface so more will be available soon! 1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb) -2. [Two moons starter example](examples/Two_Moons_Starter.ipynb) -3. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb) -4. [SBML model using an external simulator](examples/From_ABC_to_BayesFlow.ipynb) +2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb) +3. [Two moons starter example](examples/Two_Moons_Starter.ipynb) +4. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb) 5. [Hyperparameter optimization](examples/Hyperparameter_Optimization.ipynb) 6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb) 7. [Simple model comparison example (One-Sample T-Test)](examples/One_Sample_TTest.ipynb) From 0461991cab14064158b3972add8da18e2e8ae01c Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Wed, 19 Feb 2025 10:03:36 -0500 Subject: [PATCH 7/7] Cleanup function --- .../diagnostics/plots/pairs_posterior.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/bayesflow/diagnostics/plots/pairs_posterior.py b/bayesflow/diagnostics/plots/pairs_posterior.py index f924f2c12..40a16ee23 100644 --- a/bayesflow/diagnostics/plots/pairs_posterior.py +++ b/bayesflow/diagnostics/plots/pairs_posterior.py @@ -100,17 +100,14 @@ def pairs_posterior( **kwargs, ) - # add true parameters - if plot_data["targets"] is not None: - # drop dataset axis if it is still present but of length 1 - targets_shape = plot_data["targets"].shape - if len(targets_shape) == 2 and targets_shape[0] == 1: - plot_data["targets"] = np.squeeze(plot_data["targets"], axis=0) - - g.data = pd.DataFrame( - plot_data["targets"][np.newaxis], - columns=plot_data["targets"].variable_names, - ) + targets = plot_data.get("targets") + if targets is not None: + # Ensure targets is at least 2D + if targets.ndim == 1: + targets = np.atleast_2d(targets) + + # Create DataFrame with variable names as columns + g.data = pd.DataFrame(targets, columns=targets.variable_names) g.data["_source"] = "True Parameter" g.map_diag(plot_true_params)