Skip to content

Commit

Permalink
Use siphash for persistent hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jul 4, 2024
1 parent 444c4b6 commit 738b005
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 27 deletions.
1 change: 1 addition & 0 deletions .pylintrc-local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
- arg: ignored-modules
val:
- matplotlib
- siphash24
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ classifiers = [
dependencies = [
"platformdirs>=2.2",
"typing-extensions>=4; python_version<'3.11'",
"siphash24>=1.6",
]

[project.optional-dependencies]
Expand Down
11 changes: 7 additions & 4 deletions pytools/persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"""


import hashlib
import logging
import os
import pickle
Expand All @@ -52,6 +51,8 @@
cast,
)

from siphash24 import siphash13


if TYPE_CHECKING:
from _typeshed import ReadableBuffer
Expand Down Expand Up @@ -160,7 +161,7 @@ class KeyBuilder:

# this exists so that we can (conceivably) switch algorithms at some point
# down the road
new_hash: Callable[..., Hash] = hashlib.sha256
new_hash: Callable[..., Hash] = siphash13

def rec(self, key_hash: Hash, key: Any) -> Hash:
"""
Expand Down Expand Up @@ -301,7 +302,8 @@ def update_for_frozenset(self, key_hash: Hash, key: FrozenSet[Any]) -> None:

unordered_hash(
key_hash,
(self.rec(self.new_hash(), key_i).digest() for key_i in key))
(self.rec(self.new_hash(), key_i).digest() for key_i in key),
hash_constructor=self.new_hash)

update_for_FrozenOrderedSet = update_for_frozenset # noqa: N815

Expand Down Expand Up @@ -351,7 +353,8 @@ def update_for_frozendict(self, key_hash: Hash, key: Mapping[Any, Any]) -> None:

unordered_hash(
key_hash,
(self.rec(self.new_hash(), (k, v)).digest() for k, v in key.items()))
(self.rec(self.new_hash(), (k, v)).digest() for k, v in key.items()),
hash_constructor=self.new_hash)

update_for_immutabledict = update_for_frozendict
update_for_constantdict = update_for_frozendict
Expand Down
38 changes: 15 additions & 23 deletions pytools/test/test_persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,17 +548,14 @@ class TagClass2(Tag):
assert keyb(TagClass) != keyb(TagClass2)
assert keyb(TagClass()) != keyb(TagClass2())

assert keyb(TagClass()) == \
"f5697a96dde0083e31a290b54ee7a5640b2bb8eb6d18e9c7ee89228b015a6131"
assert keyb(TagClass2) == \
"0833645765e32e7fb4a586614d0e345878eba50199ed2d8e963b28f797fd6e29"
assert keyb(TagClass()) == "7b3e4e66503438f6"
assert keyb(TagClass2) == "690b86bbf51aad83"

@tag_dataclass
class TagClass3(Tag):
s: str

assert (keyb(TagClass3("foo")) # type: ignore[call-arg]
== "c6521f4157ed530d04e956b7046db85e038c120b047cd1b848340d81f9fd8b4a")
assert (keyb(TagClass3("foo")) == "cf1a33652cc75b9c") # type: ignore[call-arg]


def test_dataclass_hashing() -> None:
Expand All @@ -569,8 +566,7 @@ class MyDC:
name: str
value: int

assert keyb(MyDC("hi", 1)) == \
"2ba6363c3b98f1cc2209bd57388368b3efe3074e3764eee30fbcf15946efb802"
assert keyb(MyDC("hi", 1)) == "d1a1079f1c10aa4f"

assert keyb(MyDC("hi", 1)) == keyb(MyDC("hi", 1))
assert keyb(MyDC("hi", 1)) != keyb(MyDC("hi", 2))
Expand All @@ -594,8 +590,7 @@ class MyAttrs:
name: str
value: int

assert (keyb(MyAttrs("hi", 1)) # type: ignore[call-arg]
== "17f272d114d22c1dc0117354777f2d506b303d90e10840d39fb0eef007252f68")
assert (keyb(MyAttrs("hi", 1)) == "5b6c5da60eb2bd0f") # type: ignore[call-arg]

