Skip to content

Commit

Permalink
Merge pull request #275 from ihmeuw/collijk/refactor/reduce-argument-…
Browse files Browse the repository at this point in the history
…chaining-in-randomness

Remove filter_for_probability and reduce argument chaining in choice
  • Loading branch information
collijk committed Feb 22, 2023
2 parents adcca78 + c0b0b49 commit 89bcc32
Showing 1 changed file with 25 additions and 74 deletions.
99 changes: 25 additions & 74 deletions src/vivarium/framework/randomness/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def filter_for_rate(
"""Decide an event outcome for each individual from rates.
Given a population or its index and an array of associated rates for
some event to happen, we create and return the sub-population for whom
some event to happen, we create and return the subpopulation for whom
the event occurred.
Parameters
Expand All @@ -172,7 +172,7 @@ def filter_for_rate(
Returns
-------
pandas.core.generic.PandasObject
The sub-population of the simulants for whom the event occurred.
The subpopulation of the simulants for whom the event occurred.
The return type will be the same as type(population)
"""
Expand All @@ -189,7 +189,7 @@ def filter_for_probability(
"""Decide an outcome for each individual from probabilities.
Given a population or its index and an array of associated probabilities
for some event to happen, we create and return the sub-population for
for some event to happen, we create and return the subpopulation for
whom the event occurred.
Parameters
Expand All @@ -208,13 +208,17 @@ def filter_for_probability(
Returns
-------
pandas.core.generic.PandasObject
The sub-population of the simulants for whom the event occurred.
The subpopulation of the simulants for whom the event occurred.
The return type will be the same as type(population)
"""
return filter_for_probability(
self._key(additional_key), population, probability, self.index_map
)
if population.empty:
return population

index = population if isinstance(population, pd.Index) else population.index
draws = self.get_draw(index, additional_key)
mask = np.array(draws < probability)
return population[mask]

def choice(
self,
Expand Down Expand Up @@ -259,7 +263,8 @@ def choice(
more than one reference to `RESIDUAL_CHOICE`.
"""
return choice(self._key(additional_key), index, choices, p, self.index_map)
draws = self.get_draw(index, additional_key)
return _choice(draws, choices, p)

def __repr__(self) -> str:
return "RandomnessStream(key={!r}, clock={!r}, seed={!r})".format(
Expand Down Expand Up @@ -316,12 +321,10 @@ def random(
return pd.Series(index=index, dtype=float) # Structured null value


def choice(
key: str,
index: Union[pd.Index, pd.MultiIndex],
def _choice(
draws: pd.Series,
choices: Union[List, Tuple, np.ndarray, pd.Series],
p: Union[List, Tuple, np.ndarray, pd.Series] = None,
index_map: IndexMap = None,
) -> pd.Series:
"""Decides between a weighted or unweighted set of choices.
Expand All @@ -331,24 +334,19 @@ def choice(
Parameters
----------
key
A string used to create a seed for the random number generation.
index
An index whose length is the number of random draws made
and which indexes the returned `pandas.Series`.
draws
A uniformly distributed random number for every simulant to make
a choice for.
choices
A set of options to choose from.
A set of options to choose from. Choices must be the same for every
simulant.
p
The relative weights of the choices. Can be either a 1-d array of
the same length as `choices` or a 2-d array with `len(index)` rows
the same length as `choices` or a 2-d array with `len(draws)` rows
and `len(choices)` columns. In the 1-d case, the same set of weights
are used to decide among the choices for every item in the `index`.
In the 2-d case, each row in `p` contains a separate set of weights
for every item in the `index`.
index_map
A mapping between the provided index (which may contain ints, floats,
datetimes or any arbitrary combination of them) and an integer index
into the random number array.
Returns
-------
Expand All @@ -365,19 +363,17 @@ def choice(
"""
# Convert p to normalized probabilities broadcasted over index.
p = (
_set_residual_probability(_normalize_shape(p, index))
_set_residual_probability(_normalize_shape(p, draws.index))
if p is not None
else np.ones((len(index), len(choices)))
else np.ones((len(draws.index), len(choices)))
)
p = p / p.sum(axis=1, keepdims=True)

draw = random(key, index, index_map)

p_bins = np.cumsum(p, axis=1)
# Use the random draw to make a choice for every row in index.
choice_index = (draw.values[np.newaxis].T > p_bins).sum(axis=1)
choice_index = (draws.values[np.newaxis].T > p_bins).sum(axis=1)

return pd.Series(np.array(choices)[choice_index], index=index)
return pd.Series(np.array(choices)[choice_index], index=draws.index)


def _normalize_shape(
Expand Down Expand Up @@ -428,48 +424,3 @@ def _set_residual_probability(p: np.ndarray) -> np.ndarray:

p[residual_mask] = residual_p
return p


def filter_for_probability(
key: str,
population: Union[pd.DataFrame, pd.Series, pd.Index],
probability: Union[List, Tuple, np.ndarray, pd.Series],
index_map: IndexMap = None,
) -> Union[pd.DataFrame, pd.Series, pd.Index]:
"""Decide an event outcome for each individual in a population from
probabilities.
Given a population or its index and an array of associated probabilities
for some event to happen, we create and return the sub-population for whom
the event occurred.
Parameters
----------
key
A string used to create a seed for the random number generation.
population
A view on the simulants for which we are determining the
outcome of an event.
probability
A 1d list of probabilities of the event under consideration
occurring which corresponds (i.e. `len(population) == len(probability)`)
to the population array passed in.
index_map
A mapping between the provided index (which may contain ints, floats,
datetimes or any arbitrary combination of them) and an integer index
into the random number array.
Returns
-------
Union[pandas.DataFrame, pandas.Series, pandas.Index]
The sub-population of the simulants for whom the event occurred.
The return type will be the same as type(population)
"""
if population.empty:
return population

index = population if isinstance(population, pd.Index) else population.index
draw = random(key, index, index_map)
mask = np.array(draw < probability)
return population[mask]

0 comments on commit 89bcc32

Please sign in to comment.