From bd7593cfc1aa406726b863d744228cb7d9ff2a04 Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Mon, 4 Mar 2024 17:11:44 +0100 Subject: [PATCH] refactor: use sets --- src/datamaestro/record.py | 60 ++++++++++++++++++++++------- src/datamaestro/test/test_record.py | 26 +++++++++++-- 2 files changed, 69 insertions(+), 17 deletions(-) diff --git a/src/datamaestro/record.py b/src/datamaestro/record.py index 5ce4df1..a904ae1 100644 --- a/src/datamaestro/record.py +++ b/src/datamaestro/record.py @@ -1,4 +1,5 @@ -from typing import ClassVar, Type, TypeVar, Dict, List, Union, Optional +import logging +from typing import ClassVar, Type, TypeVar, Dict, List, Union, Optional, Set class Item: @@ -6,6 +7,7 @@ class Item: @classmethod def __get_base__(cls: Type) -> Type: + """Get the most generic superclass for this type of item""" if base := getattr(cls, "__base__cache__", None): return base @@ -22,11 +24,19 @@ def __get_base__(cls: Type) -> Type: class Record: - """Associate types with entries""" + """Associate types with entries + + A record is a composition of items; each item base class is unique. + """ items: Items - def __init__(self, *items: Union[Items, T], override=False, cls=None): + unpickled: bool = False + """Flags unpickled records""" + + def __init__( + self, *items: Union[Items, T], override=False, unpickled=False, cls=None + ): self.items = {} if len(items) == 1 and isinstance(items[0], dict): @@ -43,13 +53,16 @@ def __init__(self, *items: Union[Items, T], override=False, cls=None): ) self.items[base] = entry - self.validate(cls or self.__class__) + if unpickled: + self.unpickled = True + else: + self.validate(cls or self.__class__) - def __new__(cls, *items: Union[Items, T], override=False): + def __new__(cls, *items: Union[Items, T], override=False, unpickled=False): # Without this, impossible to pickle objects if cls.__trueclass__ is not None: record = object.__new__(cls.__trueclass__) - record.__init__(*items, cls=cls, override=override) + record.__init__(*items, cls=cls, override=override, unpickled=unpickled) return record return object.__new__(cls) @@ -61,6 +74,13 @@ def __str__(self): + "}" ) + def __getstate__(self): + return self.items + + def __setstate__(self, state): + self.unpickled = True + self.items = state + def validate(self, cls: Type["Record"] = None): """Validate the record""" cls = cls if cls is not None else self.__class__ @@ -115,11 +135,13 @@ def update(self, *items: T) -> "Record": # --- Class methods and variables - itemtypes: ClassVar[List[Type[T]]] = [] - """For specific records, this is the list of types""" + itemtypes: ClassVar[Set[Type[T]]] = [] + """For specific records, this is the list of types. The list is empty when + no validation is used (e.g. pickled records created on the fly)""" __trueclass__: ClassVar[Optional[Type["Record"]]] = None - """True when the class is defined in a module""" + """The last class in the type hierarchy corresponding to an actual type, + i.e. not created on the fly""" @classmethod def has_type(cls, itemtype: Type[T]): @@ -127,7 +149,7 @@ def has_type(cls, itemtype: Type[T]): @classmethod def _subclass(cls, *itemtypes: Type[T]): - cls_itemtypes = [x for x in getattr(cls, "itemtypes", [])] + cls_itemtypes = set((x for x in getattr(cls, "itemtypes", []))) mapping = { ix: itemtype.__get_base__() for ix, itemtype in enumerate(cls_itemtypes) } @@ -136,7 +158,7 @@ def _subclass(cls, *itemtypes: Type[T]): if ix := mapping.get(itemtype.__get_base__(), None): cls_itemtypes[ix] = itemtype else: - cls_itemtypes.append(itemtype) + cls_itemtypes.add(itemtype) return cls_itemtypes @classmethod @@ -144,6 +166,7 @@ def from_types(cls, name: str, *itemtypes: Type[T], module: str = None): extra_dict = {} if module: extra_dict["__module__"] = module + return type( name, (cls,), @@ -176,8 +199,9 @@ def __init__(self, name: str, *itemtypes: Type[T], module: str = None): self._name = name self._itemtypes = itemtypes self._cache: Dict[Type[Record], Type[Record]] = {} + self._unpickled_warnings = False - def __getitem__(self, record_type: Type[Record]): + def get(self, record_type: Type[Record], unpickled=False): if updated_type := self._cache.get(record_type, None): return updated_type @@ -190,4 +214,14 @@ def __getitem__(self, record_type: Type[Record]): return updated_type def update(self, record: Record, *items: Item): - return self[record.__class__](*record.items.values(), *items, override=True) + if record.unpickled and not self._unpickled_warnings: + # In that case, impossible to recover the hierarchy + logging.warning( + "Updating a pickled record is not recommended:" + " prefer using well defined record types" + ) + self._unpickled_warnings = True + + return self.get(record.__class__)( + *record.items.values(), *items, override=True, unpickled=record.unpickled + ) diff --git a/src/datamaestro/test/test_record.py b/src/datamaestro/test/test_record.py index abd3c88..ef84aea 100644 --- a/src/datamaestro/test/test_record.py +++ b/src/datamaestro/test/test_record.py @@ -1,3 +1,4 @@ +import pickle from datamaestro.record import Record, Item, RecordTypesCache, recordtypes from attrs import define import pytest @@ -23,7 +24,11 @@ class CItem(Item): c: int -class MyRecord(Record): +class BaseRecord(Record): + itemtypes = [A1Item] + + +class MyRecord(BaseRecord): itemtypes = [A1Item, BItem] @@ -74,14 +79,14 @@ def test_record_newtype(): def test_record_onthefly(): cache = RecordTypesCache("OnTheFly", CItem) - MyRecord2 = cache[MyRecord] + MyRecord2 = cache.get(MyRecord) r2 = MyRecord2(A1Item(1, 2), BItem(2), CItem(3)) assert r2.__class__ is MyRecord - assert cache[MyRecord] is MyRecord2 + assert cache.get(MyRecord) is MyRecord2 r = MyRecord(A1Item(1, 2), BItem(2)) - assert cache[r.__class__] is MyRecord2 + assert cache.get(r.__class__) is MyRecord2 r = cache.update(r, CItem(3)) @@ -89,3 +94,16 @@ def test_record_onthefly(): cache2 = RecordTypesCache("OnTheFly", CItem) cache2.update(r, CItem(4)) + + +def test_record_pickled(): + # First, + MyRecord2 = BaseRecord.from_types("MyRecordBis", BItem) + r = MyRecord2(A1Item(1, 2), BItem(2)) + r = pickle.loads(pickle.dumps(r)) + + assert isinstance(r, BaseRecord) and not isinstance(r, MyRecord2) + assert r.unpickled + cache = RecordTypesCache("OnTheFly", CItem) + + cache.update(r, CItem(4))