Skip to content

Commit

Permalink
Merge pull request #20 from matthewwardrop/instance_state
Browse files Browse the repository at this point in the history
Store instance state separately for better invalidation logic during instantiation and thread-safety.
  • Loading branch information
matthewwardrop committed Mar 26, 2024
2 parents 1f49f88 + 6284366 commit 200cd15
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 11 deletions.
14 changes: 6 additions & 8 deletions spec_classes/methods/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ class InitMethod(MethodDescriptor):
def init(spec_cls, self, **kwargs):
instance_metadata = self.__spec_class__

# Unlock the class for mutation during initialization.
is_frozen = instance_metadata.frozen
if instance_metadata.owner is spec_cls and instance_metadata.frozen:
instance_metadata.frozen = False

# Initialise any non-local spec attributes via parent constructors
if instance_metadata.owner is spec_cls:
for parent in reversed(spec_cls.mro()[1:]):
Expand Down Expand Up @@ -123,8 +118,8 @@ def init(spec_cls, self, **kwargs):
if instance_metadata.post_init:
instance_metadata.post_init(self)

if is_frozen:
instance_metadata.frozen = True
self.__spec_class_state__.initialized = True
self.__spec_class_state__.frozen = instance_metadata.frozen

def build_method(self) -> Callable:
spec_class_key = self.spec_cls.__spec_class__.key
Expand Down Expand Up @@ -237,7 +232,7 @@ class DelAttrMethod(MethodDescriptor):

def build_method(self) -> Callable:
def __delattr__(self, attr, force=False):
if self.__spec_class__.frozen:
if self.__spec_class_state__.frozen:
raise FrozenInstanceError(
f"Cannot mutate attribute `{attr}` of frozen spec class `{self.__class__.__name__}`."
)
Expand Down Expand Up @@ -463,6 +458,9 @@ def deepcopy(self, memo):
new.__dict__[attr] = value
else:
new.__dict__[attr] = protect_via_deepcopy(value, memo)
self.__spec_class__.instance_state[new] = self.__spec_class__.instance_state[
self
]
return new

def build_method(self) -> Callable:
Expand Down
38 changes: 38 additions & 0 deletions spec_classes/spec_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from cached_property import cached_property

from spec_classes.methods import core as core_methods, SCALAR_METHODS, TOPLEVEL_METHODS
from spec_classes.utils.weakref_cache import WeakRefCache

from .types import Attr, MISSING

Expand Down Expand Up @@ -352,7 +353,14 @@ def bootstrap(self, spec_cls: type):
spec_cls.__annotations__[attr] = attr_spec.type

# Finalize metadata and remove bootstrapper from class.
@property
def __spec_class_state__(self):
if self not in self.__spec_class__.instance_state:
self.__spec_class__.instance_state[self] = SpecClassState()
return self.__spec_class__.instance_state[self]

spec_cls.__spec_class__ = metadata
spec_cls.__spec_class_state__ = __spec_class_state__
spec_cls.__dataclass_fields__ = metadata.attrs

# Register class-level methods and validate constructor/etc.
Expand Down Expand Up @@ -535,6 +543,10 @@ class SpecClassMetadata:
post_init: An optional callable to call post __init__. Lifted from
spec-class `__post_init__`. It should take only a single argument
(self).
instance_state: A mapping from object id to instance state. We
attach it here so that this state does not appear in
`spec_cls.__dict__`, where it could be confused attribute
values.
Generated (and cached) properties:
annotations: A mapping from attribute name to type for all of the
Expand Down Expand Up @@ -591,6 +603,7 @@ def for_class(cls, spec_cls: Type) -> SpecClassMetadata:
do_not_copy: bool = False
attrs: Dict[str, Attr] = dataclasses.field(default_factory=dict)
post_init: Optional[Callable[[Any], None]] = None
instance_state: WeakRefCache = dataclasses.field(default_factory=WeakRefCache)

