Skip to content

Commit

Permalink
Fix bug in gather_results, fix typing notation (#269)
Browse files Browse the repository at this point in the history
- *Category*: bugfix
- *JIRA issue*: [MIC-3684](https://jira.ihme.washington.edu/browse/MIC-3684)

Mends two bugs found in the development of the `DiseaseObserver` where the population was being improperly refined during `gather_results` and the `ResultsContext` default stratifications were improperly set when used by `ResultsStratifier`. Also updates return type annotation that was incomplete for the Stratification default mapper.

Testing
`DiseaseObserver` works with this code.
  • Loading branch information
mattkappel committed Jan 25, 2023
1 parent 59dfecd commit b495fa1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def set_default_stratifications(self, default_grouping_columns: List[str]):
"Attempting to set an empty list as the default grouping columns "
"for results production."
)
self._default_grouping_columns = default_grouping_columns
self._default_stratifications = default_grouping_columns

def add_stratification(
self,
Expand Down Expand Up @@ -106,8 +106,10 @@ def gather_results(self, population: pd.DataFrame, event_name: str) -> Dict[str,
# Results production can be simplified to
# filter -> groupby -> aggregate in all situations we've seen.
if pop_filter:
population = population.query(pop_filter)
pop_groups = population.groupby(list(stratifications))
filtered_pop = population.query(pop_filter)
else:
filtered_pop = population
pop_groups = filtered_pop.groupby(list(stratifications))
for measure, aggregator_sources, aggregator, additional_keys in observations:
if aggregator_sources:
aggregates = pop_groups[aggregator_sources].apply(aggregator).fillna(0.0)
Expand Down
2 changes: 1 addition & 1 deletion src/vivarium/framework/results/stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ def __call__(self, population: pd.DataFrame) -> pd.DataFrame:
return population

@staticmethod
def _default_mapper(pop: pd.DataFrame) -> str:
def _default_mapper(pop: pd.DataFrame) -> pd.Series:
return pop.squeeze(axis=1)

0 comments on commit b495fa1

Please sign in to comment.