Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bayesflow/simulators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@
from .simulator import Simulator

from .benchmark_simulators import (
BernoulliGLM,
BernoulliGLMRaw,
GaussianLinear,
GaussianLinearUniform,
GaussianMixture,
InverseKinematics,
LotkaVolterra,
SIR,
SLCP,
SLCPDistractors,
TwoMoons,
)

Expand Down
8 changes: 8 additions & 0 deletions bayesflow/simulators/benchmark_simulators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from .bernoulli_glm import BernoulliGLM
from .bernoulli_glm_raw import BernoulliGLMRaw
from .gaussian_linear import GaussianLinear
from .gaussian_linear_uniform import GaussianLinearUniform
from .gaussian_mixture import GaussianMixture
from .inverse_kinematics import InverseKinematics
from .lotka_volterra import LotkaVolterra
from .sir import SIR
from .slcp import SLCP
from .slcp_distractors import SLCPDistractors
from .two_moons import TwoMoons
9 changes: 7 additions & 2 deletions bayesflow/simulators/benchmark_simulators/gaussian_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,10 @@
# Generate prior predictive samples, possibly a single if n_obs is None
if self.n_obs is None:
return self.rng.normal(loc=params, scale=self.obs_scale)
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
return np.transpose(x, (1, 0, 2))
if params.ndim == 2:
# batched sampling with n_obs
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
return np.transpose(x, (1, 0, 2))

Check warning on line 81 in bayesflow/simulators/benchmark_simulators/gaussian_linear.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/simulators/benchmark_simulators/gaussian_linear.py#L80-L81

Added lines #L80 - L81 were not covered by tests
elif params.ndim == 1:
# non-batched sampling with n_obs
return self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0]))
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,10 @@
# Generate prior predictive samples, possibly a single if n_obs is None
if self.n_obs is None:
return self.rng.normal(loc=params, scale=self.obs_scale)
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
return np.transpose(x, (1, 0, 2))
if params.ndim == 2:
# batched sampling with n_obs
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
return np.transpose(x, (1, 0, 2))

Check warning on line 85 in bayesflow/simulators/benchmark_simulators/gaussian_linear_uniform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/simulators/benchmark_simulators/gaussian_linear_uniform.py#L84-L85

Added lines #L84 - L85 were not covered by tests
elif params.ndim == 1:
# non-batched sampling with n_obs
return self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0]))
128 changes: 123 additions & 5 deletions tests/test_simulators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,97 @@ def use_squeezed(request):
return request.param


@pytest.fixture()
def bernoulli_glm():
from bayesflow.simulators import BernoulliGLM

return BernoulliGLM()


@pytest.fixture()
def bernoulli_glm_raw():
from bayesflow.simulators import BernoulliGLMRaw

return BernoulliGLMRaw()


@pytest.fixture()
def gaussian_linear():
from bayesflow.simulators import GaussianLinear

return GaussianLinear()


@pytest.fixture()
def gaussian_linear_n_obs():
from bayesflow.simulators import GaussianLinear

return GaussianLinear(n_obs=5)


@pytest.fixture()
def gaussian_linear_uniform():
from bayesflow.simulators import GaussianLinearUniform

return GaussianLinearUniform()


@pytest.fixture()
def gaussian_linear_uniform_n_obs():
from bayesflow.simulators import GaussianLinearUniform

return GaussianLinearUniform(n_obs=5)


@pytest.fixture(
params=["gaussian_linear", "gaussian_linear_n_obs", "gaussian_linear_uniform", "gaussian_linear_uniform_n_obs"]
)
def gaussian_linear_simulator(request):
return request.getfixturevalue(request.param)


@pytest.fixture()
def gaussian_mixture():
from bayesflow.simulators import GaussianMixture

return GaussianMixture()


@pytest.fixture()
def inverse_kinematics():
from bayesflow.simulators import InverseKinematics

return InverseKinematics()


@pytest.fixture()
def lotka_volterra():
from bayesflow.simulators import LotkaVolterra

return LotkaVolterra()


@pytest.fixture()
def sir():
from bayesflow.simulators import SIR

return SIR()


@pytest.fixture()
def slcp():
from bayesflow.simulators import SLCP

return SLCP()


@pytest.fixture()
def slcp_distractors():
from bayesflow.simulators import SLCPDistractors

return SLCPDistractors()


@pytest.fixture()
def composite_two_moons():
from bayesflow.simulators import make_simulator
Expand All @@ -40,13 +131,40 @@ def observables(parameters):
return make_simulator([parameters, observables])


@pytest.fixture(params=["composite_two_moons", "two_moons"])
def simulator(request):
return request.getfixturevalue(request.param)


@pytest.fixture()
def two_moons():
from bayesflow.simulators import TwoMoons

return TwoMoons()


@pytest.fixture(
params=[
"composite_two_moons",
"two_moons",
]
)
def two_moons_simulator(request):
return request.getfixturevalue(request.param)


@pytest.fixture(
params=[
"bernoulli_glm",
"bernoulli_glm_raw",
"gaussian_linear",
"gaussian_linear_n_obs",
"gaussian_linear_uniform",
"gaussian_linear_uniform_n_obs",
"gaussian_mixture",
"inverse_kinematics",
"lotka_volterra",
"sir",
"slcp",
"slcp_distractors",
"composite_two_moons",
"two_moons",
]
)
def simulator(request):
return request.getfixturevalue(request.param)
12 changes: 10 additions & 2 deletions tests/test_simulators/test_simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import numpy as np


def test_two_moons(simulator, batch_size):
samples = simulator.sample((batch_size,))
def test_two_moons(two_moons_simulator, batch_size):
samples = two_moons_simulator.sample((batch_size,))

assert isinstance(samples, dict)
assert list(samples.keys()) == ["parameters", "observables"]
Expand All @@ -13,6 +13,14 @@ def test_two_moons(simulator, batch_size):
assert samples["observables"].shape == (batch_size, 2)


def test_gaussian_linear(gaussian_linear_simulator, batch_size):
samples = gaussian_linear_simulator.sample((batch_size,))

# test n_obs respected if applicable
if hasattr(gaussian_linear_simulator, "n_obs") and isinstance(gaussian_linear_simulator.n_obs, int):
assert samples["observables"].shape[1] == gaussian_linear_simulator.n_obs


def test_sample(simulator, batch_size):
samples = simulator.sample((batch_size,))

Expand Down