Skip to content

Commit

Permalink
feat(common): better inheritance support for Slotted and FrozenSlotted
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Nov 6, 2023
1 parent 2e3a5a0 commit 9165d41
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 29 deletions.
12 changes: 4 additions & 8 deletions ibis/common/annotations.py
Expand Up @@ -100,7 +100,9 @@ class Annotation(Slotted, Immutable):
Annotations are used to mark fields in a class and to validate them.
"""

__slots__ = ()
__slots__ = ("pattern", "default")
pattern: Pattern
default: AnyType

def validate(self, name: str, value: AnyType, this: AnyType) -> AnyType:
"""Validate the field.
Expand Down Expand Up @@ -142,10 +144,6 @@ class Attribute(Annotation):
Callable to compute the default value of the field.
"""

__slots__ = ("pattern", "default")
pattern: Pattern
default: AnyType

def __init__(self, pattern: Pattern = _any, default: AnyType = EMPTY):
super().__init__(pattern=ensure_pattern(pattern), default=default)

Expand Down Expand Up @@ -199,9 +197,7 @@ class Argument(Annotation):
Defaults to positional or keyword.
"""

__slots__ = ("pattern", "default", "typehint", "kind")
pattern: Pattern
default: AnyType
__slots__ = ("typehint", "kind")
typehint: AnyType
kind: int

Expand Down
34 changes: 23 additions & 11 deletions ibis/common/bases.py
Expand Up @@ -116,6 +116,7 @@ class Final(Abstract):
"""Prohibit subclassing."""

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.__init_subclass__ = cls.__prohibit_inheritance__

@classmethod
Expand Down Expand Up @@ -178,37 +179,45 @@ def __cached_equals__(self, other) -> bool:
return result


class Slotted(Abstract):
class SlottedMeta(AbstractMeta):
def __new__(metacls, clsname, bases, dct, **kwargs):
fields = dct.get("__fields__", dct.get("__slots__", ()))
inherited = (getattr(base, "__fields__", ()) for base in bases)
dct["__fields__"] = sum(inherited, ()) + fields
return super().__new__(metacls, clsname, bases, dct, **kwargs)


class Slotted(Abstract, metaclass=SlottedMeta):
"""A lightweight alternative to `ibis.common.grounds.Annotable`.
The class is mostly used to reduce boilerplate code.
"""

def __init__(self, **kwargs) -> None:
for name, value in kwargs.items():
object.__setattr__(self, name, value)
for field in self.__fields__:
object.__setattr__(self, field, kwargs[field])

def __eq__(self, other) -> bool:
if self is other:
return True
if type(self) is not type(other):
return NotImplemented
return all(getattr(self, n) == getattr(other, n) for n in self.__slots__)
return all(getattr(self, n) == getattr(other, n) for n in self.__fields__)

def __getstate__(self):
return {k: getattr(self, k) for k in self.__slots__}
return {k: getattr(self, k) for k in self.__fields__}

def __setstate__(self, state):
for name, value in state.items():
object.__setattr__(self, name, value)

def __repr__(self):
fields = {k: getattr(self, k) for k in self.__slots__}
fields = {k: getattr(self, k) for k in self.__fields__}
fieldstring = ", ".join(f"{k}={v!r}" for k, v in fields.items())
return f"{self.__class__.__name__}({fieldstring})"

def __rich_repr__(self):
for name in self.__slots__:
for name in self.__fields__:
yield name, getattr(self, name)


Expand All @@ -220,18 +229,21 @@ class FrozenSlotted(Slotted, Immutable, Hashable):
"""

__slots__ = ("__precomputed_hash__",)
__fields__ = ()
__precomputed_hash__: int

def __init__(self, **kwargs) -> None:
for name, value in kwargs.items():
object.__setattr__(self, name, value)
hashvalue = hash(tuple(kwargs.values()))
values = []
for field in self.__fields__:
values.append(value := kwargs[field])
object.__setattr__(self, field, value)
hashvalue = hash((self.__class__, tuple(values)))
object.__setattr__(self, "__precomputed_hash__", hashvalue)

def __setstate__(self, state):
for name, value in state.items():
object.__setattr__(self, name, value)
hashvalue = hash(tuple(state.values()))
hashvalue = hash((self.__class__, tuple(state.values())))
object.__setattr__(self, "__precomputed_hash__", hashvalue)

def __hash__(self) -> int:
Expand Down
60 changes: 50 additions & 10 deletions ibis/common/tests/test_bases.py
Expand Up @@ -266,50 +266,90 @@ class B(A):
class MyObj(Slotted):
__slots__ = ("a", "b")

def __init__(self, a, b):
super().__init__(a=a, b=b)


def test_slotted():
obj = MyObj(1, 2)
obj = MyObj(a=1, b=2)
assert obj.a == 1
assert obj.b == 2
assert obj.__fields__ == ("a", "b")
assert obj.__slots__ == ("a", "b")
with pytest.raises(AttributeError):
obj.c = 3

obj2 = MyObj(1, 2)
obj2 = MyObj(a=1, b=2)
assert obj == obj2
assert obj is not obj2

obj3 = MyObj(1, 3)
obj3 = MyObj(a=1, b=3)
assert obj != obj3

assert pickle.loads(pickle.dumps(obj)) == obj

with pytest.raises(KeyError):
MyObj(a=1)


class MyObj2(MyObj):
__slots__ = ("c",)


def test_slotted_inheritance():
obj = MyObj2(a=1, b=2, c=3)
assert obj.a == 1
assert obj.b == 2
assert obj.c == 3
assert obj.__fields__ == ("a", "b", "c")
assert obj.__slots__ == ("c",)
with pytest.raises(AttributeError):
obj.d = 4

obj2 = MyObj2(a=1, b=2, c=3)
assert obj == obj2
assert obj is not obj2

obj3 = MyObj2(a=1, b=2, c=4)
assert obj != obj3
assert pickle.loads(pickle.dumps(obj)) == obj

with pytest.raises(KeyError):
MyObj2(a=1, b=2)


class MyFrozenObj(FrozenSlotted):
__slots__ = ("a", "b")

def __init__(self, a, b):
super().__init__(a=a, b=b)

class MyFrozenObj2(MyFrozenObj):
__slots__ = ("c", "d")


def test_frozen_slotted():
obj = MyFrozenObj(1, 2)
obj = MyFrozenObj(a=1, b=2)

assert obj.a == 1
assert obj.b == 2
assert obj.__fields__ == ("a", "b")
assert obj.__slots__ == ("a", "b")
with pytest.raises(AttributeError):
obj.b = 3
with pytest.raises(AttributeError):
obj.c = 3

obj2 = MyFrozenObj(1, 2)
obj2 = MyFrozenObj(a=1, b=2)
assert obj == obj2
assert obj is not obj2
assert hash(obj) == hash(obj2)

restored = pickle.loads(pickle.dumps(obj))
assert restored == obj
assert hash(restored) == hash(obj)

with pytest.raises(KeyError):
MyFrozenObj(a=1)


def test_frozen_slotted_inheritance():
obj3 = MyFrozenObj2(a=1, b=2, c=3, d=4)
assert obj3.__slots__ == ("c", "d")
assert obj3.__fields__ == ("a", "b", "c", "d")
assert pickle.loads(pickle.dumps(obj3)) == obj3

0 comments on commit 9165d41

Please sign in to comment.