assert keyb(MyAttrs("hi", 1)) == keyb(MyAttrs("hi", 1)) # type: ignore[call-arg]
assert keyb(MyAttrs("hi", 1)) != keyb(MyAttrs("hi", 2)) # type: ignore[call-arg]
Expand Down Expand Up @@ -626,7 +621,7 @@ def test_datetime_hashing() -> None:
# No timezone info; date is always naive
assert (keyb(datetime.date(2020, 1, 1))
== keyb(datetime.date(2020, 1, 1))
== "9fb97d7faabc3603f3e334ca5eb1eb0fe0c92665e5611cb1b5aa77fa0f70f5e3")
== "1c866ff10ff0d997")
assert keyb(datetime.date(2020, 1, 1)) != keyb(datetime.date(2020, 1, 2))

# }}}
Expand All @@ -640,7 +635,7 @@ def test_datetime_hashing() -> None:
== keyb(datetime.time(12, 0))
== keyb(datetime.time(12, 0, 0))
== keyb(datetime.time(12, 0, 0, 0))
== "288ec82f6a00ac15968d4d257d4aca1089b863c61ef2ee200e64351238397705")
== "e523be74ebc6b227")
assert keyb(datetime.time(12, 0)) != keyb(datetime.time(12, 1))

# Aware time
Expand All @@ -653,7 +648,7 @@ def test_datetime_hashing() -> None:
assert t1 == t2
assert (keyb(t1)
== keyb(t2)
== "3587427ca9d581779d532b397df206ddeadfcf4e38b1ee69c19174e8e1268cc4")
== "2041e7cd5b17b8eb")

assert t1 != t3
assert keyb(t1) != keyb(t3)
Expand All @@ -672,7 +667,7 @@ def test_datetime_hashing() -> None:
assert dt1 == dt2
assert (keyb(dt1)
== keyb(dt2)
== "cd35722af47e42cb3bc81c389b87eb2e78ee8e20298bb1d8a193b30940d1c142")
== "8be96b9e739c7d8c")

dt3 = datetime.datetime(2020, 1, 1, 7,
tzinfo=datetime.timezone(datetime.timedelta(hours=-4)))
Expand All @@ -688,7 +683,7 @@ def test_datetime_hashing() -> None:
assert (keyb(datetime.datetime(2020, 1, 1))
== keyb(datetime.datetime(2020, 1, 1))
== keyb(datetime.datetime(2020, 1, 1, 0, 0, 0, 0))
== "8f3b843d7b9176afd8e2ce97ebc19789098a1c7774c4ec00d4054ec954ce2b88"
== "215dbe82add7a55c"
)
assert keyb(datetime.datetime(2020, 1, 1)) != keyb(datetime.datetime(2020, 1, 2))
assert (keyb(datetime.datetime(2020, 1, 1))
Expand All @@ -711,7 +706,7 @@ def test_datetime_hashing() -> None:
assert tz2 == tz3
assert (keyb(tz2)
== keyb(tz3)
== "89bd615f32c1f209b0853b1fc7d06ddb6fda7f367a00a8621d60337d52cb8d10")
== "5e1d46ab778c7ccf")

# }}}

Expand Down Expand Up @@ -771,7 +766,7 @@ def test_size():

size = pdict.nbytes()
print("sqlite size: ", size/1024/1024, " MByte")
assert 1*1024*1024 < size < 2*1024*1024
assert 1024*1024//2 < size < 2*1024*1024
finally:
shutil.rmtree(tmpdir)

Expand Down Expand Up @@ -841,8 +836,7 @@ def test_hash_function() -> None:

# {{{ global functions

assert keyb(global_fun) == keyb(global_fun) == \
"51b5980dd3a8aa13f6e83869e4a04c22973d7aaf96cb22899abdfdc55e15c9b2"
assert keyb(global_fun) == keyb(global_fun) == "79efd03f9a38ed77"
assert keyb(global_fun) != keyb(global_fun2)

# }}}
Expand Down Expand Up @@ -882,8 +876,7 @@ def local_fun():
def local_fun2():
pass

assert keyb(local_fun) == keyb(local_fun) == \
"fc58f5b0130df821913c848749eb03f5dcd4da7a568c6130f1c0cfb96ed0d12d"
assert keyb(local_fun) == keyb(local_fun) == "adc92e690b62dc2b"
assert keyb(local_fun) != keyb(local_fun2)

# }}}
Expand All @@ -898,8 +891,7 @@ class C2:
def method(self):
pass

assert keyb(C1.method) == keyb(C1.method) == \
"3013eb424dac133a57bd70cb6084d2a2f349a247714efc508fe3b10b99b6f717"
assert keyb(C1.method) == keyb(C1.method) == "af19e056ad7749c4"
assert keyb(C1.method) != keyb(C2.method)

# }}}
Expand Down

0 comments on commit 738b005

Please sign in to comment.