Skip to content

Commit

Permalink
Merge pull request #280 from ihmeuw/collijk/refactor/reorganize-crn-code
Browse files Browse the repository at this point in the history
Clean up CRN implementation so it's clearer what's going on
  • Loading branch information
collijk committed Feb 25, 2023
2 parents 612f4aa + 5a6284f commit 7b71c27
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 60 deletions.
10 changes: 4 additions & 6 deletions src/vivarium/framework/randomness/index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,11 @@ def _shift(self, m: Union[float, pd.Series]) -> Union[int, pd.Series]:
return out.astype("int64")
return int(out)

def __getitem__(self, index: pd.Index) -> pd.Series:
if not self._use_crn:
return pd.Series(index, index=index)
if isinstance(index, pd.Index):
return self._map.loc[index]
def __getitem__(self, index: pd.Index) -> np.ndarray:
if self._use_crn:
return self._map.loc[index].values
else:
raise IndexError(index)
return index.values

def __len__(self) -> int:
return self._size
Expand Down
98 changes: 44 additions & 54 deletions src/vivarium/framework/randomness/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,54 @@ def get_draw(self, index: pd.Index, additional_key: Any = None) -> pd.Series:
-------
pandas.Series
A series of random numbers indexed by the provided `pandas.Index`.
Note
----
This is the core of the CRN implementation, allowing for consistent use of random
numbers across simulations with multiple scenarios.
See Also
--------
https://en.wikipedia.org/wiki/Variance_reduction and
"Untangling Uncertainty with Common Random Numbers:
A Simulation Study; A.Flaxman, et. al., Summersim 2017"
"""
key = self._key(additional_key)
# Return a structured null value if an empty index is passed
if index.empty:
return pd.Series(index=index, dtype=float)

# Initialize a random_state with a seed based on the simulation clock, the
# decision_point this stream represents, and any additional user-supplied information.
# This is one pre-condition to reproducibility.
seed = get_hash(self._key(additional_key))
random_state = np.random.RandomState(seed=seed)

# Second, we need to sample a very large chunk of random numbers. The size of the
# index map is set at the simulation start and is at least 10x the size of the
# initial population. Which means this is a consistently sampled block of uniformly
# distributed random numbers, irrespective of the size of the simulation population,
# which is important if there are scenarios that result in different population sizes
# through time.
sample_size = len(self.index_map)
raw_draws = random_state.random_sample(sample_size)

if self.initializes_crn_attributes:
draw = random(key, pd.Index(range(len(index))))
draw.index = index
# If we're initializing CRN attributes (i.e. attributes used to identify a
# simulant across multiple simulations), we can't use the index map yet since
# these people aren't registered in the index map yet. Instead, we use the draws
# from our random sample in order. This couples the initialization of CRN
# attributes to the time step on which they are initialized, meaning interventions
# that alter the entrance time of a simulant will break the CRN guarantees.
draws = pd.Series(raw_draws[: len(index)], index=index)
else:
draw = random(key, index, self.index_map)
# If we're not initializing CRN attributes, we can use the index map to get the
# correct draws for each simulant. This allows us to use the same CRN attributes
# across multiple simulations, even if the population size changes.
draw_index = self.index_map[index]
draws = pd.Series(raw_draws[draw_index], index=index)

return draw
return draws

def filter_for_rate(
self,
Expand Down Expand Up @@ -271,55 +310,6 @@ def __repr__(self) -> str:
)


def random(
key: str,
index: Union[pd.Index, pd.MultiIndex],
index_map: IndexMap = None,
) -> pd.Series:
"""Produces an indexed set of uniformly distributed random numbers.
The index passed in typically corresponds to a subset of rows in a
`pandas.DataFrame` for which a probabilistic draw needs to be made.
Parameters
----------
key :
A string used to create a seed for the random number generation.
index :
The index used for the returned series.
index_map :
A mapping between the provided index (which may contain ints, floats,
datetimes or any arbitrary combination of them) and an integer index
into the random number array.
Returns
-------
pandas.Series
A series of random numbers indexed by the provided index.
"""
if len(index) > 0:
random_state = np.random.RandomState(seed=get_hash(key))

# Generate a random number for every simulant.
#
# NOTE: We generate a full set of random numbers for the population
# even when we may only need a few. This ensures consistency in outcomes
# across simulations.
# See Also:
# 1. https://en.wikipedia.org/wiki/Variance_reduction
# 2. Untangling Uncertainty with Common Random Numbers: A Simulation Study; A.Flaxman, et. al., Summersim 2017
sample_size = len(index_map) if index_map is not None else index.max() + 1
try:
draw_index = index_map[index]
except (IndexError, TypeError):
draw_index = index
raw_draws = random_state.random_sample(sample_size)
return pd.Series(raw_draws[draw_index], index=index)

return pd.Series(index=index, dtype=float) # Structured null value


def _choice(
draws: pd.Series,
choices: Union[List, Tuple, np.ndarray, pd.Series],
Expand Down

0 comments on commit 7b71c27

Please sign in to comment.