Skip to content

Commit

Permalink
Merge pull request #304 from ihmeuw/develop
Browse files Browse the repository at this point in the history
Release v1.2.0
  • Loading branch information
stevebachmeier committed Jun 1, 2023
2 parents 53b7d1f + a78d223 commit 4f14dad
Show file tree
Hide file tree
Showing 18 changed files with 136 additions and 64 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- name: Install dependencies
run: |
python --version
Expand Down
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
**1.2.0 - 06/01/23**

- Stop supporting Python 3.7 and start supporting 3.11
- Bugfix to allow for zero stratifications
- Removes ignore filters for known FutureWarnings
- Refactor location of default stratification definition
- Bugfix to stop shuffling simulants when drawing common random number

**1.1.0 - 05/03/23**

- Clean up randomness system
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Vivarium
Vivarium is a simulation framework written using standard scientific Python
tools.

**Vivarium requires Python 3.7-3.10 to run**
**Vivarium requires Python 3.8-3.11 to run**

You can install ``vivarium`` from PyPI with pip:

Expand Down
19 changes: 11 additions & 8 deletions docs/source/tutorials/exploration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ configuration by simply printing it.
component_configs: True
extrapolate:
component_configs: True
stratification:
default:
component_configs: []


What do we see here? The configuration is *hierarchical*. There are a set of
Expand Down Expand Up @@ -300,20 +303,20 @@ the population as a whole.
alive 100000
Name: alive, dtype: int64
count 100000.000000
mean 0.500602
std 0.288434
min 0.000022
25% 0.251288
50% 0.499957
75% 0.749816
mean 0.499756
std 0.288412
min 0.000015
25% 0.251550
50% 0.497587
75% 0.749215
max 0.999978
Name: child_wasting_propensity, dtype: float64
susceptible_to_lower_respiratory_infections 100000
Name: lower_respiratory_infections, dtype: int64
2021-12-31 12:00:00 100000
Name: entrance_time, dtype: int64
Male 50162
Female 49838
Male 50185
Female 49815
Name: sex, dtype: int64
True 100000
Name: tracked, dtype: int64
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

min_version, max_version = ((3, 6), "3.6"), ((3, 10), "3.10")
min_version, max_version = ((3, 8), "3.8"), ((3, 11), "3.11")

if not (min_version[0] <= sys.version_info[:2] <= max_version[0]):
# Python 3.5 does not support f-strings
Expand Down
2 changes: 1 addition & 1 deletion src/vivarium/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
__summary__ = "vivarium is a microsimulation framework built on top of the standard scientific python stack."
__uri__ = "https://github.com/ihmeuw/vivarium"

__version__ = "1.1.0"
__version__ = "1.2.0"

__author__ = "The vivarium developers"
__email__ = "vivarium.dev@gmail.com"
Expand Down
21 changes: 16 additions & 5 deletions src/vivarium/framework/randomness/index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ def __init__(self, key_columns: List[str] = None, size: int = 1_000_000):
self._map = None
self._size = size

def update(self, new_keys: pd.DataFrame) -> None:
def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None:
"""Adds the new keys to the mapping.
Parameters
----------
new_keys
A pandas DataFrame indexed by the simulant index and columns corresponding to
the randomness system key columns.
clock_time
The simulation clock time. Used as the salt during hashing to
minimize inter-simulation collisions.
"""
if new_keys.empty or not self._use_crn:
Expand All @@ -49,10 +52,13 @@ def update(self, new_keys: pd.DataFrame) -> None:
if len(final_keys) != len(final_keys.unique()):
raise RandomnessError("Non-unique keys in index")

final_mapping = self._build_final_mapping(new_mapping_index)
final_mapping = self._build_final_mapping(new_mapping_index, clock_time)

# Tack on the simulant index to the front of the map.
final_mapping.index = final_mapping_index
final_mapping.index = final_mapping.index.join(final_mapping_index).reorder_levels(
[self.SIM_INDEX_COLUMN] + self._key_columns
)
final_mapping = final_mapping.sort_index(level=self.SIM_INDEX_COLUMN)
self._map = final_mapping

