Skip to content

Commit

Permalink
perf(common): improve Concrete construction performance
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Aug 19, 2023
1 parent 41df410 commit 2cb1a55
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 36 deletions.
15 changes: 12 additions & 3 deletions ibis/common/annotations.py
Expand Up @@ -103,8 +103,18 @@ def default(self, fn):
"""Annotation to mark a field with a default value computed by a callable."""
return Attribute(default=fn)

def initialize(self, this):
"""Compute the default value of the field."""
def initialize(self, this: AnyType) -> AnyType:
"""Compute the default value of the field.
Parameters
----------
this
The instance of the class the attribute is defined on.
Returns
-------
The default value for the field.
"""
if self._default is EMPTY:
return EMPTY
elif callable(self._default):
Expand Down Expand Up @@ -429,7 +439,6 @@ def validate_return(self, value, context):


# TODO(kszucs): try to cache pattern objects
# TODO(kszucs): try a quicker curry implementation


def annotated(_1=None, _2=None, _3=None, **kwargs):
Expand Down
73 changes: 48 additions & 25 deletions ibis/common/grounds.py
Expand Up @@ -17,7 +17,6 @@
Argument,
Attribute,
Signature,
attribute,
)
from ibis.common.bases import ( # noqa: F401
Base,
Expand Down Expand Up @@ -67,11 +66,10 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
namespace, arguments = {}, {}
for name, attrib in dct.items():
if isinstance(attrib, Pattern):
attrib = Argument.required(attrib)

if isinstance(attrib, Argument):
arguments[name] = Argument.required(attrib)
slots.append(name)
elif isinstance(attrib, Argument):
arguments[name] = attrib
attributes[name] = attrib
slots.append(name)
elif isinstance(attrib, Attribute):
attributes[name] = attrib
Expand Down Expand Up @@ -103,10 +101,17 @@ def __or__(self, other):
class Annotable(Base, metaclass=AnnotableMeta):
"""Base class for objects with custom validation rules."""

__argnames__: ClassVar[tuple[str, ...]]
__signature__: ClassVar[Signature]
"""Signature of the class, containing the Argument annotations."""

__attributes__: ClassVar[FrozenDict[str, Annotation]]
"""Mapping of the Attribute annotations."""

__argnames__: ClassVar[tuple[str, ...]]
"""Names of the arguments."""

__match_args__: ClassVar[tuple[str, ...]]
__signature__: ClassVar[Signature]
"""Names of the arguments to be used for pattern matching."""

@classmethod
def __create__(cls, *args: Any, **kwargs: Any) -> Self:
Expand All @@ -120,19 +125,20 @@ def __recreate__(cls, kwargs: Any) -> Self:
kwargs = cls.__signature__.validate_nobind(**kwargs)
return super().__create__(**kwargs)

def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, **kwargs: Any) -> None:
# set the already validated arguments
for name, value in kwargs.items():
object.__setattr__(self, name, value)

# post-initialize the remaining attributes
# initialize the remaining attributes
for name, field in self.__attributes__.items():
if isinstance(field, Attribute):
if (value := field.initialize(self)) is not EMPTY:
object.__setattr__(self, name, value)
if (default := field.initialize(self)) is not EMPTY:
object.__setattr__(self, name, default)

def __setattr__(self, name, value) -> None:
if field := self.__attributes__.get(name):
# first try to look up the argument then the attribute
if param := self.__signature__.parameters.get(name):
value = param.annotation.validate(value, self)
elif field := self.__attributes__.get(name):
value = field.validate(value, self)
super().__setattr__(name, value)

Expand All @@ -142,13 +148,17 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({argstring})"

def __eq__(self, other) -> bool:
# compare types
if type(self) is not type(other):
return NotImplemented

return all(
getattr(self, name, None) == getattr(other, name, None)
for name in self.__attributes__
)
# compare arguments
if self.__args__ != other.__args__:
return False
# compare attributes
for name in self.__attributes__:
if getattr(self, name, None) != getattr(other, name, None):
return False
return True

