Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bayesflow/utils/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def split_arrays(data: Mapping[str, np.ndarray], axis: int = -1) -> Mapping[str,
splits = [np.squeeze(split, axis=axis) for split in splits]

for i, split in enumerate(splits):
result[f"{key}_{i + 1}"] = split
result[f"{key}_{i}"] = split

return result

Expand Down Expand Up @@ -214,7 +214,7 @@ def make_variable_array(

# use default names if not otherwise specified
if variable_names is None:
variable_names = [f"${default_name}_{{{i}}}$" for i in range(x.shape[-1])]
variable_names = [f"{default_name}_{i}" for i in range(x.shape[-1])]

if dataset_ids is not None:
x = x[dataset_ids]
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def pytest_make_parametrize_id(config, val, argname):
return f"{argname}={repr(val)}"


@pytest.fixture(params=[2, 3], scope="session", autouse=True)
@pytest.fixture(params=[2, 3], scope="session")
def batch_size(request):
return request.param

Expand Down
31 changes: 29 additions & 2 deletions tests/test_diagnostics/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,33 @@
import numpy as np
import pytest


@pytest.fixture()
def num_samples():
return 1000
def var_names():
return [r"$\beta_0$", r"$\beta_1$", r"$\sigma$"]


@pytest.fixture()
def random_estimates():
return {
"beta": np.random.standard_normal(size=(32, 10, 2)),
"sigma": np.random.standard_normal(size=(32, 10, 1)),
}


@pytest.fixture()
def random_targets():
return {
"beta": np.random.standard_normal(size=(32, 2)),
"sigma": np.random.standard_normal(size=(32, 1)),
"y": np.random.standard_normal(size=(32, 3, 1)),
}


@pytest.fixture()
def random_priors():
return {
"beta": np.random.standard_normal(size=(64, 2)),
"sigma": np.random.standard_normal(size=(64, 1)),
"y": np.random.standard_normal(size=(64, 3, 1)),
}
Empty file.
49 changes: 49 additions & 0 deletions tests/test_diagnostics/test_diagnostics_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import bayesflow as bf


def num_variables(x: dict):
return sum(arr.shape[-1] for arr in x.values())


def test_metric_calibration_error(random_estimates, random_targets, var_names):
# basic functionality: automatic variable names
out = bf.diagnostics.metrics.calibration_error(random_estimates, random_targets)
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
assert out["values"].shape == (num_variables(random_estimates),)
assert out["metric_name"] == "Calibration Error"
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]

# user specified variable names
out = bf.diagnostics.metrics.calibration_error(
estimates=random_estimates,
targets=random_targets,
variable_names=var_names,
)
assert out["variable_names"] == var_names

# user-specifed keys and scalar variable
out = bf.diagnostics.metrics.calibration_error(
estimates=random_estimates,
targets=random_targets,
variable_keys="sigma",
)
assert out["values"].shape == (random_estimates["sigma"].shape[-1],)
assert out["variable_names"] == ["sigma"]


def test_posterior_contraction(random_estimates, random_targets):
# basic functionality: automatic variable names
out = bf.diagnostics.metrics.posterior_contraction(random_estimates, random_targets)
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
assert out["values"].shape == (num_variables(random_estimates),)
assert out["metric_name"] == "Posterior Contraction"
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]


def test_root_mean_squared_error(random_estimates, random_targets):
# basic functionality: automatic variable names
out = bf.diagnostics.metrics.root_mean_squared_error(random_estimates, random_targets)
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
assert out["values"].shape == (num_variables(random_estimates),)
assert out["metric_name"] == "NRMSE"
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
104 changes: 104 additions & 0 deletions tests/test_diagnostics/test_diagnostics_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import bayesflow as bf
import pytest


def num_variables(x: dict):
return sum(arr.shape[-1] for arr in x.values())


def test_calibration_ecdf(random_estimates, random_targets, var_names):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets)
assert len(out.axes) == num_variables(random_estimates)
assert out.axes[1].title._text == "beta_1"

# custom variable names
out = bf.diagnostics.plots.calibration_ecdf(
estimates=random_estimates,
targets=random_targets,
variable_names=var_names,
)
assert len(out.axes) == num_variables(random_estimates)
assert out.axes[1].title._text == "$\\beta_1$"

# subset of keys with a single scalar key
out = bf.diagnostics.plots.calibration_ecdf(
estimates=random_estimates, targets=random_targets, variable_keys="sigma"
)
assert len(out.axes) == random_estimates["sigma"].shape[-1]
assert out.axes[0].title._text == "sigma"

# use single array instead of dict of arrays as input
out = bf.diagnostics.plots.calibration_ecdf(
estimates=random_estimates["beta"],
targets=random_targets["beta"],
)
assert len(out.axes) == random_estimates["beta"].shape[-1]
# cannot infer the variable names from an array so default names are used
assert out.axes[1].title._text == "v_1"


def test_calibration_histogram(random_estimates, random_targets):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.calibration_histogram(random_estimates, random_targets)
assert len(out.axes) == num_variables(random_estimates)
assert out.axes[0].title._text == "beta_0"


def test_recovery(random_estimates, random_targets):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.recovery(random_estimates, random_targets)
assert len(out.axes) == num_variables(random_estimates)
assert out.axes[2].title._text == "sigma"


def test_z_score_contraction(random_estimates, random_targets):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.z_score_contraction(random_estimates, random_targets)
assert len(out.axes) == num_variables(random_estimates)
assert out.axes[1].title._text == "beta_1"


def test_pairs_samples(random_priors):
out = bf.diagnostics.plots.pairs_samples(
samples=random_priors,
variable_keys=["beta", "sigma"],
)
num_vars = random_priors["sigma"].shape[-1] + random_priors["beta"].shape[-1]
assert out.axes.shape == (num_vars, num_vars)
assert out.axes[0, 0].get_ylabel() == "beta_0"
assert out.axes[2, 2].get_xlabel() == "sigma"


def test_pairs_posterior(random_estimates, random_targets, random_priors):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.pairs_posterior(
random_estimates,
random_targets,
dataset_id=1,
)
num_vars = num_variables(random_estimates)
assert out.axes.shape == (num_vars, num_vars)
assert out.axes[0, 0].get_ylabel() == "beta_0"
assert out.axes[2, 2].get_xlabel() == "sigma"

# also plot priors
out = bf.diagnostics.plots.pairs_posterior(
estimates=random_estimates,
targets=random_targets,
priors=random_priors,
dataset_id=1,
)
num_vars = num_variables(random_estimates)
assert out.axes.shape == (num_vars, num_vars)
assert out.axes[0, 0].get_ylabel() == "beta_0"
assert out.axes[2, 2].get_xlabel() == "sigma"
assert out.figure.legends[0].get_texts()[0]._text == "Prior"

with pytest.raises(ValueError):
bf.diagnostics.plots.pairs_posterior(
estimates=random_estimates,
targets=random_targets,
priors=random_priors,
dataset_id=[1, 3],
)