Skip to content

Commit

Permalink
Merge pull request #278 from ihmeuw/collijk/bugfix/index-map-loc-bug
Browse files Browse the repository at this point in the history
Rework index map to add introspectibility and fix outstanding bug
  • Loading branch information
collijk committed Feb 23, 2023
2 parents 32cb407 + 4c59742 commit ba47f86
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 65 deletions.
119 changes: 102 additions & 17 deletions src/vivarium/framework/randomness/index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
import datetime
from typing import Union
from typing import List, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -21,42 +21,127 @@
class IndexMap:
"""A key-index mapping with a vectorized hash and vectorized lookups."""

SIM_INDEX_COLUMN = "simulant_index"
TEN_DIGIT_MODULUS = 10_000_000_000

def __init__(self, use_crn: bool = True, size: int = 1_000_000):
self._use_crn = use_crn
self._map = pd.Series(dtype=int)
def __init__(self, key_columns: List[str] = None, size: int = 1_000_000):
self._use_crn = bool(key_columns)
self._key_columns = key_columns
self._map = None
self._size = size

def update(self, new_keys: pd.Index) -> None:
def update(self, new_keys: pd.DataFrame) -> None:
"""Adds the new keys to the mapping.
Parameters
----------
new_keys
The new index to hash.
A pandas DataFrame indexed by the simulant index and columns corresponding to
the randomness system key columns.
"""
if new_keys.empty or not self._use_crn:
return # Nothing to do

new_index = self._map.index.append(new_keys)
if len(new_index) != len(new_index.unique()):
new_mapping_index, final_mapping_index = self._parse_new_keys(new_keys)

final_keys = final_mapping_index.droplevel(self.SIM_INDEX_COLUMN)
if len(final_keys) != len(final_keys.unique()):
raise RandomnessError("Non-unique keys in index")

mapping_update = self._hash(new_keys)
if self._map.empty:
self._map = mapping_update.drop_duplicates()
final_mapping = self._build_final_mapping(new_mapping_index)

# Tack on the simulant index to the front of the map.
final_mapping.index = final_mapping_index
self._map = final_mapping

def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.MultiIndex]:
"""Parses raw new keys into the mapping index.
Parameters
----------
new_keys
A pandas DataFrame indexed by the simulant index and columns corresponding to
the randomness system key columns.
Returns
-------
Tuple[pd.MultiIndex, pd.MultiIndex]
A tuple of the new mapping index and the final mapping index. Both are pandas
indices with a level for the index assigned by the population system and
additional levels for the key columns associated with the simulant index. The
new mapping index contains only the values for the new keys and the final mapping
combines the existing mapping and the new mapping index.
"""
keys = new_keys.copy()
keys.index.name = self.SIM_INDEX_COLUMN
new_mapping_index = keys.set_index(self._key_columns, append=True).index

if self._map is None:
final_mapping_index = new_mapping_index
else:
self._map = pd.concat([self._map, mapping_update]).drop_duplicates()
final_mapping_index = self._map.index.append(new_mapping_index)
return new_mapping_index, final_mapping_index

def _build_final_mapping(self, new_mapping_index: pd.Index) -> pd.Series:
"""Builds a new mapping between key columns and the randomness index from the
new mapping index and the existing map.
collisions = mapping_update.index.difference(self._map.index)
Parameters
----------
new_mapping_index
An index with a level for the index assigned by the population system and
additional levels for the key columns associated with the simulant index.
Returns
-------
pd.Series
The new mapping incorporating the updates from the new mapping index and
resolving collisions.
"""
new_key_index = new_mapping_index.droplevel(self.SIM_INDEX_COLUMN)
mapping_update = self._hash(new_key_index)
if self._map is None:
current_map = mapping_update
else:
old_map = self._map.droplevel(self.SIM_INDEX_COLUMN)
current_map = pd.concat([old_map, mapping_update])

return self._resolve_collisions(new_key_index, current_map)

def _resolve_collisions(
self,
new_key_index: pd.MultiIndex,
current_mapping: pd.Series,
) -> pd.Series:
"""Resolves collisions in the new mapping by perturbing the hash.
Parameters
----------
new_key_index
The index of new key attributes to hash.
current_mapping
The new mapping incorporating the updates from the new mapping index with
collisions unresolved.
Returns
-------
pd.Series
The new mapping incorporating the updates from the new mapping index and
resolving collisions.
"""
current_mapping = current_mapping.drop_duplicates()
collisions = new_key_index.difference(current_mapping.index)
salt = 1
while not collisions.empty:
mapping_update = self._hash(collisions, salt)
self._map = pd.concat([self._map, mapping_update]).drop_duplicates()
collisions = mapping_update.index.difference(self._map.index)
current_mapping = pd.concat([current_mapping, mapping_update]).drop_duplicates()
collisions = mapping_update.index.difference(current_mapping.index)
salt += 1
return current_mapping