@property
def __args__(self) -> tuple[Any, ...]:
Expand Down Expand Up @@ -176,13 +186,26 @@ def copy(self, **overrides: Any) -> Annotable:
class Concrete(Immutable, Comparable, Annotable):
"""Opinionated base class for immutable data classes."""

@attribute.default
def __args__(self):
return tuple(getattr(self, name) for name in self.__argnames__)
__slots__ = ("__args__", "__precomputed_hash__")

@attribute.default
def __precomputed_hash__(self) -> int:
return hash((self.__class__, self.__args__))
def __init__(self, **kwargs: Any) -> None:
# collect and set the arguments in a single pass
args = []
for name in self.__argnames__:
value = kwargs[name]
args.append(value)
object.__setattr__(self, name, value)

# precompute the hash value since the instance is immutable
args = tuple(args)
hashvalue = hash((self.__class__, args))
object.__setattr__(self, "__args__", args)
object.__setattr__(self, "__precomputed_hash__", hashvalue)

# initialize the remaining attributes
for name, field in self.__attributes__.items():
if (default := field.initialize(self)) is not EMPTY:
object.__setattr__(self, name, default)

def __reduce__(self):
# assuming immutability and idempotency of the __init__ method, we can
Expand Down
95 changes: 88 additions & 7 deletions ibis/common/patterns.py
Expand Up @@ -1183,6 +1183,64 @@ def match(self, value, context):
class SequenceOf(Slotted, Pattern):
"""Pattern that matches if all of the items in a sequence match a given pattern.
Specialization of the more flexible GenericSequenceOf pattern which uses two
additional patterns to possibly coerce the sequence type and to match on
the length of the sequence.
Parameters
----------
item
The pattern to match against each item in the sequence.
type
The type to coerce the sequence to. Defaults to tuple.
"""

__slots__ = ("item", "type")
item: Pattern
type: type

def __new__(
cls,
item,
type: type = tuple,
exactly: Optional[int] = None,
at_least: Optional[int] = None,
at_most: Optional[int] = None,
):
if (
exactly is not None
or at_least is not None
or at_most is not None
or issubclass(type, Coercible)
):
return GenericSequenceOf(
item, type=type, exactly=exactly, at_least=at_least, at_most=at_most
)
else:
return super().__new__(cls)

def __init__(self, item, type=tuple):
super().__init__(item=pattern(item), type=type)

def match(self, values, context):
try:
iterable = iter(values)
except TypeError:
return NoMatch

result = []
for item in iterable:
item = self.item.match(item, context)
if item is NoMatch:
return NoMatch
result.append(item)

return self.type(result)


class GenericSequenceOf(Slotted, Pattern):
"""Pattern that matches if all of the items in a sequence match a given pattern.
Parameters
----------
item
Expand All @@ -1199,9 +1257,27 @@ class SequenceOf(Slotted, Pattern):

__slots__ = ("item", "type", "length")
item: Pattern
type: CoercedTo
type: Pattern
length: Length

def __new__(
cls,
item: Pattern,
type: type = tuple,
exactly: Optional[int] = None,
at_least: Optional[int] = None,
at_most: Optional[int] = None,
):
if (
exactly is None
and at_least is None
and at_most is None
and not issubclass(type, Coercible)
):
return SequenceOf(item, type=type)
else:
return super().__new__(cls)

def __init__(
self,
item: Pattern,
Expand All @@ -1216,11 +1292,13 @@ def __init__(
super().__init__(item=item, type=type, length=length)

def match(self, values, context):
if not is_iterable(values):
try:
iterable = iter(values)
except TypeError:
return NoMatch

result = []
for value in values:
for value in iterable:
value = self.item.match(value, context)
if value is NoMatch:
return NoMatch
Expand Down Expand Up @@ -1272,7 +1350,7 @@ def match(self, values, context):
return tuple(result)


class MappingOf(Slotted, Pattern):
class GenericMappingOf(Slotted, Pattern):
"""Pattern that matches if all of the keys and values match the given patterns.
Parameters
Expand All @@ -1288,7 +1366,7 @@ class MappingOf(Slotted, Pattern):
__slots__ = ("key", "value", "type")
key: Pattern
value: Pattern
type: CoercedTo
type: Pattern

