Skip to content

Commit

Permalink
Mend get_value unhashable argument for Results Manger add_observation()
Browse files Browse the repository at this point in the history
#260

- *Category*: bugfix
- *JIRA issue*: [MIC-3761](https://jira.ihme.washington.edu/browse/MIC-3761)

Changes
- Mends the call to `get_value` to go per item in the list of targets.
- Adds a test case to `test_interface.py` to check for this bug.
- Adds a new test that passes in arguments for the `requires_values` that trigger the error case.
- Refactors a mocked `get_value` call to be in the top-level for results framework tests.
- Mends `add_observation` call with incorrect number of arguments (effect on the actual test was negligible).

Testing
All tests pass.
  • Loading branch information
mattkappel committed Jan 11, 2023
1 parent 3cb1c9e commit 8b8b3e4
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/vivarium/framework/results/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def register_observation(
pop_filter: str = "",
aggregator_sources: List[str] = None,
aggregator: Callable[[pd.DataFrame], float] = len,
requires_columns: List[str] = None,
requires_values: List[str] = None,
requires_columns: List[str] = (),
requires_values: List[str] = (),
additional_stratifications: List[str] = (),
excluded_stratifications: List[str] = (),
when: str = "collect_metrics",
Expand Down
6 changes: 3 additions & 3 deletions src/vivarium/framework/results/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def register_observation(
pop_filter: str,
aggregator_sources: List[str],
aggregator: Callable,
requires_columns: List[str] = None,
requires_values: List[str] = None,
requires_columns: List[str] = (),
requires_values: List[str] = (),
additional_stratifications: List[str] = (),
excluded_stratifications: List[str] = (),
when: str = "collect_metrics",
Expand All @@ -198,7 +198,7 @@ def _add_resources(self, target: List[str], target_type: SourceType):
if target_type == SourceType.COLUMN:
self._required_columns.update(target)
elif target_type == SourceType.VALUE:
self._required_values.update(self.get_value(target))
self._required_values.update([self.get_value(t) for t in target])

def _prepare_population(self, event: Event):
population = self.population_view.subview(list(self._required_columns)).get(
Expand Down
7 changes: 7 additions & 0 deletions tests/framework/results/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,10 @@ def verify_stratification_added(
matching_stratification_found = True
break
return matching_stratification_found


# Mock for get_value call for Pipelines, returns a str instead of a Pipeline
def mock_get_value(self, name: str):
if not isinstance(name, str):
raise TypeError("Passed a non-string type to mock get_value(), check your pipelines.")
return name
51 changes: 43 additions & 8 deletions tests/framework/results/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .mocks import BASE_POPULATION
from .mocks import CATEGORIES as HOUSES
from .mocks import FAMILIARS
from .mocks import FAMILIARS, mock_get_value


def _silly_aggregator(_: pd.DataFrame) -> float:
Expand All @@ -26,8 +26,8 @@ def _silly_aggregator(_: pd.DataFrame) -> float:
'alive == "alive" and undead == False',
[],
_silly_aggregator,
None,
None,
[],
[],
[],
[],
"collect_metrics",
Expand All @@ -37,16 +37,28 @@ def _silly_aggregator(_: pd.DataFrame) -> float:
"undead == True",
[],
_silly_aggregator,
None,
None,
[],
[],
[],
[],
"time_step__prepare",
),
(
"undead_person_time",
"undead == True",
[],
_silly_aggregator,
[],
["fake_pipeline", "another_fake_pipeline"],
[],
[],
"time_step__prepare",
),
],
ids=["valid_on_collect_metrics", "valid_on_time_step__prepare"],
ids=["valid_on_collect_metrics", "valid_on_time_step__prepare", "valid_pipelines"],
)
def test_register_observation(
mocker,
name,
pop_filter,
aggregator_columns,
Expand All @@ -59,13 +71,19 @@ def test_register_observation(
):
mgr = ResultsManager()
interface = ResultsInterface(mgr)
# interface.set_default_stratifications(["age", "sex"])
builder = mocker.Mock()
# Set up mock builder with mocked get_value call for Pipelines
mocker.patch.object(builder, "value.get_value")
builder.value.get_value = MethodType(mock_get_value, builder)
mgr.setup(builder)
assert len(interface._manager._results_context._observations) == 0
interface.register_observation(
name,
pop_filter,
aggregator_columns,
aggregator,
requires_columns,
requires_values,
additional_stratifications,
excluded_stratifications,
)
Expand All @@ -75,7 +93,6 @@ def test_register_observation(
def test_register_observations():
mgr = ResultsManager()
interface = ResultsInterface(mgr)
# interface.set_default_stratifications(["age", "sex"])
assert len(interface._manager._results_context._observations) == 0
interface.register_observation(
"living_person_time",
Expand Down Expand Up @@ -103,6 +120,24 @@ def test_register_observations():
assert len(interface._manager._results_context._observations) == 2


def test_unhashable_pipeline():
mgr = ResultsManager()
interface = ResultsInterface(mgr)
assert len(interface._manager._results_context._observations) == 0
with pytest.raises(TypeError, match="unhashable"):
interface.register_observation(
"living_person_time",
'alive == "alive" and undead == False',
[],
_silly_aggregator,
[],
[["bad", "unhashable", "thing"]], # unhashable first element
[],
[],
"collect_metrics",
)


def mock__prepare_population(self, event):
"""Return a mock population in the vein of ResultsManager._prepare_population"""
# Generate population DataFrame
Expand Down
7 changes: 1 addition & 6 deletions tests/framework/results/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,12 @@
CATEGORIES,
NAME,
SOURCES,
mock_get_value,
sorting_hat_serial,
sorting_hat_vector,
verify_stratification_added,
)


# Mock for get_value call for Pipelines, returns a str instead of a Pipeline
def mock_get_value(self, name: str):
return name


#######################################
# Tests for `register_stratification` #
#######################################
Expand Down

0 comments on commit 8b8b3e4

Please sign in to comment.