diff --git a/README.md b/README.md index b883eab7b..a6aadc9e1 100644 --- a/README.md +++ b/README.md @@ -94,9 +94,9 @@ conda env create --file environment.yaml --name bayesflow Check out some of our walk-through notebooks below. We are actively working on porting all notebooks to the new interface so more will be available soon! 1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb) -2. [Two moons starter example](examples/Two_Moons_Starter.ipynb) -3. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb) -4. [SBML model using an external simulator](examples/From_ABC_to_BayesFlow.ipynb) +2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb) +3. [Two moons starter example](examples/Two_Moons_Starter.ipynb) +4. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb) 5. [Hyperparameter optimization](examples/Hyperparameter_Optimization.ipynb) 6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb) 7. [Simple model comparison example (One-Sample T-Test)](examples/One_Sample_TTest.ipynb) diff --git a/bayesflow/diagnostics/metrics/calibration_error.py b/bayesflow/diagnostics/metrics/calibration_error.py index 3fa808b8c..af7a7c030 100644 --- a/bayesflow/diagnostics/metrics/calibration_error.py +++ b/bayesflow/diagnostics/metrics/calibration_error.py @@ -88,4 +88,5 @@ def calibration_error( # Aggregate errors across alpha error = aggregation(absolute_errors, axis=0) + variable_names = samples["estimates"].variable_names return {"values": error, "metric_name": "Calibration Error", "variable_names": variable_names} diff --git a/bayesflow/diagnostics/metrics/posterior_contraction.py b/bayesflow/diagnostics/metrics/posterior_contraction.py index 523d3e767..eb7f898c6 100644 --- a/bayesflow/diagnostics/metrics/posterior_contraction.py +++ b/bayesflow/diagnostics/metrics/posterior_contraction.py @@ -58,4 +58,5 @@ def posterior_contraction( prior_vars = samples["targets"].var(axis=0, keepdims=True, ddof=1) contraction = 1 - (post_vars / prior_vars) contraction = aggregation(contraction, axis=0) - return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": samples["variable_names"]} + variable_names = samples["estimates"].variable_names + return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": variable_names} diff --git a/bayesflow/diagnostics/metrics/root_mean_squared_error.py b/bayesflow/diagnostics/metrics/root_mean_squared_error.py index 91ef38ce6..3289cd0e8 100644 --- a/bayesflow/diagnostics/metrics/root_mean_squared_error.py +++ b/bayesflow/diagnostics/metrics/root_mean_squared_error.py @@ -65,4 +65,5 @@ def root_mean_squared_error( metric_name = "RMSE" rmse = aggregation(rmse, axis=0) - return {"values": rmse, "metric_name": metric_name, "variable_names": samples["variable_names"]} + variable_names = samples["estimates"].variable_names + return {"values": rmse, "metric_name": metric_name, "variable_names": variable_names} diff --git a/bayesflow/diagnostics/plots/calibration_ecdf.py b/bayesflow/diagnostics/plots/calibration_ecdf.py index 90062d1bc..adf153642 100644 --- a/bayesflow/diagnostics/plots/calibration_ecdf.py +++ b/bayesflow/diagnostics/plots/calibration_ecdf.py @@ -143,7 +143,7 @@ def calibration_ecdf( # Plot individual ecdf of parameters for j in range(ranks.shape[-1]): - ecdf_single = np.sort(ranks[:, j]) + ecdf_single = np.pad(np.sort(ranks[:, j]), (1, 1), constant_values=(0, 1)) xx = ecdf_single yy = np.arange(1, xx.shape[-1] + 1) / float(xx.shape[-1]) diff --git a/bayesflow/diagnostics/plots/pairs_posterior.py b/bayesflow/diagnostics/plots/pairs_posterior.py index a77fb61b3..40a16ee23 100644 --- a/bayesflow/diagnostics/plots/pairs_posterior.py +++ b/bayesflow/diagnostics/plots/pairs_posterior.py @@ -20,14 +20,12 @@ def pairs_posterior( variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, height: int = 3, + post_color: str | tuple = "#132a70", + prior_color: str | tuple = "gray", + alpha=0.9, label_fontsize: int = 14, tick_fontsize: int = 12, - # arguments related to priors which is currently unused - # legend_fontsize: int = 16, - # post_color: str | tuple = "#132a70", - # prior_color: str | tuple = "gray", - # post_alpha: float = 0.9, - # prior_alpha: float = 0.7, + legend_fontsize: int = 14, **kwargs, ) -> sns.PairGrid: """Generates a bivariate pair plot given posterior draws and optional prior or prior draws. @@ -57,10 +55,12 @@ def pairs_posterior( The color for the posterior histograms and KDEs priors_color : str, optional, default: gray The color for the optional prior histograms and KDEs - post_alpha : float in [0, 1], optonal, default: 0.9 + post_alpha : float in [0, 1], optional, default: 0.9 The opacity of the posterior plots - prior_alpha : float in [0, 1], optonal, default: 0.7 + prior_alpha : float in [0, 1], optional, default: 0.7 The opacity of the prior plots + **kwargs : dict, optional, default: {} + Further optional keyword arguments propagated to `_pairs_samples` Returns ------- @@ -75,6 +75,7 @@ def pairs_posterior( plot_data = dicts_to_arrays( estimates=estimates, targets=targets, + priors=priors, dataset_ids=dataset_id, variable_keys=variable_keys, variable_names=variable_names, @@ -90,52 +91,33 @@ def pairs_posterior( g = _pairs_samples( plot_data=plot_data, height=height, + color=post_color, + color2=prior_color, + alpha=alpha, label_fontsize=label_fontsize, tick_fontsize=tick_fontsize, + legend_fontsize=legend_fontsize, **kwargs, ) - # add priors - if priors is not None: - # TODO: integrate priors into plot_data and then use - # proper coloring of posterior vs. prior using the hue argument in PairGrid - raise ValueError("Plotting prior samples is not yet implemented.") - - """ - # this is currently not working as expected as it doesn't show the off diagonal plots - prior_samples_df = pd.DataFrame(priors, columns=plot_data["variable_names"]) - g.data = prior_samples_df - g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1) - g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1) - - # Add legend to differentiate between prior and posterior - handles = [ - Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha), - Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha), - ] - handles_names = ["Posterior", "Prior"] - if targets is not None: - handles.append(Line2D(xdata=[], ydata=[], color="black", lw=3, linestyle="--")) - handles_names.append("True Parameter") - plt.legend(handles=handles, labels=handles_names, fontsize=legend_fontsize, loc="center right") - """ - - # add true parameters - if plot_data["targets"] is not None: - # TODO: also add true parameters to the off diagonal plots? - - # drop dataset axis if it is still present but of length 1 - targets_shape = plot_data["targets"].shape - if len(targets_shape) == 2 and targets_shape[0] == 1: - plot_data["targets"] = np.squeeze(plot_data["targets"], axis=0) - - # Custom function to plot true parameters on the diagonal - def plot_true_params(x, **kwargs): - param = x.iloc[0] # Get the single true value for the diagonal - plt.axvline(param, color="black", linestyle="--") # Add vertical line - - # only plot on the diagonal a vertical line for the true parameter - g.data = pd.DataFrame(plot_data["targets"][np.newaxis], columns=plot_data["variable_names"]) + targets = plot_data.get("targets") + if targets is not None: + # Ensure targets is at least 2D + if targets.ndim == 1: + targets = np.atleast_2d(targets) + + # Create DataFrame with variable names as columns + g.data = pd.DataFrame(targets, columns=targets.variable_names) + g.data["_source"] = "True Parameter" g.map_diag(plot_true_params) return g + + +def plot_true_params(x, **kwargs): + """Custom function to plot true parameters on the diagonal.""" + + # hue needs to be added to handle the case of plotting both posterior and prior + param = x.iloc[0] # Get the single true value for the diagonal + # only plot on the diagonal a vertical line for the true parameter + plt.axvline(param, color="black", linestyle="--") diff --git a/bayesflow/diagnostics/plots/pairs_samples.py b/bayesflow/diagnostics/plots/pairs_samples.py index 2bdd363bd..b1d045e2d 100644 --- a/bayesflow/diagnostics/plots/pairs_samples.py +++ b/bayesflow/diagnostics/plots/pairs_samples.py @@ -68,9 +68,11 @@ def _pairs_samples( plot_data: dict, height: float = 2.5, color: str | tuple = "#132a70", + color2: str | tuple = "gray", alpha: float = 0.9, label_fontsize: int = 14, tick_fontsize: int = 12, + legend_fontsize: int = 14, **kwargs, ) -> sns.PairGrid: # internal version of pairs_samples creating the seaborn plot @@ -87,45 +89,83 @@ def _pairs_samples( f"your samples array has a shape of {estimates_shape}." ) + variable_names = plot_data["estimates"].variable_names + # Convert samples to pd.DataFrame - data_to_plot = pd.DataFrame(plot_data["estimates"], columns=plot_data["variable_names"]) + if plot_data["priors"] is not None: + # differentiate posterior from prior draws + # row bind posterior and prior draws + samples = np.vstack((plot_data["priors"], plot_data["estimates"])) + data_to_plot = pd.DataFrame(samples, columns=variable_names) + + # ensure that the source of the samples is stored + source_prior = np.repeat("Prior", plot_data["priors"].shape[0]) + source_post = np.repeat("Posterior", plot_data["estimates"].shape[0]) + data_to_plot["_source"] = np.concatenate((source_prior, source_post)) + data_to_plot["_source"] = pd.Categorical(data_to_plot["_source"], categories=["Prior", "Posterior"]) + + # initialize plot + g = sns.PairGrid( + data_to_plot, + height=height, + hue="_source", + palette=[color2, color], + **kwargs, + ) - # initialize plot - artist = sns.PairGrid(data_to_plot, height=height, **kwargs) + else: + # plot just the one set of distributions + data_to_plot = pd.DataFrame(plot_data["estimates"], columns=variable_names) - # Generate grids - # in the off diagonal plots, the grids appears in front of the points/densities - # TODO: can we put the grid in the background somehow? - dim = artist.axes.shape[0] - for i in range(dim): - for j in range(dim): - artist.axes[i, j].grid(alpha=0.5) + # initialize plot + g = sns.PairGrid(data_to_plot, height=height, **kwargs) # add histograms + KDEs to the diagonal - artist.map_diag(sns.histplot, fill=True, color=color, alpha=alpha, kde=True) + g.map_diag( + sns.histplot, + fill=True, + kde=True, + color=color, + alpha=alpha, + stat="density", + common_norm=False, + ) + + # add scatterplots to the upper diagonal + g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) - # Incorporate exceptions for generating KDE plots + # add KDEs to the lower diagonal try: - artist.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha) + g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha) except Exception as e: logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.") - artist.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) + g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) - artist.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) + # need to add legend here such that colors are recognized + if plot_data["priors"] is not None: + g.add_legend(fontsize=legend_fontsize, loc="center right") + g._legend.set_title(None) - dim = artist.axes.shape[0] + # Generate grids + dim = g.axes.shape[0] + for i in range(dim): + for j in range(dim): + g.axes[i, j].grid(alpha=0.5) + g.axes[i, j].set_axisbelow(True) + + dim = g.axes.shape[0] for i in range(dim): # Modify tick sizes for j in range(i + 1): - artist.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize) - artist.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) + g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize) + g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) # adjust font size of labels # the labels themselves remain the same as before, i.e., variable_names - artist.axes[i, 0].set_ylabel(plot_data["variable_names"][i], fontsize=label_fontsize) - artist.axes[dim - 1, i].set_xlabel(plot_data["variable_names"][i], fontsize=label_fontsize) + g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize) + g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize) # Return figure - artist.tight_layout() + g.tight_layout() - return artist + return g diff --git a/bayesflow/utils/dict_utils.py b/bayesflow/utils/dict_utils.py index 180f0c59d..3f755ba34 100644 --- a/bayesflow/utils/dict_utils.py +++ b/bayesflow/utils/dict_utils.py @@ -129,13 +129,38 @@ def split_arrays(data: Mapping[str, np.ndarray], axis: int = -1) -> Mapping[str, return result -def validate_variable_array( +class VariableArray(np.ndarray): + """ + An enriched numpy array with information on variable keys and names + to be used in post-processing, specifically the diagnostics module. + + The current implemention is very basic and we may want to extend it + in the future should this general structure prove useful. + + Design according to + https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray + """ + + def __new__(cls, input_array, variable_keys=None, variable_names=None): + obj = np.asarray(input_array).view(cls) + obj.variable_keys = variable_keys + obj.variable_names = variable_names + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self.variable_keys = getattr(obj, "variable_keys", None) + self.variable_names = getattr(obj, "variable_names", None) + + +def make_variable_array( x: Mapping[str, np.ndarray] | np.ndarray, dataset_ids: Sequence[int] | int = None, variable_keys: Sequence[str] | str = None, variable_names: Sequence[str] | str = None, default_name: str = "v", -): +) -> VariableArray: """ Helper function to validate arrays for use in the diagnostics module. @@ -180,6 +205,14 @@ def validate_variable_array( # Case arrays provided elif isinstance(x, np.ndarray): + if isinstance(x, VariableArray): + # reuse existing variable keys and names if contained in x + if variable_names is None: + variable_names = x.variable_names + if variable_keys in None: + variable_keys = x.variable_keys + + # use default names if not otherwise specified if variable_names is None: variable_names = [f"${default_name}_{{{i}}}$" for i in range(x.shape[-1])] @@ -193,12 +226,19 @@ def validate_variable_array( if len(variable_names) is not x.shape[-1]: raise ValueError("Length of 'variable_names' should be the same as the number of variables.") - return x, variable_keys, variable_names + if variable_keys is None: + # every variable will count as its own key if not otherwise specified + variable_keys = variable_names + + x = VariableArray(x, variable_keys=variable_keys, variable_names=variable_names) + + return x def dicts_to_arrays( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray = None, + priors: Mapping[str, np.ndarray] | np.ndarray = None, dataset_ids: Sequence[int] | int = None, variable_keys: Sequence[str] | str = None, variable_names: Sequence[str] | str = None, @@ -235,6 +275,10 @@ def dicts_to_arrays( dataset_ids : Sequence of integers indexing the datasets to select (default = None). By default, use all datasets. + variable_keys : list or None, optional, default: None + Select keys from the dictionary provided in samples. + By default, select all keys. + variable_names : Sequence[str], optional (default = None) Optional variable names to act as a filter if dicts provided or actual variable names in case of array inputs. @@ -245,7 +289,7 @@ def dicts_to_arrays( # other to be validated arrays (see below) will take use # the variable_keys and variable_names implied by estimates - estimates, variable_keys, variable_names = validate_variable_array( + estimates = make_variable_array( estimates, dataset_ids=dataset_ids, variable_keys=variable_keys, @@ -254,16 +298,23 @@ def dicts_to_arrays( ) if targets is not None: - targets, _, _ = validate_variable_array( + targets = make_variable_array( targets, dataset_ids=dataset_ids, - variable_keys=variable_keys, - variable_names=variable_names, + variable_keys=estimates.variable_keys, + variable_names=estimates.variable_names, + ) + + if priors is not None: + priors = make_variable_array( + priors, + # priors are data independent so datasets_ids is not passed here + variable_keys=estimates.variable_keys, + variable_names=estimates.variable_names, ) return dict( estimates=estimates, targets=targets, - variable_names=variable_names, - num_variables=len(variable_names), + priors=priors, ) diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index fd8512d0e..2a1ccb2c8 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -36,6 +36,8 @@ def prepare_plot_data( targets : dict[str, ndarray] or ndarray, optional (default = None) Ground truth values corresponding to the estimates. Must match the structure and dimensionality of `estimates` in terms of first and last axis. + variable_keys : list or None, optional, default: None + Select keys from the dictionary provided in samples. By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to act as a filter if dicts provided or actual variable names in case of array args num_col : int @@ -59,8 +61,14 @@ def prepare_plot_data( ) check_estimates_prior_shapes(plot_data["estimates"], plot_data["targets"]) + # store variable information at top level for easy access + variable_names = plot_data["estimates"].variable_names + num_variables = len(variable_names) + plot_data["variable_names"] = variable_names + plot_data["num_variables"] = num_variables + # Configure layout - num_row, num_col = set_layout(plot_data["num_variables"], num_row, num_col, stacked) + num_row, num_col = set_layout(num_variables, num_row, num_col, stacked) # Initialize figure fig, axes = make_figure(num_row, num_col, figsize=figsize)