diff --git a/bayesflow/utils/dict_utils.py b/bayesflow/utils/dict_utils.py index 3f755ba34..db80f971a 100644 --- a/bayesflow/utils/dict_utils.py +++ b/bayesflow/utils/dict_utils.py @@ -124,7 +124,7 @@ def split_arrays(data: Mapping[str, np.ndarray], axis: int = -1) -> Mapping[str, splits = [np.squeeze(split, axis=axis) for split in splits] for i, split in enumerate(splits): - result[f"{key}_{i + 1}"] = split + result[f"{key}_{i}"] = split return result @@ -214,7 +214,7 @@ def make_variable_array( # use default names if not otherwise specified if variable_names is None: - variable_names = [f"${default_name}_{{{i}}}$" for i in range(x.shape[-1])] + variable_names = [f"{default_name}_{i}" for i in range(x.shape[-1])] if dataset_ids is not None: x = x[dataset_ids] diff --git a/tests/conftest.py b/tests/conftest.py index 54a067d5d..5e6598ba8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,7 @@ def pytest_make_parametrize_id(config, val, argname): return f"{argname}={repr(val)}" -@pytest.fixture(params=[2, 3], scope="session", autouse=True) +@pytest.fixture(params=[2, 3], scope="session") def batch_size(request): return request.param diff --git a/tests/test_diagnostics/conftest.py b/tests/test_diagnostics/conftest.py index 51405e1e9..52e3e343a 100644 --- a/tests/test_diagnostics/conftest.py +++ b/tests/test_diagnostics/conftest.py @@ -1,6 +1,33 @@ +import numpy as np import pytest @pytest.fixture() -def num_samples(): - return 1000 +def var_names(): + return [r"$\beta_0$", r"$\beta_1$", r"$\sigma$"] + + +@pytest.fixture() +def random_estimates(): + return { + "beta": np.random.standard_normal(size=(32, 10, 2)), + "sigma": np.random.standard_normal(size=(32, 10, 1)), + } + + +@pytest.fixture() +def random_targets(): + return { + "beta": np.random.standard_normal(size=(32, 2)), + "sigma": np.random.standard_normal(size=(32, 1)), + "y": np.random.standard_normal(size=(32, 3, 1)), + } + + +@pytest.fixture() +def random_priors(): + return { + "beta": np.random.standard_normal(size=(64, 2)), + "sigma": np.random.standard_normal(size=(64, 1)), + "y": np.random.standard_normal(size=(64, 3, 1)), + } diff --git a/tests/test_diagnostics/test_diagnostics.py b/tests/test_diagnostics/test_diagnostics.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_diagnostics/test_diagnostics_metrics.py b/tests/test_diagnostics/test_diagnostics_metrics.py new file mode 100644 index 000000000..62b0b4ca1 --- /dev/null +++ b/tests/test_diagnostics/test_diagnostics_metrics.py @@ -0,0 +1,49 @@ +import bayesflow as bf + + +def num_variables(x: dict): + return sum(arr.shape[-1] for arr in x.values()) + + +def test_metric_calibration_error(random_estimates, random_targets, var_names): + # basic functionality: automatic variable names + out = bf.diagnostics.metrics.calibration_error(random_estimates, random_targets) + assert list(out.keys()) == ["values", "metric_name", "variable_names"] + assert out["values"].shape == (num_variables(random_estimates),) + assert out["metric_name"] == "Calibration Error" + assert out["variable_names"] == ["beta_0", "beta_1", "sigma"] + + # user specified variable names + out = bf.diagnostics.metrics.calibration_error( + estimates=random_estimates, + targets=random_targets, + variable_names=var_names, + ) + assert out["variable_names"] == var_names + + # user-specifed keys and scalar variable + out = bf.diagnostics.metrics.calibration_error( + estimates=random_estimates, + targets=random_targets, + variable_keys="sigma", + ) + assert out["values"].shape == (random_estimates["sigma"].shape[-1],) + assert out["variable_names"] == ["sigma"] + + +def test_posterior_contraction(random_estimates, random_targets): + # basic functionality: automatic variable names + out = bf.diagnostics.metrics.posterior_contraction(random_estimates, random_targets) + assert list(out.keys()) == ["values", "metric_name", "variable_names"] + assert out["values"].shape == (num_variables(random_estimates),) + assert out["metric_name"] == "Posterior Contraction" + assert out["variable_names"] == ["beta_0", "beta_1", "sigma"] + + +def test_root_mean_squared_error(random_estimates, random_targets): + # basic functionality: automatic variable names + out = bf.diagnostics.metrics.root_mean_squared_error(random_estimates, random_targets) + assert list(out.keys()) == ["values", "metric_name", "variable_names"] + assert out["values"].shape == (num_variables(random_estimates),) + assert out["metric_name"] == "NRMSE" + assert out["variable_names"] == ["beta_0", "beta_1", "sigma"] diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py new file mode 100644 index 000000000..97e21b2a2 --- /dev/null +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -0,0 +1,104 @@ +import bayesflow as bf +import pytest + + +def num_variables(x: dict): + return sum(arr.shape[-1] for arr in x.values()) + + +def test_calibration_ecdf(random_estimates, random_targets, var_names): + # basic functionality: automatic variable names + out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[1].title._text == "beta_1" + + # custom variable names + out = bf.diagnostics.plots.calibration_ecdf( + estimates=random_estimates, + targets=random_targets, + variable_names=var_names, + ) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[1].title._text == "$\\beta_1$" + + # subset of keys with a single scalar key + out = bf.diagnostics.plots.calibration_ecdf( + estimates=random_estimates, targets=random_targets, variable_keys="sigma" + ) + assert len(out.axes) == random_estimates["sigma"].shape[-1] + assert out.axes[0].title._text == "sigma" + + # use single array instead of dict of arrays as input + out = bf.diagnostics.plots.calibration_ecdf( + estimates=random_estimates["beta"], + targets=random_targets["beta"], + ) + assert len(out.axes) == random_estimates["beta"].shape[-1] + # cannot infer the variable names from an array so default names are used + assert out.axes[1].title._text == "v_1" + + +def test_calibration_histogram(random_estimates, random_targets): + # basic functionality: automatic variable names + out = bf.diagnostics.plots.calibration_histogram(random_estimates, random_targets) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[0].title._text == "beta_0" + + +def test_recovery(random_estimates, random_targets): + # basic functionality: automatic variable names + out = bf.diagnostics.plots.recovery(random_estimates, random_targets) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[2].title._text == "sigma" + + +def test_z_score_contraction(random_estimates, random_targets): + # basic functionality: automatic variable names + out = bf.diagnostics.plots.z_score_contraction(random_estimates, random_targets) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[1].title._text == "beta_1" + + +def test_pairs_samples(random_priors): + out = bf.diagnostics.plots.pairs_samples( + samples=random_priors, + variable_keys=["beta", "sigma"], + ) + num_vars = random_priors["sigma"].shape[-1] + random_priors["beta"].shape[-1] + assert out.axes.shape == (num_vars, num_vars) + assert out.axes[0, 0].get_ylabel() == "beta_0" + assert out.axes[2, 2].get_xlabel() == "sigma" + + +def test_pairs_posterior(random_estimates, random_targets, random_priors): + # basic functionality: automatic variable names + out = bf.diagnostics.plots.pairs_posterior( + random_estimates, + random_targets, + dataset_id=1, + ) + num_vars = num_variables(random_estimates) + assert out.axes.shape == (num_vars, num_vars) + assert out.axes[0, 0].get_ylabel() == "beta_0" + assert out.axes[2, 2].get_xlabel() == "sigma" + + # also plot priors + out = bf.diagnostics.plots.pairs_posterior( + estimates=random_estimates, + targets=random_targets, + priors=random_priors, + dataset_id=1, + ) + num_vars = num_variables(random_estimates) + assert out.axes.shape == (num_vars, num_vars) + assert out.axes[0, 0].get_ylabel() == "beta_0" + assert out.axes[2, 2].get_xlabel() == "sigma" + assert out.figure.legends[0].get_texts()[0]._text == "Prior" + + with pytest.raises(ValueError): + bf.diagnostics.plots.pairs_posterior( + estimates=random_estimates, + targets=random_targets, + priors=random_priors, + dataset_id=[1, 3], + )