def __init__(self, key: Pattern, value: Pattern, type: type = dict):
super().__init__(key=pattern(key), value=pattern(value), type=CoercedTo(type))
Expand All @@ -1312,6 +1390,9 @@ def match(self, value, context):
return result


MappingOf = GenericMappingOf


class Attrs(Slotted, Pattern):
__slots__ = ("fields",)
fields: FrozenDict[str, Pattern]
Expand Down Expand Up @@ -1500,8 +1581,8 @@ def match(self, value, context):
if isinstance(following, Capture):
following = following.pattern

if isinstance(current, (SequenceOf, PatternSequence)):
if isinstance(following, SequenceOf):
if isinstance(current, (SequenceOf, GenericSequenceOf, PatternSequence)):
if isinstance(following, (SequenceOf, GenericSequenceOf)):
following = following.item
elif isinstance(following, PatternSequence):
# first pattern to match from the pattern window
Expand Down
2 changes: 1 addition & 1 deletion ibis/common/tests/test_grounds.py
Expand Up @@ -865,7 +865,7 @@ def double_a(self):
op = Value(1)
assert op.a == 1
assert op.double_a == 2
assert len(Value.__attributes__) == 2
assert len(Value.__attributes__) == 1
assert "double_a" in Value.__slots__


Expand Down
32 changes: 32 additions & 0 deletions ibis/common/tests/test_grounds_benchmarks.py
@@ -0,0 +1,32 @@
from __future__ import annotations

import pytest

from ibis.common.annotations import attribute
from ibis.common.collections import frozendict
from ibis.common.grounds import Concrete

pytestmark = pytest.mark.benchmark


class MyObject(Concrete):
a: int
b: str
c: tuple[int, ...]
d: frozendict[str, int]

@attribute.default
def e(self):
return self.a * 2

@attribute.default
def f(self):
return self.b * 2

@attribute.default
def g(self):
return self.c * 2


def test_concrete_construction(benchmark):
benchmark(MyObject, 1, "2", c=(3, 4), d=frozendict(e=5, f=6))
26 changes: 26 additions & 0 deletions ibis/common/tests/test_patterns.py
Expand Up @@ -42,6 +42,7 @@
FrozenDictOf,
Function,
GenericInstanceOf,
GenericSequenceOf,
Innermost,
InstanceOf,
IsIn,
Expand Down Expand Up @@ -409,13 +410,38 @@ def test_isin():

def test_sequence_of():
p = SequenceOf(InstanceOf(str), list)
assert isinstance(p, SequenceOf)
assert p.match(["foo", "bar"], context={}) == ["foo", "bar"]
assert p.match([1, 2], context={}) is NoMatch
assert p.match(1, context={}) is NoMatch


def test_generic_sequence_of():
class MyList(list, Coercible):
@classmethod
def __coerce__(cls, value, T=...):
return cls(value)

p = SequenceOf(InstanceOf(str), MyList)
assert isinstance(p, GenericSequenceOf)
assert p == GenericSequenceOf(InstanceOf(str), MyList)
assert p.match(["foo", "bar"], context={}) == MyList(["foo", "bar"])

p = SequenceOf(InstanceOf(str), tuple, at_least=1)
assert isinstance(p, GenericSequenceOf)
assert p == GenericSequenceOf(InstanceOf(str), tuple, at_least=1)
assert p.match(("foo", "bar"), context={}) == ("foo", "bar")
assert p.match([], context={}) is NoMatch

p = GenericSequenceOf(InstanceOf(str), list)
assert isinstance(p, SequenceOf)
assert p == SequenceOf(InstanceOf(str), list)
assert p.match(("foo", "bar"), context={}) == ["foo", "bar"]


def test_list_of():
p = ListOf(InstanceOf(str))
assert isinstance(p, SequenceOf)
assert p.match(["foo", "bar"], context={}) == ["foo", "bar"]
assert p.match([1, 2], context={}) is NoMatch
assert p.match(1, context={}) is NoMatch
Expand Down

0 comments on commit 2cb1a55

Please sign in to comment.