def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series:
"""Hashes the index into an integer index in the range [0, self.stride]
Expand Down Expand Up @@ -158,8 +243,8 @@ def _shift(self, m: Union[float, pd.Series]) -> Union[int, pd.Series]:
def __getitem__(self, index: pd.Index) -> pd.Series:
if not self._use_crn:
return pd.Series(index, index=index)
if isinstance(index, (pd.Index, pd.MultiIndex)):
return self._map[index]
if isinstance(index, pd.Index):
return self._map.loc[index]
else:
raise IndexError(index)

Expand Down
5 changes: 2 additions & 3 deletions src/vivarium/framework/randomness/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ def setup(self, builder):
self._clock = builder.time.clock()
self._key_columns = builder.configuration.randomness.key_columns

use_crn = bool(self._key_columns)
map_size = builder.configuration.randomness.map_size
pop_size = builder.configuration.population.population_size
map_size = max(map_size, 10 * pop_size)
self._key_mapping = IndexMap(use_crn, map_size)
self._key_mapping = IndexMap(self._key_columns, map_size)

self.resources = builder.resources
self._add_constraint = builder.lifecycle.add_constraint
Expand Down Expand Up @@ -168,7 +167,7 @@ def register_simulants(self, simulants: pd.DataFrame):
raise RandomnessError(
"The simulants dataframe does not have all specified key_columns."
)
self._key_mapping.update(simulants.set_index(self._key_columns).index)
self._key_mapping.update(simulants.loc[:, self._key_columns])

def __str__(self):
return "RandomnessManager()"
Expand Down
7 changes: 3 additions & 4 deletions src/vivarium/framework/randomness/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,12 @@ def get_draw(self, index: pd.Index, additional_key: Any = None) -> pd.Series:
pandas.Series
A series of random numbers indexed by the provided `pandas.Index`.
"""
key = self._key(additional_key)
if self.initializes_crn_attributes:
draw = random(
self._key(additional_key), pd.Index(range(len(index))), self.index_map
)
draw = random(key, pd.Index(range(len(index))))
draw.index = index
else:
draw = random(self._key(additional_key), index, self.index_map)
draw = random(key, index, self.index_map)

return draw

Expand Down
4 changes: 2 additions & 2 deletions src/vivarium/testing_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,5 @@ def reset_mocks(mocks):
mock.reset_mock()


def metadata(file_path):
return {"layer": "override", "source": str(Path(file_path).resolve())}
def metadata(file_path, layer="override"):
return {"layer": layer, "source": str(Path(file_path).resolve())}
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def base_config():
},
"randomness": {"key_columns": ["entrance_time", "age"]},
},
**metadata(__file__),
**metadata(__file__, layer="model_override"),
)
return config

