Skip to content

Commit

Permalink
add logging for registering stratifications and observations (#305)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmudambi committed Jul 11, 2023
1 parent a78d223 commit ccbf97b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/vivarium/framework/results/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def register_stratification(
------
None
"""
self.logger.debug(f"Registering stratification {name}")
target_columns = list(requires_columns) + list(requires_values)
self._results_context.add_stratification(
name, target_columns, categories, mapper, is_vectorized
Expand Down Expand Up @@ -189,6 +190,7 @@ def register_observation(
excluded_stratifications: List[str] = (),
when: str = "collect_metrics",
) -> None:
self.logger.debug(f"Registering observation {name}")
self._warn_check_stratifications(additional_stratifications, excluded_stratifications)
self._results_context.add_observation(
name,
Expand Down
15 changes: 13 additions & 2 deletions tests/framework/results/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ def test_register_observation(
assert len(interface._manager._results_context.observations) == 1


def test_register_observations():
def test_register_observations(mocker):
mgr = ResultsManager()
interface = ResultsInterface(mgr)
builder = mocker.Mock()
builder.configuration.stratification.default = []
mgr.setup(builder)

assert len(interface._manager._results_context.observations) == 0
interface.register_observation(
"living_person_time",
Expand Down Expand Up @@ -121,9 +125,13 @@ def test_register_observations():
assert len(interface._manager._results_context.observations) == 2


def test_unhashable_pipeline():
def test_unhashable_pipeline(mocker):
mgr = ResultsManager()
interface = ResultsInterface(mgr)
builder = mocker.Mock()
builder.configuration.stratification.default = []
mgr.setup(builder)

assert len(interface._manager._results_context.observations) == 0
with pytest.raises(TypeError, match="unhashable"):
interface.register_observation(
Expand Down Expand Up @@ -158,6 +166,9 @@ def test_integration_full_observation(mocker):
# Create interface
mgr = ResultsManager()
results_interface = ResultsInterface(mgr)
builder = mocker.Mock()
builder.configuration.stratification.default = []
mgr.setup(builder)

# register stratifications
results_interface.register_stratification("house", HOUSES, None, True, ["house"], [])
Expand Down

0 comments on commit ccbf97b

Please sign in to comment.