Skip to content

Commit

Permalink
feat(common): make frozendict truly immutable
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Jan 11, 2023
1 parent dd352b1 commit 1c25213
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
23 changes: 23 additions & 0 deletions ibis/tests/test_util.py
Expand Up @@ -4,6 +4,7 @@
import pytest

from ibis import util
from ibis.tests.util import assert_pickle_roundtrip


@pytest.mark.parametrize(
Expand Down Expand Up @@ -58,6 +59,28 @@ def test_dotdict():
assert d.x


def test_frozendict():
d = util.frozendict({"a": 1, "b": 2, "c": 3})
e = util.frozendict(a=1, b=2, c=3)
assert d == e
assert d["a"] == 1
assert d["b"] == 2

msg = "'frozendict' object does not support item assignment"
with pytest.raises(TypeError, match=msg):
d["a"] = 2
with pytest.raises(TypeError, match=msg):
d["d"] = 4

with pytest.raises(TypeError):
d.__view__["a"] = 2
with pytest.raises(TypeError):
d.__view__ = {"a": 2}

assert hash(d)
assert_pickle_roundtrip(d)


def test_import_object():
import collections

Expand Down
26 changes: 17 additions & 9 deletions ibis/util.py
Expand Up @@ -48,29 +48,37 @@


class frozendict(Mapping, Hashable):
__slots__ = ("_dict", "_hash")
__slots__ = ("__view__", "__precomputed_hash__")

def __init__(self, *args, **kwargs):
self._dict = dict(*args, **kwargs)
self._hash = hash(tuple(self._dict.items()))
dictview = types.MappingProxyType(dict(*args, **kwargs))
dicthash = hash(tuple(dictview.items()))
object.__setattr__(self, "__view__", dictview)
object.__setattr__(self, "__precomputed_hash__", dicthash)

def __str__(self):
return str(self._dict)
return str(self.__view__)

def __repr__(self):
return f"{self.__class__.__name__}({self._dict!r})"
return f"{self.__class__.__name__}({self.__view__!r})"

def __setattr__(self, name: str, _: Any) -> None:
raise TypeError(f"Attribute {name!r} cannot be assigned to frozendict")

def __reduce__(self):
return frozendict, (dict(self.__view__),)

def __iter__(self):
return iter(self._dict)
return iter(self.__view__)

def __len__(self):
return len(self._dict)
return len(self.__view__)

def __getitem__(self, key):
return self._dict[key]
return self.__view__[key]

def __hash__(self):
return self._hash
return self.__precomputed_hash__


class DotDict(dict):
Expand Down

0 comments on commit 1c25213

Please sign in to comment.