Skip to content

Commit

Permalink
Allow zero stratifications (#296)
Browse files Browse the repository at this point in the history
    Category: feature
    JIRA issue: MIC-3774

Allow zero stratifications when running gather_results. Add test.

Testing
Ran CIFF SAM for a few time steps and checked that the categorical risk observer was recording child stunting person time properly without any stratifications. All tests pass and new test was written.
  • Loading branch information
hussain-jafari committed May 24, 2023
1 parent a347678 commit 85dc1b0
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
23 changes: 16 additions & 7 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ def set_default_stratifications(self, default_grouping_columns: List[str]):
"Multiple calls are being made to set default grouping columns "
"for results production."
)
if not default_grouping_columns:
raise ResultsConfigurationError(
"Attempting to set an empty list as the default grouping columns "
"for results production."
)
self.default_stratifications = default_grouping_columns

def add_stratification(
Expand Down Expand Up @@ -125,7 +120,11 @@ def gather_results(self, population: pd.DataFrame, event_name: str) -> Dict[str,
if filtered_pop.empty:
yield {}
else:
pop_groups = filtered_pop.groupby(list(stratifications))
if not len(list(stratifications)): # Handle situation of no stratifications
pop_groups = filtered_pop.groupby(lambda _: True)
else:
pop_groups = filtered_pop.groupby(list(stratifications))

for measure, aggregator_sources, aggregator, additional_keys in observations:
if aggregator_sources:
aggregates = (
Expand All @@ -144,7 +143,12 @@ def gather_results(self, population: pd.DataFrame, event_name: str) -> Dict[str,
)

# Keep formatting all in one place.
yield self._format_results(measure, aggregates, **additional_keys)
yield self._format_results(
measure,
aggregates,
bool(len(list(stratifications))),
**additional_keys,
)

def _get_stratifications(
self,
Expand All @@ -162,9 +166,14 @@ def _get_stratifications(
def _format_results(
measure: str,
aggregates: pd.Series,
has_stratifications: bool,
**additional_keys: str,
) -> Dict[str, float]:
results = {}
# Simpler formatting if we don't have stratifications
if not has_stratifications:
return {measure: aggregates.squeeze()}

# First we expand the categorical index over unobserved pairs.
# This ensures that the produced results are always the same length.
if isinstance(aggregates.index, pd.MultiIndex):
Expand Down
22 changes: 21 additions & 1 deletion tests/framework/results/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,32 @@ def test_gather_results_with_empty_pop_filter():
assert len(result) == 0


def test_gather_results_with_no_stratifications():
"""Test case where we have no stratifications. gather_results should return one value."""
ctx = ResultsContext()

# Generate population DataFrame
population = BASE_POPULATION.copy()

event_name = "collect_metrics"
ctx.add_observation(
name="wizard_count",
pop_filter="",
aggregator_sources=None,
aggregator=len,
event_name=event_name,
)

assert len(ctx.stratifications) == 0
assert len(list(ctx.gather_results(population, event_name))) == 1


def test__format_results():
"""Test that format results produces the expected number of keys and a specific expected key"""
ctx = ResultsContext()
aggregates = BASE_POPULATION.groupby(["house", "familiar"]).apply(len)
measure = "wizard_count"
rv = ctx._format_results(measure, aggregates)
rv = ctx._format_results(measure, aggregates, has_stratifications=True)

# Check that the number of expected data column names are there
expected_keys_len = len(CATEGORIES) * len(FAMILIARS)
Expand Down

0 comments on commit 85dc1b0

Please sign in to comment.