Expand Down
16 changes: 10 additions & 6 deletions tests/framework/randomness/test_crn.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,12 @@ def test_multi_sim_reproducibility_with_different_pop_growth(with_crn, pop_class
assert_frame_equal(pop1.loc[overlap], pop2.loc[overlap])


class BrokenPopulation(BasePopulation):
"""CRN system falls over if the first CRN attribute is an int or float."""
class UnBrokenPopulation(BasePopulation):
"""CRN system used to fall over if the first CRN attribute is an int or float.
This is now a regression testing class.
"""

def on_initialize_simulants(self, pop_data: SimulantData):
crn_attr = (1_000_000 * self.randomness_init.get_draw(index=pop_data.index)).astype(
Expand All @@ -280,20 +284,20 @@ def on_initialize_simulants(self, pop_data: SimulantData):
@pytest.mark.parametrize(
"with_crn, sims_to_add",
[
pytest.param(True, cycle([0]), marks=pytest.mark.xfail),
pytest.param(True, cycle([1]), marks=pytest.mark.xfail),
pytest.param(True, cycle([0])),
pytest.param(True, cycle([1])),
pytest.param(False, cycle([0])),
pytest.param(False, cycle([1])),
],
)
def test_failure_path_when_first_crn_attribute_not_datelike(with_crn, sims_to_add):
def test_prior_failure_path_when_first_crn_attribute_not_datelike(with_crn, sims_to_add):
if with_crn:
configuration = {"randomness": {"key_columns": ["crn_attr1", "crn_attr2"]}}
else:
configuration = {}

sim = InteractiveContext(
components=[BrokenPopulation(with_crn=with_crn, sims_to_add=sims_to_add)],
components=[UnBrokenPopulation(with_crn=with_crn, sims_to_add=sims_to_add)],
configuration=configuration,
)

Expand Down
60 changes: 35 additions & 25 deletions tests/framework/randomness/test_index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,29 @@ def almost_powerset(iterable):
def generate_keys(number, types=("int", "float", "datetime"), seed=123456):
rs = np.random.RandomState(seed=seed)

index = []
keys = {}
if "datetime" in types:
year = rs.choice(np.arange(1980, 2018))
day = rs.choice(pd.date_range(f"01/01/{year}", periods=365))
start_time = rs.choice(pd.date_range(day, periods=86400, freq="s"))
freq = rs.choice(["ms", "s", "min", "h"])
index.append(pd.date_range(start_time, periods=number, freq=freq))
keys["datetime"] = pd.date_range(start_time, periods=number, freq=freq)

if "int" in types:
kind = rs.choice(["random", "sequential"])
if kind == "random":
ints = np.unique(rs.randint(0, 1000 * number, size=100 * number))
assert len(ints) > number
rs.shuffle(ints)
index.append(ints[:number])
keys["int"] = ints[:number]
else:
start = rs.randint(0, 100 * number)
index.append(np.arange(start, start + number, dtype=int))
keys["int"] = np.arange(start, start + number, dtype=int)

if "float" in types:
index.append(rs.random_sample(size=number))
keys["float"] = rs.random_sample(size=number)

return pd.Series(0, index=index).index
return pd.DataFrame(keys, index=pd.RangeIndex(number))


rs = np.random.RandomState(seed=456789)
Expand All @@ -55,8 +55,9 @@ def id_fun(param):

@pytest.fixture(scope="module", params=list(product(index_sizes, types, seeds)), ids=id_fun)
def map_size_and_hashed_values(request):
keys = generate_keys(*request.param)
m = IndexMap()
index_size, types_, seed = request.param
keys = generate_keys(*request.param).set_index(types_).index
m = IndexMap(key_columns=types_)
return len(m), m._hash(keys)


Expand Down Expand Up @@ -166,47 +167,56 @@ def test_hash_uniformity(map_size_and_hashed_values):

@pytest.fixture(scope="function")
def index_map(mocker):
m = IndexMap()
mock_index_map = IndexMap

def hash_mock(k, salt=0):
seed = 123456
rs = np.random.RandomState(seed=seed + salt)
return pd.Series(rs.randint(0, len(k) * 10, size=len(k)), index=k)

mocker.patch.object(m, "_hash", side_effect=hash_mock)
mocker.patch.object(mock_index_map, "_hash", side_effect=hash_mock)

return m
return mock_index_map


def test_update_empty_bad_keys(index_map):
keys = pd.Index(["a"] * 10)
keys = pd.DataFrame({"A": ["a"] * 10}, index=range(10))
m = index_map(key_columns=list(keys.columns))
with pytest.raises(RandomnessError):
index_map.update(keys)
m.update(keys)


def test_update_nonempty_bad_keys(index_map):
keys = generate_keys(1000)

index_map.update(keys)
m = index_map(key_columns=list(keys.columns))
m.update(keys)
with pytest.raises(RandomnessError):
index_map.update(keys)
m.update(keys)


def test_update_empty_good_keys(index_map):
keys = generate_keys(1000)
index_map.update(keys)
assert len(index_map._map) == len(keys), "All keys not in mapping"
assert index_map._map.index.difference(keys).empty, "All keys not in mapping"
assert len(index_map._map.unique()) == len(keys), "Duplicate values in mapping"
m = index_map(key_columns=list(keys.columns))
m.update(keys)
key_index = keys.set_index(list(keys.columns)).index
assert len(m._map) == len(keys), "All keys not in mapping"
assert (
m._map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty
), "Extra keys in mapping"
assert len(m._map.unique()) == len(keys), "Duplicate values in mapping"


def test_update_nonempty_good_keys(index_map):
keys = generate_keys(2000)
m = index_map(key_columns=list(keys.columns))
keys1, keys2 = keys[:1000], keys[1000:]

index_map.update(keys1)
index_map.update(keys2)
m.update(keys1)
m.update(keys2)

assert len(index_map._map) == len(keys), "All keys not in mapping"
assert index_map._map.index.difference(keys).empty, "All keys not in mapping"
assert len(index_map._map.unique()) == len(keys), "Duplicate values in mapping"
key_index = keys.set_index(list(keys.columns)).index
assert len(m._map) == len(keys), "All keys not in mapping"
assert (
m._map.index.droplevel(m.SIM_INDEX_COLUMN).difference(key_index).empty
), "Extra keys in mapping"
assert len(m._map.unique()) == len(keys), "Duplicate values in mapping"

0 comments on commit ba47f86

Please sign in to comment.