Skip to content

Commit

Permalink
Bugfix/randomness typing (#289)
Browse files Browse the repository at this point in the history
* support float inputs to filter_for_rate/probability

* pin pandas below 2.0.0
  • Loading branch information
rmudambi committed Apr 3, 2023
1 parent 7e9788f commit 80cc034
Show file tree
Hide file tree
Showing 22 changed files with 63 additions and 39 deletions.
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from setuptools import find_packages, setup

if __name__ == "__main__":

base_dir = Path(__file__).parent
src_dir = base_dir / "src"

Expand All @@ -33,7 +32,7 @@

install_requirements = [
"numpy",
"pandas",
"pandas<2.0.0",
"pyyaml>=5.1",
"scipy",
"click",
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/examples/boids/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


class Location:

configuration_defaults = {
"location": {
"width": 1000, # Width of our field
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/examples/boids/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


class Neighbors:

configuration_defaults = {"neighbors": {"radius": 10}}

def __init__(self):
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/examples/boids/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


class Population:

configuration_defaults = {
"population": {
"colors": ["red", "blue"],
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/examples/disease_model/disease.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def metrics(self, index, metrics):


class SISDiseaseModel:

configuration_defaults = {
"disease": {
"incidence_rate": 0.005,
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/examples/disease_model/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class TreatmentIntervention:

configuration_defaults = {
"intervention": {
"effect_size": 0.5,
Expand Down
2 changes: 0 additions & 2 deletions src/vivarium/examples/disease_model/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Observer:

configuration_defaults = {
"mortality": {
"life_expectancy": 80,
Expand All @@ -24,7 +23,6 @@ def setup(self, builder: Builder):
builder.value.register_value_modifier("metrics", self.metrics)

def metrics(self, index: pd.Index, metrics: Dict):

pop = self.population_view.get(index)
metrics["total_population_alive"] = len(pop[pop.alive == "alive"])
metrics["total_population_dead"] = len(pop[pop.alive == "dead"])
Expand Down
2 changes: 0 additions & 2 deletions src/vivarium/examples/disease_model/risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


class Risk:

configuration_defaults = {
"risk": {
"proportion_exposed": 0.3,
Expand Down Expand Up @@ -50,7 +49,6 @@ def __repr__(self):


class RiskEffect:

configuration_defaults = {
"risk_effect": {
"relative_risk": 2,
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/framework/artifact/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ def _write_pandas_data(path: Path, entity_key: EntityKey, data: Union[PandasObj]
def _write_json_blob(path: Path, entity_key: EntityKey, data: Any):
"""Writes a Python object as json to the HDF file at the given path."""
with tables.open_file(str(path), "a") as store:

if entity_key.group_prefix not in store:
store.create_group("/", entity_key.type)

Expand Down
1 change: 0 additions & 1 deletion src/vivarium/framework/artifact/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def parse_artifact_path_config(config: ConfigTree) -> str:
path = Path(config.input_data.artifact_path)

if not path.is_absolute():

path_config = config.input_data.metadata("artifact_path")[-1]
if path_config["source"] is None:
raise ValueError("Insufficient information provided to find artifact.")
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/framework/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@


class SimulationContext:

_created_simulation_contexts: Set[str] = set()

@staticmethod
Expand Down
2 changes: 0 additions & 2 deletions src/vivarium/framework/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(
extrapolate: bool,
validate: bool,
):

self.data = data
self.population_view = population_view
self.key_columns = key_columns
Expand Down Expand Up @@ -138,7 +137,6 @@ def __init__(
values: Union[List[ScalarValue], Tuple[ScalarValue]],
value_columns: Union[List[str], Tuple[str]],
):

self.values = values
self.value_columns = value_columns

Expand Down
22 changes: 12 additions & 10 deletions src/vivarium/framework/randomness/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def get_draw(self, index: pd.Index, additional_key: Any = None) -> pd.Series:
def filter_for_rate(
self,
population: Union[pd.DataFrame, pd.Series, pd.Index],
rate: Union[List, Tuple, np.ndarray, pd.Series],
rate: Union[float, List, Tuple, np.ndarray, pd.Series],
additional_key: Any = None,
) -> Union[pd.DataFrame, pd.Series, pd.Index]:
"""Decide an event outcome for each individual from rates.
Expand All @@ -199,11 +199,12 @@ def filter_for_rate(
A view on the simulants for which we are determining the
outcome of an event.
rate
A 1d list of rates of the event under consideration occurring which
corresponds (i.e. `len(population) == len(probability))` to the
population view passed in. The rates must be scaled to the
simulation time-step size either manually or as a post-processing
step in a rate pipeline.
A scalar float value or a 1d list of rates of the event under
consideration occurring which corresponds (i.e.
`len(population) == len(probability))` to the population view passed
in. The rates must be scaled to the simulation time-step size either
manually or as a post-processing step in a rate pipeline. If a
scalar is provided, it is applied to every row in the population.
additional_key
Any additional information used to create the seed.
Expand All @@ -221,7 +222,7 @@ def filter_for_rate(
def filter_for_probability(
self,
population: Union[pd.DataFrame, pd.Series, pd.Index],
probability: Union[List, Tuple, np.ndarray, pd.Series],
probability: Union[float, List, Tuple, np.ndarray, pd.Series],
additional_key: Any = None,
) -> Union[pd.DataFrame, pd.Series, pd.Index]:
"""Decide an outcome for each individual from probabilities.
Expand All @@ -236,10 +237,11 @@ def filter_for_probability(
A view on the simulants for which we are determining the
outcome of an event.
probability
A 1d list of probabilities of the event under consideration
occurring which corresponds (i.e.
A scalar float value or a 1d list of probabilities of the event
under consideration occurring which corresponds (i.e.
`len(population) == len(probability)` to the population view
passed in.
passed in. If a scalar is provided, it is applied to every row in
the population.
additional_key
Any additional information used to create the seed.
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def add_observation(
when: str = "collect_metrics",
**additional_keys: str,
):

stratifications = self._get_stratifications(
additional_stratifications, excluded_stratifications
)
Expand Down
7 changes: 5 additions & 2 deletions src/vivarium/framework/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import functools
from bdb import BdbQuit
from importlib import import_module
from typing import Any, Callable
from typing import Any, Callable, List, Tuple, Union

import numpy as np
import pandas as pd


def from_yearly(value, time_step):
Expand All @@ -22,10 +23,12 @@ def to_yearly(value, time_step):
return value / (time_step.total_seconds() / (60 * 60 * 24 * 365.0))


def rate_to_probability(rate):
def rate_to_probability(rate: Union[float, List, Tuple, np.ndarray, pd.Series]) -> np.ndarray:
# encountered underflow from rate > 30k
# for rates greater than 250, exp(-rate) evaluates to 1e-109
# beware machine-specific floating point issues

rate = np.array(rate)
rate[rate > 250] = 250.0
return 1 - np.exp(-rate)

Expand Down
2 changes: 0 additions & 2 deletions src/vivarium/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
extrapolate: bool,
validate: bool,
):

# TODO: allow for order 1 interpolation with binned edges
if order != 0:
raise NotImplementedError(
Expand Down Expand Up @@ -241,7 +240,6 @@ def check_data_complete(data, parameter_columns):
warnings.simplefilter(action="ignore", category=FutureWarning)
sub_tables = list(sub_tables)
for _, table in sub_tables:

param_data = table[[p[0], p[1]]].copy().sort_values(by=p[0])
start, end = param_data[p[0]].reset_index(drop=True), param_data[
p[1]
Expand Down
1 change: 0 additions & 1 deletion src/vivarium/testing_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class NonCRNTestPopulation:

configuration_defaults = {
"population": {
"age_start": 0,
Expand Down
1 change: 0 additions & 1 deletion tests/framework/artifact/test_hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ def test_EntityKey_with_name():


def test_entity_key_equality():

type_ = "cause"
name = "diarrheal_diseases"
measure = "incidence"
Expand Down
1 change: 0 additions & 1 deletion tests/framework/components/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __eq__(self, other):


class MockGenericComponent:

configuration_defaults = {
"component": {
"key1": "val",
Expand Down
1 change: 0 additions & 1 deletion tests/framework/components/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def import_and_instantiate_mock(mocker):


def test_parse_component_config(components):

source = yaml.full_load(components)["components"]
component_list = parse_component_config_to_list(source)
assert set(TEST_COMPONENTS_PARSED) == set(component_list)
Expand Down
43 changes: 39 additions & 4 deletions tests/framework/randomness/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,48 @@ def test__set_residual_probability(weights_with_residuals, index):
assert np.isclose(p_total, len(index), atol=0.0001)


def test_filter_for_probability(randomness_stream, index):

def test_filter_for_probability_single_probability(randomness_stream, index):
sub_index = randomness_stream.filter_for_probability(index, 0.5)
assert round(len(sub_index) / len(index), 1) == 0.5
assert np.isclose(len(sub_index) / len(index), 0.5, rtol=0.1)

sub_sub_index = randomness_stream.filter_for_probability(sub_index, 0.5)
assert round(len(sub_sub_index) / len(sub_index), 1) == 0.5
assert np.isclose(len(sub_sub_index) / len(sub_index), 0.5, rtol=0.1)


def test_filter_for_probability_multiple_probabilities(randomness_stream, index):
probabilities = pd.Series([0.3, 0.3, 0.3, 0.6, 0.6] * (index.size // 5), index=index)
threshold_0_3 = probabilities.index[probabilities == 0.3]
threshold_0_6 = probabilities.index.difference(threshold_0_3)

sub_index = randomness_stream.filter_for_probability(index, probabilities)
assert np.isclose(
len(sub_index.intersection(threshold_0_3)) / len(threshold_0_3), 0.3, rtol=0.1
)
assert np.isclose(
len(sub_index.intersection(threshold_0_6)) / len(threshold_0_6), 0.6, rtol=0.1
)


def test_filter_for_rate_single_probability(randomness_stream, index):
sub_index = randomness_stream.filter_for_rate(index, 0.5)
assert np.isclose(len(sub_index) / len(index), 1 - np.exp(-0.5), rtol=0.1)

sub_sub_index = randomness_stream.filter_for_rate(sub_index, 0.5)
assert np.isclose(len(sub_sub_index) / len(sub_index), 1 - np.exp(-0.5), rtol=0.1)


def test_filter_for_rate_multiple_probabilities(randomness_stream, index):
rates = pd.Series([0.3, 0.3, 0.3, 0.6, 0.6] * (index.size // 5), index=index)
rate_0_3 = rates.index[rates == 0.3]
rate_0_6 = rates.index.difference(rate_0_3)

sub_index = randomness_stream.filter_for_rate(index, rates)
assert np.isclose(
len(sub_index.intersection(rate_0_3)) / len(rate_0_3), 1 - np.exp(-0.3), rtol=0.1
)
assert np.isclose(
len(sub_index.intersection(rate_0_6)) / len(rate_0_6), 1 - np.exp(-0.6), rtol=0.1
)


def test_choice(randomness_stream, index, choices, weights):
Expand Down
6 changes: 6 additions & 0 deletions tests/framework/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def test_rate_to_probability():
assert np.isclose(prob, 0.00099950016662497809)


def test_very_high_rate_to_probability():
rate = np.array([10_000])
prob = rate_to_probability(rate)
assert np.isclose(prob, 1.0)


def test_probability_to_rate():
prob = np.array([0.00099950016662497809])
rate = probability_to_rate(prob)
Expand Down

0 comments on commit 80cc034

Please sign in to comment.