Skip to content

Commit

Permalink
Bugfix/sbachmei/mic 4091 crn shuffling (#299)
Browse files Browse the repository at this point in the history
* join simulant index to index map
* use clock time for index map salt
* update doctest expected outputs
  • Loading branch information
stevebachmeier committed May 31, 2023
1 parent f55de30 commit afb9c76
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 21 deletions.
16 changes: 8 additions & 8 deletions docs/source/tutorials/exploration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,12 @@ the population as a whole.
alive 100000
Name: count, dtype: int64
count 100000.000000
mean 0.500602
std 0.288434
min 0.000022
25% 0.251288
50% 0.499957
75% 0.749816
mean 0.499756
std 0.288412
min 0.000015
25% 0.251550
50% 0.497587
75% 0.749215
max 0.999978
Name: child_wasting_propensity, dtype: float64
lower_respiratory_infections
Expand All @@ -319,8 +319,8 @@ the population as a whole.
2021-12-31 12:00:00 100000
Name: count, dtype: int64
sex
Male 50162
Female 49838
Male 50185
Female 49815
Name: count, dtype: int64
tracked
True 100000
Expand Down
21 changes: 16 additions & 5 deletions src/vivarium/framework/randomness/index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ def __init__(self, key_columns: List[str] = None, size: int = 1_000_000):
self._map = None
self._size = size

def update(self, new_keys: pd.DataFrame) -> None:
def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None:
"""Adds the new keys to the mapping.
Parameters
----------
new_keys
A pandas DataFrame indexed by the simulant index and columns corresponding to
the randomness system key columns.
clock_time
The simulation clock time. Used as the salt during hashing to
minimize inter-simulation collisions.
"""
if new_keys.empty or not self._use_crn:
Expand All @@ -49,10 +52,13 @@ def update(self, new_keys: pd.DataFrame) -> None:
if len(final_keys) != len(final_keys.unique()):
raise RandomnessError("Non-unique keys in index")

final_mapping = self._build_final_mapping(new_mapping_index)
final_mapping = self._build_final_mapping(new_mapping_index, clock_time)

# Tack on the simulant index to the front of the map.
final_mapping.index = final_mapping_index
final_mapping.index = final_mapping.index.join(final_mapping_index).reorder_levels(
[self.SIM_INDEX_COLUMN] + self._key_columns
)
final_mapping = final_mapping.sort_index(level=self.SIM_INDEX_COLUMN)
self._map = final_mapping

def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.MultiIndex]:
Expand Down Expand Up @@ -84,7 +90,9 @@ def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.Mul
final_mapping_index = self._map.index.append(new_mapping_index)
return new_mapping_index, final_mapping_index

def _build_final_mapping(self, new_mapping_index: pd.Index) -> pd.Series:
def _build_final_mapping(
self, new_mapping_index: pd.Index, clock_time: pd.Timestamp
) -> pd.Series:
"""Builds a new mapping between key columns and the randomness index from the
new mapping index and the existing map.
Expand All @@ -93,6 +101,9 @@ def _build_final_mapping(self, new_mapping_index: pd.Index) -> pd.Series:
new_mapping_index
An index with a level for the index assigned by the population system and
additional levels for the key columns associated with the simulant index.
clock_time
The simulation clock time. Used as the salt during hashing to
minimize inter-simulation collisions.
Returns
-------
Expand All @@ -102,7 +113,7 @@ def _build_final_mapping(self, new_mapping_index: pd.Index) -> pd.Series:
"""
new_key_index = new_mapping_index.droplevel(self.SIM_INDEX_COLUMN)
mapping_update = self._hash(new_key_index)
mapping_update = self._hash(new_key_index, salt=clock_time)
if self._map is None:
current_map = mapping_update
else:
Expand Down
2 changes: 1 addition & 1 deletion src/vivarium/framework/randomness/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def register_simulants(self, simulants: pd.DataFrame):
raise RandomnessError(
"The simulants dataframe does not have all specified key_columns."
)
self._key_mapping.update(simulants.loc[:, self._key_columns])
self._key_mapping.update(simulants.loc[:, self._key_columns], self._clock())

def __str__(self):
return "RandomnessManager()"
Expand Down
15 changes: 8 additions & 7 deletions tests/framework/randomness/test_index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,9 @@ def test_hash_uniformity(map_size_and_hashed_values):
def index_map(mocker):
mock_index_map = IndexMap

def hash_mock(k, salt=0):
def hash_mock(k, salt):
seed = 123456
salt = IndexMap()._convert_to_ten_digit_int(pd.Series(salt, index=k))
rs = np.random.RandomState(seed=seed + salt)
return pd.Series(rs.randint(0, len(k) * 10, size=len(k)), index=k)

Expand All @@ -183,21 +184,21 @@ def test_update_empty_bad_keys(index_map):
keys = pd.DataFrame({"A": ["a"] * 10}, index=range(10))
m = index_map(key_columns=list(keys.columns))
with pytest.raises(RandomnessError):
m.update(keys)
m.update(keys, pd.to_datetime("2023-01-01"))


def test_update_nonempty_bad_keys(index_map):
keys = generate_keys(1000)
m = index_map(key_columns=list(keys.columns))
m.update(keys)
m.update(keys, pd.to_datetime("2023-01-01"))
with pytest.raises(RandomnessError):
m.update(keys)
m.update(keys, pd.to_datetime("2023-01-01"))


def test_update_empty_good_keys(index_map):
keys = generate_keys(1000)
m = index_map(key_columns=list(keys.columns))
m.update(keys)
m.update(keys, pd.to_datetime("2023-01-01"))
key_index = keys.set_index(list(keys.columns)).index
assert len(m._map) == len(keys), "All keys not in mapping"
assert (
Expand All @@ -211,8 +212,8 @@ def test_update_nonempty_good_keys(index_map):
m = index_map(key_columns=list(keys.columns))
keys1, keys2 = keys[:1000], keys[1000:]

m.update(keys1)
m.update(keys2)
m.update(keys1, pd.to_datetime("2023-01-01"))
m.update(keys2, pd.to_datetime("2023-01-01"))

key_index = keys.set_index(list(keys.columns)).index
assert len(m._map) == len(keys), "All keys not in mapping"
Expand Down

0 comments on commit afb9c76

Please sign in to comment.