Skip to content

Commit

Permalink
Merge pull request #273 from ihmeuw/collijk/refactor/index-map-interface
Browse files Browse the repository at this point in the history
Be clearer about the index map interface
  • Loading branch information
collijk committed Feb 22, 2023
2 parents cfc1243 + 46aa312 commit 24ff7c6
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 69 deletions.
43 changes: 22 additions & 21 deletions src/vivarium/framework/randomness/index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
positional index within a stream of seeded random numbers.
"""

import datetime
from typing import Union

Expand All @@ -24,9 +23,10 @@ class IndexMap:

TEN_DIGIT_MODULUS = 10_000_000_000

def __init__(self, map_size=1_000_000):
def __init__(self, use_crn: bool = True, size: int = 1_000_000):
self._use_crn = use_crn
self._map = pd.Series(dtype=float)
self.map_size = map_size
self._size = size

def update(self, new_keys: pd.Index) -> None:
"""Adds the new keys to the mapping.
Expand All @@ -37,13 +37,12 @@ def update(self, new_keys: pd.Index) -> None:
The new index to hash.
"""

if new_keys.empty:
if new_keys.empty or not self._use_crn:
return # Nothing to do
elif not self._map.index.intersection(new_keys).empty:
raise KeyError("Non-unique keys in index")

mapping_update = self.hash_(new_keys)
mapping_update = self._hash(new_keys)
if self._map.empty:
self._map = mapping_update.drop_duplicates()
else:
Expand All @@ -52,12 +51,12 @@ def update(self, new_keys: pd.Index) -> None:
collisions = mapping_update.index.difference(self._map.index)
salt = 1
while not collisions.empty:
mapping_update = self.hash_(collisions, salt)
mapping_update = self._hash(collisions, salt)
self._map = pd.concat([self._map, mapping_update]).drop_duplicates()
collisions = mapping_update.index.difference(self._map.index)
salt += 1

def hash_(self, keys: pd.Index, salt: int = 0) -> pd.Series:
def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series:
"""Hashes the index into an integer index in the range [0, self.stride]
Parameters
Expand All @@ -78,10 +77,10 @@ def hash_(self, keys: pd.Index, salt: int = 0) -> pd.Series:
"""
key_frame = keys.to_frame()
new_map = pd.Series(0, index=keys)
salt = self.convert_to_ten_digit_int(pd.Series(salt, index=keys))
salt = self._convert_to_ten_digit_int(pd.Series(salt, index=keys))

for i, column_name in enumerate(key_frame.columns):
column = self.convert_to_ten_digit_int(key_frame[column_name])
column = self._convert_to_ten_digit_int(key_frame[column_name])

primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 27]
out = pd.Series(1, index=column.index)
Expand All @@ -90,12 +89,12 @@ def hash_(self, keys: pd.Index, salt: int = 0) -> pd.Series:
# to modding out by 2**64. Since it's much much larger than
# our map size the amount of additional periodicity this
# introduces is pretty trivial.
out *= np.power(p, self.digit(column, idx))
out *= np.power(p, self._digit(column, idx))
new_map += out + salt

return new_map % self.map_size
return new_map % len(self)

def convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series:
def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series:
"""Converts a column of datetimes, integers, or floats into a column
of 10 digit integers.
Expand All @@ -117,15 +116,15 @@ def convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series:
"""
if isinstance(column.iloc[0], datetime.datetime):
column = self.clip_to_seconds(column.view(np.int64))
column = self._clip_to_seconds(column.view(np.int64))
elif np.issubdtype(column.iloc[0], np.integer):
if not len(column >= 0) == len(column):
raise RandomnessError(
"Values in integer columns must be greater than or equal to zero."
)
column = self.spread(column)
column = self._spread(column)
elif np.issubdtype(column.iloc[0], np.floating):
column = self.shift(column)
column = self._shift(column)
else:
raise RandomnessError(
f"Unhashable column type {type(column.iloc[0])}. "
Expand All @@ -134,34 +133,36 @@ def convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series:
return column

@staticmethod
def digit(m: Union[int, pd.Series], n: int) -> Union[int, pd.Series]:
def _digit(m: Union[int, pd.Series], n: int) -> Union[int, pd.Series]:
"""Returns the nth digit of each number in m."""
return (m // (10**n)) % 10

@staticmethod
def clip_to_seconds(m: Union[int, pd.Series]) -> Union[int, pd.Series]:
def _clip_to_seconds(m: Union[int, pd.Series]) -> Union[int, pd.Series]:
"""Clips UTC datetime in nanoseconds to seconds."""
return m // pd.Timedelta(1, unit="s").value

def spread(self, m: Union[int, pd.Series]) -> Union[int, pd.Series]:
def _spread(self, m: Union[int, pd.Series]) -> Union[int, pd.Series]:
"""Spreads out integer values to give smaller values more weight."""
return (m * 111_111) % self.TEN_DIGIT_MODULUS

def shift(self, m: Union[float, pd.Series]) -> Union[int, pd.Series]:
def _shift(self, m: Union[float, pd.Series]) -> Union[int, pd.Series]:
"""Shifts floats so that the first 10 decimal digits are significant."""
out = m % 1 * self.TEN_DIGIT_MODULUS // 1
if isinstance(out, pd.Series):
return out.astype("int64")
return int(out)

def __getitem__(self, index: pd.Index) -> pd.Series:
if not self._use_crn:
return pd.Series(index, index=index)
if isinstance(index, (pd.Index, pd.MultiIndex)):
return self._map[index]
else:
raise IndexError(index)

def __len__(self) -> int:
return len(self._map)
return self._size

def __repr__(self) -> str:
return "IndexMap({})".format("\n ".join(repr(self._map).split("\n")))
7 changes: 5 additions & 2 deletions src/vivarium/framework/randomness/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self):
self._seed = None
self._clock = None
self._key_columns = None
self._key_mapping = IndexMap()
self._key_mapping = None
self._decision_points = dict()

@property
Expand All @@ -39,9 +39,12 @@ def setup(self, builder):
self._seed += str(builder.configuration.randomness.additional_seed)
self._clock = builder.time.clock()
self._key_columns = builder.configuration.randomness.key_columns

use_crn = bool(self._key_columns)
map_size = builder.configuration.randomness.map_size
pop_size = builder.configuration.population.population_size
self._key_mapping.map_size = max(map_size, 10 * pop_size)
map_size = max(map_size, 10 * pop_size)
self._key_mapping = IndexMap(use_crn, map_size)

self.resources = builder.resources
self._add_constraint = builder.lifecycle.add_constraint
Expand Down
20 changes: 10 additions & 10 deletions src/vivarium/framework/randomness/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
key: str,
clock: Callable,
seed: Any,
index_map: IndexMap = None,
index_map: IndexMap,
for_initialization: bool = False,
):
self.key = key
Expand Down Expand Up @@ -305,7 +305,7 @@ def random(
# See Also:
# 1. https://en.wikipedia.org/wiki/Variance_reduction
# 2. Untangling Uncertainty with Common Random Numbers: A Simulation Study; A.Flaxman, et. al., Summersim 2017
sample_size = index_map.map_size if index_map is not None else index.max() + 1
sample_size = len(index_map) if index_map is not None else index.max() + 1
try:
draw_index = index_map[index]
except (IndexError, TypeError):
Expand Down Expand Up @@ -410,21 +410,21 @@ def _set_residual_probability(p: np.ndarray) -> np.ndarray:
residual_mask = p == RESIDUAL_CHOICE
if residual_mask.any(): # I.E. if we have any placeholders.
if np.any(np.sum(residual_mask, axis=1) - 1):
raise RandomnessError(
"More than one residual choice supplied for a single set of weights. Weights: {}.".format(
p
)
msg = (
"More than one residual choice supplied for a single "
f"set of weights. Weights: {p}."
)
raise RandomnessError(msg)

p[residual_mask] = 0
residual_p = 1 - np.sum(p, axis=1) # Probabilities sum to 1.

if np.any(residual_p < 0): # We got un-normalized probability weights.
raise RandomnessError(
"Residual choice supplied with weights that summed to more than 1. Weights: {}.".format(
p
)
msg = (
"Residual choice supplied with weights that summed to more than 1. "
f"Weights: {p}."
)
raise RandomnessError(msg)

p[residual_mask] = residual_p
return p
Expand Down
40 changes: 20 additions & 20 deletions tests/framework/randomness/test_index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,28 +57,28 @@ def id_fun(param):
def map_size_and_hashed_values(request):
keys = generate_keys(*request.param)
m = IndexMap()
return m.map_size, m.hash_(keys)
return len(m), m._hash(keys)


def test_digit_scalar():
m = IndexMap()
k = 123456789
for i in range(10):
assert m.digit(k, i) == 10 - (i + 1)
assert m._digit(k, i) == 10 - (i + 1)


def test_digit_series():
m = IndexMap()
k = pd.Series(123456789, index=range(10000))
for i in range(10):
assert len(m.digit(k, i).unique()) == 1
assert m.digit(k, i)[0] == 10 - (i + 1)
assert len(m._digit(k, i).unique()) == 1
assert m._digit(k, i)[0] == 10 - (i + 1)


def test_clip_to_seconds_scalar():
m = IndexMap()
k = pd.to_datetime("2010-01-25 06:25:31.123456789")
assert m.clip_to_seconds(k.value) == int(str(k.value)[:10])
assert m._clip_to_seconds(k.value) == int(str(k.value)[:10])


def test_clip_to_seconds_series():
Expand All @@ -89,32 +89,32 @@ def test_clip_to_seconds_series():
.to_series()
.view(np.int64)
)
assert len(m.clip_to_seconds(k).unique()) == 1
assert m.clip_to_seconds(k).unique()[0] == stamp
assert len(m._clip_to_seconds(k).unique()) == 1
assert m._clip_to_seconds(k).unique()[0] == stamp


def test_spread_scalar():
m = IndexMap()
assert m.spread(1234567890) == 4072825790
assert m._spread(1234567890) == 4072825790


def test_spread_series():
m = IndexMap()
s = pd.Series(1234567890, index=range(10000))
assert len(m.spread(s).unique()) == 1
assert m.spread(s).unique()[0] == 4072825790
assert len(m._spread(s).unique()) == 1
assert m._spread(s).unique()[0] == 4072825790


def test_shift_scalar():
m = IndexMap()
assert m.shift(1.1234567890) == 1234567890
assert m._shift(1.1234567890) == 1234567890


def test_shift_series():
m = IndexMap()
s = pd.Series(1.1234567890, index=range(10000))
assert len(m.shift(s).unique()) == 1
assert m.shift(s).unique()[0] == 1234567890
assert len(m._shift(s).unique()) == 1
assert m._shift(s).unique()[0] == 1234567890


def test_convert_to_ten_digit_int():
Expand All @@ -127,11 +127,11 @@ def test_convert_to_ten_digit_int():
float_col = pd.Series(1.1234567890, index=range(10000))
bad_col = pd.Series("a", index=range(10000))

assert m.convert_to_ten_digit_int(datetime_col).unique()[0] == v
assert m.convert_to_ten_digit_int(int_col).unique()[0] == 4072825790
assert m.convert_to_ten_digit_int(float_col).unique()[0] == v
assert m._convert_to_ten_digit_int(datetime_col).unique()[0] == v
assert m._convert_to_ten_digit_int(int_col).unique()[0] == 4072825790
assert m._convert_to_ten_digit_int(float_col).unique()[0] == v
with pytest.raises(RandomnessError):
m.convert_to_ten_digit_int(bad_col)
m._convert_to_ten_digit_int(bad_col)


@pytest.mark.skip("This fails because the hash needs work")
Expand Down Expand Up @@ -173,9 +173,9 @@ def hash_mock(k, salt=0):
rs = np.random.RandomState(seed=seed + salt)
return pd.Series(rs.randint(0, len(k) * 10, size=len(k)), index=k)

mocker.patch.object(m, "hash_", side_effect=hash_mock)
mocker.patch.object(m, "_hash", side_effect=hash_mock)
m.update(keys)
assert len(m) == len(keys), "All keys not in mapping"
assert len(m._map) == len(keys), "All keys not in mapping"
assert m._map.index.difference(keys).empty, "All keys not in mapping"
assert len(m._map.unique()) == len(keys), "Duplicate values in mapping"

Expand All @@ -185,7 +185,7 @@ def hash_mock(k, salt=0):

new_unique_keys = generate_keys(1000).difference(keys)
m.update(new_unique_keys)
assert len(m) == len(keys) + len(new_unique_keys), "All keys not in mapping"
assert len(m._map) == len(keys) + len(new_unique_keys), "All keys not in mapping"
assert m._map.index.difference(
keys.union(new_unique_keys)
).empty, "All keys not in mapping"
Expand Down
2 changes: 2 additions & 0 deletions tests/framework/randomness/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import pytest

from vivarium.framework.randomness.index_map import IndexMap
from vivarium.framework.randomness.manager import RandomnessError, RandomnessManager
from vivarium.framework.randomness.stream import get_hash

Expand Down Expand Up @@ -34,6 +35,7 @@ def test_RandomnessManager_register_simulants():
rm._seed = seed
rm._clock = mock_clock
rm._key_columns = ["age", "sex"]
rm._key_mapping = IndexMap()

bad_df = pd.DataFrame({"age": range(10), "not_sex": [1] * 5 + [2] * 5})
with pytest.raises(RandomnessError):
Expand Down

0 comments on commit 24ff7c6

Please sign in to comment.