@cached_property
def annotations(self):
Expand Down Expand Up @@ -632,6 +645,31 @@ def invalidation_map(self):
return invalidation_map


@dataclasses.dataclass
class SpecClassState:
"""
A container for the instance state of a spec-class. It is used to control
certain runtime behaviors, like whether a spec-class should be treated as
frozen and/or whether invalidation should be applied.
Attributes:
spec_class: The spec-class instance.
initialized: Whether the spec-class has finished initialization.
frozen: Whether the spec-class should be treated as frozen.
"""

initialized: bool = False
frozen: bool = False

@property
def invalidation_enabled(self) -> bool:
"""
Whether invalidation logic should be applied at this stage in the
spec-class instance's life-cycle.
"""
return self.initialized and not self.frozen


@dataclasses.dataclass
class _SpecClassMetadataPlaceholder:
"""
Expand Down
10 changes: 8 additions & 2 deletions spec_classes/utils/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def mutate_attr(

if metadata:
# Abort if class is frozen.
if not force and inplace and metadata.frozen:
if not force and inplace and obj.__spec_class_state__.frozen:
raise FrozenInstanceError(
f"Cannot mutate attribute `{attr}` of frozen spec class `{obj.__class__.__name__}`."
)
Expand Down Expand Up @@ -135,13 +135,19 @@ def mutate_attr(
raise

# Invalidate any caches depending on this attribute
if metadata and metadata.invalidation_map and not metadata.frozen:
if (
metadata
and obj.__spec_class_state__.invalidation_enabled
and metadata.invalidation_map
):
invalidate_attrs(obj, attr, metadata.invalidation_map)

return obj


def invalidate_attrs(obj: Any, attr: str, invalidation_map: Dict[str, Set[str]] = None):
if not obj.__spec_class_state__.invalidation_enabled:
return
if invalidation_map is None:
invalidation_map = obj.__spec_class__.invalidation_map
if not invalidation_map:
Expand Down
39 changes: 39 additions & 0 deletions spec_classes/utils/weakref_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import weakref
from collections.abc import MutableMapping
from typing import Any


class WeakRef:
def __init__(self, obj: Any):
self._ref = weakref.ref(obj)

def __hash__(self):
return id(self._ref)

def __eq__(self, other):
return self._ref is other._ref


class WeakRefCache(MutableMapping):
def __init__(self):
self.index = weakref.WeakValueDictionary()
self.values = weakref.WeakKeyDictionary()

def __getitem__(self, obj):
return self.values[WeakRef(obj)]

def __setitem__(self, obj, value):
ref = WeakRef(obj)
self.index[ref] = obj
self.values[ref] = value

def __delitem__(self, obj):
ref = WeakRef(obj)
del self.values[ref]
del self.index[ref]

def __iter__(self):
return self.index.values()

def __len__(self):
return len(self.index)
2 changes: 1 addition & 1 deletion tests/test_spec_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,14 +771,14 @@ class Spec:
def test_invalidation(self):
@spec_class
class Spec:
always_invalidated_property: str
attr: str = Attr(default="Hello World")
unmanaged_attr = "Hi"
invalidated_attr: str = Attr(
default="Invalidated",
repr=False,
invalidated_by=["attr", "unmanaged_attr"],
)
always_invalidated_property: str

@spec_property(cache=True, invalidated_by=["unmanaged_attr"])
def invalidated_property(self):
Expand Down
30 changes: 30 additions & 0 deletions tests/utils/test_weakref_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from spec_classes.utils.weakref_cache import WeakRefCache


class Object:
pass


def test_weak_ref_cache():
cache = WeakRefCache()
a = Object()
b = Object()
c = Object()

cache[a] = 1
cache[b] = 2

assert cache[a] == 1
assert cache[b] == 2

assert a in cache
assert b in cache
assert c not in cache
assert len(cache) == 2
assert list(cache) == [a, b]

del cache[b]
assert len(cache) == 1

del a
assert len(cache) == 0

0 comments on commit 200cd15

Please sign in to comment.