Skip to content

Commit

Permalink
Revert get draws kwargs (#326)
Browse files Browse the repository at this point in the history
    Category: bugfix
    JIRA issue: MIC-4615

Changes and notes

The get_draws_kwargs argument was erroneously merged into main in #312 but should be removed because it's never used.
  • Loading branch information
hussain-jafari committed Jan 11, 2024
1 parent 21a7e8d commit 74a5fd3
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 14 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**4.1.4 - 01/11/24**

- Remove erroneously merged get_draws_kwargs argument

**4.1.3 - 01/09/24**

- Update PyPI to 2FA with trusted publisher
Expand Down
3 changes: 1 addition & 2 deletions docs/source/tutorials/pulling_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,7 @@ To use:
'parameters': {
'entity': ModelableEntity,
'measure': str,
'location': str,
'get_draws_kwargs': inspect._empty,
'location': str,
},
'return': pd.DataFrame, },
get_population_structure: {
Expand Down
8 changes: 4 additions & 4 deletions src/vivarium_inputs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from vivarium_inputs.mapping_extension import AlternativeRiskFactor, HealthcareEntity


def get_data(entity, measure: str, location: Union[str, int], **get_draws_kwargs):
def get_data(entity, measure: str, location: Union[str, int]):
measure_handlers = {
# Cause-like measures
"incidence_rate": (get_incidence_rate, ("cause", "sequela")),
Expand Down Expand Up @@ -78,7 +78,7 @@ def get_data(entity, measure: str, location: Union[str, int], **get_draws_kwargs
location_id = (
utility_data.get_location_id(location) if isinstance(location, str) else location
)
data = handler(entity, location_id, **get_draws_kwargs)
data = handler(entity, location_id)

if measure in [
"structure",
Expand Down Expand Up @@ -217,9 +217,9 @@ def get_deaths(entity: Cause, location_id: int) -> pd.DataFrame:


def get_exposure(
entity: Union[RiskFactor, AlternativeRiskFactor], location_id: int, **get_draws_kwargs
entity: Union[RiskFactor, AlternativeRiskFactor], location_id: int
) -> pd.DataFrame:
data = extract.extract_data(entity, "exposure", location_id, **get_draws_kwargs)
data = extract.extract_data(entity, "exposure", location_id)
data = data.drop("modelable_entity_id", "columns")

if entity.name in EXTRA_RESIDUAL_CATEGORY:
Expand Down
8 changes: 4 additions & 4 deletions src/vivarium_inputs/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def extract_data(
entity, measure: str, location_id: int, validate: bool = True, **get_draws_kwargs
entity, measure: str, location_id: int, validate: bool = True
) -> Union[pd.Series, pd.DataFrame]:
"""Check metadata for the requested entity-measure pair. Pull raw data from
GBD. The only filtering that occurs is by applicable measure id, metric id,
Expand Down Expand Up @@ -94,7 +94,7 @@ def extract_data(

try:
main_extractor, additional_extractors = extractors[measure]
data = main_extractor(entity, location_id, **get_draws_kwargs)
data = main_extractor(entity, location_id)
except (
ValueError,
AssertionError,
Expand Down Expand Up @@ -176,10 +176,10 @@ def extract_deaths(entity: Cause, location_id: int) -> pd.DataFrame:


def extract_exposure(
entity: Union[RiskFactor, AlternativeRiskFactor], location_id: int, **get_draws_kwargs
entity: Union[RiskFactor, AlternativeRiskFactor], location_id: int
) -> pd.DataFrame:
if entity.kind == "risk_factor":
data = gbd.get_exposure(entity.gbd_id, location_id, **get_draws_kwargs)
data = gbd.get_exposure(entity.gbd_id, location_id)
allowable_measures = [
MEASURES["Proportion"],
MEASURES["Continuous"],
Expand Down
6 changes: 2 additions & 4 deletions src/vivarium_inputs/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
from vivarium_inputs.globals import Population


def get_measure(
entity: ModelableEntity, measure: str, location: str, **get_draws_kwargs
) -> pd.DataFrame:
def get_measure(entity: ModelableEntity, measure: str, location: str) -> pd.DataFrame:
"""Pull GBD data for measure and entity and prep for simulation input,
including scrubbing all GBD conventions to replace IDs with meaningful
values or ranges and expanding over all demographic dimensions. To pull data
Expand Down Expand Up @@ -55,7 +53,7 @@ def get_measure(
Dataframe standardized to the format expected by `vivarium` simulations.
"""
data = core.get_data(entity, measure, location, **get_draws_kwargs)
data = core.get_data(entity, measure, location)
data = utilities.scrub_gbd_conventions(data, location)
validation.validate_for_simulation(data, entity, measure, location)
data = utilities.split_interval(data, interval_column="age", split_column_prefix="age")
Expand Down

0 comments on commit 74a5fd3

Please sign in to comment.