def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.MultiIndex]:
Expand Down Expand Up @@ -84,7 +90,9 @@ def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.Mul
final_mapping_index = self._map.index.append(new_mapping_index)
return new_mapping_index, final_mapping_index

def _build_final_mapping(self, new_mapping_index: pd.Index) -> pd.Series:
def _build_final_mapping(
self, new_mapping_index: pd.Index, clock_time: pd.Timestamp
) -> pd.Series:
"""Builds a new mapping between key columns and the randomness index from the
new mapping index and the existing map.
Expand All @@ -93,6 +101,9 @@ def _build_final_mapping(self, new_mapping_index: pd.Index) -> pd.Series:
new_mapping_index
An index with a level for the index assigned by the population system and
additional levels for the key columns associated with the simulant index.
clock_time
The simulation clock time. Used as the salt during hashing to
minimize inter-simulation collisions.
Returns
-------
Expand All @@ -102,7 +113,7 @@ def _build_final_mapping(self, new_mapping_index: pd.Index) -> pd.Series:
"""
new_key_index = new_mapping_index.droplevel(self.SIM_INDEX_COLUMN)
mapping_update = self._hash(new_key_index)
mapping_update = self._hash(new_key_index, salt=clock_time)
if self._map is None:
current_map = mapping_update
else:
Expand Down
2 changes: 1 addition & 1 deletion src/vivarium/framework/randomness/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def register_simulants(self, simulants: pd.DataFrame):
raise RandomnessError(
"The simulants dataframe does not have all specified key_columns."
)
self._key_mapping.update(simulants.loc[:, self._key_columns])
self._key_mapping.update(simulants.loc[:, self._key_columns], self._clock())

def __str__(self):
return "RandomnessManager()"
Expand Down
23 changes: 16 additions & 7 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ def set_default_stratifications(self, default_grouping_columns: List[str]):
"Multiple calls are being made to set default grouping columns "
"for results production."
)
if not default_grouping_columns:
raise ResultsConfigurationError(
"Attempting to set an empty list as the default grouping columns "
"for results production."
)
self.default_stratifications = default_grouping_columns

