Skip to content

Commit

Permalink
Fix broken raise KeyError for non-scalar keys.
Browse files Browse the repository at this point in the history
  • Loading branch information
bwoodsend committed Apr 13, 2021
1 parent 5b08445 commit 4f556fe
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
5 changes: 3 additions & 2 deletions hoatzin/_hash_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,13 @@ def get(self, keys, default=-1) -> np.ndarray:

return out if shape else out.item()

@staticmethod
def _blame_key(index, keys, shape) -> Tuple[str, str]:
def _blame_key(self, index, keys, shape) -> Tuple[str, str]:
"""Get a key and its location from a ravelled index. Used to prettify
key errors."""
assert index >= 0
if len(shape) == 0:
if self._dtype_shape:
return "key", repr(keys)
return "key", repr(keys.item())
if len(shape) == 1:
return f"keys[{index}]", repr(keys[index])
Expand Down
31 changes: 31 additions & 0 deletions tests/test_hash_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,37 @@ def test_getting():
self["toad"]


def test_blame_key_multidimensional():
"""Test that the custom KeyErrors work for non scalar keys. """

# Create a hash table for float triplets.
self = HashTable(10, dtype=(float, 3))
keys = np.arange(24, dtype=float).reshape((-1, 3))
# Add all but the last key.
self.add(keys[:-1])

# Try getting the last key. The resultant key errors should always point to
# the correct one being missing.
with pytest.raises(KeyError, match=r"key = array\(\[21., 22., 23.\]\) is"):
self[keys[-1]]
with pytest.raises(KeyError, match=r"keys\[7\] = array\(\[21"):
self[keys]
with pytest.raises(KeyError, match=r"keys\[3, 1\] = array\(\[21"):
self[keys.reshape((4, 2, 3))]


def test_blame_key_structured():
"""Similar to test_blame_key_multidimensional() but for struct dtypes."""
self = HashTable(10, dtype=[("name", str, 10), ("age", int)])
keys = np.array([("bill", 10), ("bob", 12), ("ben", 13)], self.dtype)
self.add(keys[:-1])

with pytest.raises(KeyError, match=r"key = \('ben', 13\) is"):
self[keys[-1]]
with pytest.raises(KeyError, match=r"keys\[2\] = \('ben', 13\) is"):
self[keys]


def test_destroy():
self = HashTable(10, float)
self.add([.3, .5, .8])
Expand Down

0 comments on commit 4f556fe

Please sign in to comment.