Skip to content

Commit

Permalink
refactor: pickled support with __reduce__
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Mar 4, 2024
1 parent bd7593c commit c621ec6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 47 deletions.
61 changes: 28 additions & 33 deletions src/datamaestro/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@ class Record:

items: Items

unpickled: bool = False
"""Flags unpickled records"""

def __init__(
self, *items: Union[Items, T], override=False, unpickled=False, cls=None
):
def __init__(self, *items: Union[Items, T], override=False, pickled=False):
self.items = {}

if len(items) == 1 and isinstance(items[0], dict):
Expand All @@ -53,19 +48,10 @@ def __init__(
)
self.items[base] = entry

if unpickled:
self.unpickled = True
if pickled:
self.itemtypes = None
else:
self.validate(cls or self.__class__)

def __new__(cls, *items: Union[Items, T], override=False, unpickled=False):
# Without this, impossible to pickle objects
if cls.__trueclass__ is not None:
record = object.__new__(cls.__trueclass__)
record.__init__(*items, cls=cls, override=override, unpickled=unpickled)
return record

return object.__new__(cls)
self.validate()

def __str__(self):
return (
Expand All @@ -74,12 +60,20 @@ def __str__(self):
+ "}"
)

def __getstate__(self):
return self.items
def __reduce__(self):
cls = self.__class__
if cls.__trueclass__ is None:
return (cls.__new__, (cls.__trueclass__ or cls,), {"items": self.items})

return (
cls.__new__,
(cls.__trueclass__ or cls,),
{"items": self.items, "itemtypes": self.itemtypes},
)

def __setstate__(self, state):
self.unpickled = True
self.items = state
self.items = state["items"]
self.itemtypes = None

def validate(self, cls: Type["Record"] = None):
"""Validate the record"""
Expand Down Expand Up @@ -124,6 +118,9 @@ def __getitem__(self, key: Type[T]) -> T:
raise KeyError(f"No entry with type {key}")
return entry

def is_pickled(self):
return self.itemtypes is None

def update(self, *items: T) -> "Record":
"""Update some items"""
# Create our new dictionary
Expand All @@ -135,13 +132,13 @@ def update(self, *items: T) -> "Record":

# --- Class methods and variables

itemtypes: ClassVar[Set[Type[T]]] = []
itemtypes: ClassVar[Optional[Set[Type[T]]]] = []
"""For specific records, this is the list of types. The list is empty when
no validation is used (e.g. pickled records created on the fly)"""

__trueclass__: ClassVar[Optional[Type["Record"]]] = None
"""The last class in the type hierarchy corresponding to an actual type,
i.e. not created on the fly"""
i.e. not created on the fly (only defined when the record is pickled)"""

@classmethod
def has_type(cls, itemtype: Type[T]):
Expand Down Expand Up @@ -172,8 +169,8 @@ def from_types(cls, name: str, *itemtypes: Type[T], module: str = None):
(cls,),
{
**extra_dict,
"__trueclass__": cls.__trueclass__ or cls,
"itemtypes": cls._subclass(*itemtypes),
"__trueclass__": cls.__trueclass__ or cls,
},
)

Expand All @@ -199,9 +196,9 @@ def __init__(self, name: str, *itemtypes: Type[T], module: str = None):
self._name = name
self._itemtypes = itemtypes
self._cache: Dict[Type[Record], Type[Record]] = {}
self._unpickled_warnings = False
self._warning = False

def get(self, record_type: Type[Record], unpickled=False):
def get(self, record_type: Type[Record]):
if updated_type := self._cache.get(record_type, None):
return updated_type

Expand All @@ -214,14 +211,12 @@ def get(self, record_type: Type[Record], unpickled=False):
return updated_type

def update(self, record: Record, *items: Item):
if record.unpickled and not self._unpickled_warnings:
# In that case, impossible to recover the hierarchy
if record.is_pickled() and not self._warning:
logging.warning(
"Updating a pickled record is not recommended:"
" prefer using well defined record types"
"Updating unpickled records is not recommended"
" (no more record checking, and potential speed issues)"
)
self._unpickled_warnings = True

return self.get(record.__class__)(
*record.items.values(), *items, override=True, unpickled=record.unpickled
*record.items.values(), *items, override=True, pickled=record.is_pickled()
)
26 changes: 12 additions & 14 deletions src/datamaestro/test/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class CItem(Item):
c: int


@recordtypes(A1Item)
class BaseRecord(Record):
itemtypes = [A1Item]
...


@recordtypes(BItem)
class MyRecord(BaseRecord):
itemtypes = [A1Item, BItem]
...


@recordtypes(CItem)
Expand All @@ -48,6 +50,7 @@ def test_record_simple():

def test_record_missing_init():
with pytest.raises(KeyError):
# A1Item is missing
MyRecord(AItem(1), BItem(2))

with pytest.raises(KeyError):
Expand All @@ -68,20 +71,11 @@ def test_record_decorator():
MyRecord2(A1Item(1, 2), BItem(2), CItem(3))


def test_record_newtype():
MyRecord2 = MyRecord.from_types("MyRecord2", CItem)
r = MyRecord2(A1Item(1, 2), BItem(2), CItem(3))

# For a dynamic class, we should have the same MyRecord type
assert r.__class__ is MyRecord


def test_record_onthefly():
cache = RecordTypesCache("OnTheFly", CItem)

MyRecord2 = cache.get(MyRecord)
r2 = MyRecord2(A1Item(1, 2), BItem(2), CItem(3))
assert r2.__class__ is MyRecord
MyRecord2(A1Item(1, 2), BItem(2), CItem(3))

assert cache.get(MyRecord) is MyRecord2

Expand All @@ -103,7 +97,11 @@ def test_record_pickled():
r = pickle.loads(pickle.dumps(r))

assert isinstance(r, BaseRecord) and not isinstance(r, MyRecord2)
assert r.unpickled
cache = RecordTypesCache("OnTheFly", CItem)

cache.update(r, CItem(4))
assert r.is_pickled()

r = cache.update(r, CItem(4))

# The result should still be not pickled
assert r.is_pickled()

0 comments on commit c621ec6

Please sign in to comment.