def add_stratification(
Expand Down Expand Up @@ -125,7 +120,11 @@ def gather_results(self, population: pd.DataFrame, event_name: str) -> Dict[str,
if filtered_pop.empty:
yield {}
else:
pop_groups = filtered_pop.groupby(list(stratifications))
if not len(list(stratifications)): # Handle situation of no stratifications
pop_groups = filtered_pop.groupby(lambda _: True)
else:
pop_groups = filtered_pop.groupby(list(stratifications))

for measure, aggregator_sources, aggregator, additional_keys in observations:
if aggregator_sources:
aggregates = (
Expand All @@ -144,7 +143,12 @@ def gather_results(self, population: pd.DataFrame, event_name: str) -> Dict[str,
)

# Keep formatting all in one place.
yield self._format_results(measure, aggregates, **additional_keys)
yield self._format_results(
measure,
aggregates,
bool(len(list(stratifications))),
**additional_keys,
)

def _get_stratifications(
self,
Expand All @@ -162,9 +166,14 @@ def _get_stratifications(
def _format_results(
measure: str,
aggregates: pd.Series,
has_stratifications: bool,
**additional_keys: str,
) -> Dict[str, float]:
results = {}
# Simpler formatting if we don't have stratifications
if not has_stratifications:
return {measure: aggregates.squeeze()}

# First we expand the categorical index over unobserved pairs.
# This ensures that the produced results are always the same length.
if isinstance(aggregates.index, pd.MultiIndex):
Expand Down
11 changes: 10 additions & 1 deletion src/vivarium/framework/results/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class ResultsManager:
`collect_metrics`).
"""

configuration_defaults = {
"stratification": {
"default": [],
}
}

def __init__(self):
self._metrics = Counter()
self._results_context = ResultsContext()
Expand Down Expand Up @@ -59,6 +65,8 @@ def setup(self, builder: "Builder"):

self.get_value = builder.value.get_value

self.set_default_stratifications(builder)

builder.value.register_value_modifier("metrics", self.get_results)

def on_time_step_prepare(self, event: Event):
Expand All @@ -78,7 +86,8 @@ def gather_results(self, event_name: str, event: Event):
for results_group in self._results_context.gather_results(population, event_name):
self._metrics.update(results_group)

def set_default_stratifications(self, default_stratifications: List[str]):
def set_default_stratifications(self, builder):
default_stratifications = builder.configuration.stratification.default
self._results_context.set_default_stratifications(default_stratifications)

def register_stratification(
Expand Down
11 changes: 6 additions & 5 deletions src/vivarium/framework/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,7 @@ def _groupby_new_state(
"""
output_map = {o: i for i, o in enumerate(outputs)}
# TODO: fix rather than suppress this FutureWarning
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
groups = pd.Series(index).groupby([output_map[d] for d in decisions])
groups = pd.Series(index).groupby([output_map[d] for d in decisions])
results = [(outputs[i], pd.Index(sub_group.values)) for i, sub_group in groups]
selected_outputs = [o for o, _ in results]
for output in outputs:
Expand Down Expand Up @@ -212,6 +209,10 @@ def sub_components(self) -> List:
def setup(self, builder: "Builder") -> None:
pass

def set_model(self, model_name: str) -> None:
"""Defines the column name for the model this state belongs to"""
self._model = model_name

def next_state(
self, index: pd.Index, event_time: "Time", population_view: "PopulationView"
) -> None:
Expand Down Expand Up @@ -488,7 +489,7 @@ def setup(self, builder: "Builder") -> None:
def add_states(self, states: Iterable[State]) -> None:
for state in states:
self.states.append(state)
state._model = self.state_column
state.set_model(self.state_column)

def transition(self, index: pd.Index, event_time: "Time") -> None:
"""Finds the population in each state and moves them to the next state.
Expand Down
30 changes: 6 additions & 24 deletions src/vivarium/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,14 @@ def __init__(
if self.key_columns:
# Since there are key_columns we need to group the table by those
# columns to get the sub-tables to fit
# TODO: fix rather than suppress this FutureWarning
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
sub_tables = self.data.groupby(list(self.key_columns))
sub_tables = self.data.groupby(list(self.key_columns))
else:
# There are no key columns so we will fit the whole table
sub_tables = {None: self.data}.items()

self.interpolations = {}

# TODO: fix rather than suppress this FutureWarning
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
sub_tables = list(sub_tables)
sub_tables = list(sub_tables)
for key, base_table in sub_tables:
if (
base_table.empty
Expand Down Expand Up @@ -114,21 +108,15 @@ def __call__(self, interpolants: pd.DataFrame) -> pd.DataFrame:
validate_call_data(interpolants, self.key_columns, self.parameter_columns)

if self.key_columns:
# TODO: fix rather than suppress this FutureWarning
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
sub_tables = interpolants.groupby(list(self.key_columns))
sub_tables = interpolants.groupby(list(self.key_columns))
else:
sub_tables = [(None, interpolants)]
# specify some numeric type for columns so they won't be objects but will updated with whatever
# column type actually is
result = pd.DataFrame(
index=interpolants.index, columns=self.value_columns, dtype=np.float64
)
# TODO: fix rather than suppress this FutureWarning
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
sub_tables = list(sub_tables)
sub_tables = list(sub_tables)
for key, sub_table in sub_tables:
if sub_table.empty:
continue
Expand Down Expand Up @@ -226,19 +214,13 @@ def check_data_complete(data, parameter_columns):
for p in param_edges:
other_params = [p_ed[0] for p_ed in param_edges if p_ed != p]
if other_params:
# TODO: fix rather than suppress this FutureWarning
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
sub_tables = data.groupby(list(other_params))
sub_tables = data.groupby(list(other_params))
else:
sub_tables = {None: data}.items()

n_p_total = len(set(data[p[0]]))

# TODO: fix rather than suppress this FutureWarning
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
sub_tables = list(sub_tables)
sub_tables = list(sub_tables)
for _, table in sub_tables:
param_data = table[[p[0], p[1]]].copy().sort_values(by=p[0])
start, end = param_data[p[0]].reset_index(drop=True), param_data[
Expand Down

0 comments on commit 4f14dad

Please sign in to comment.