Skip to content

Commit

Permalink
refactor: use sets
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Mar 4, 2024
1 parent 97d8b79 commit bd7593c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 17 deletions.
60 changes: 47 additions & 13 deletions src/datamaestro/record.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import ClassVar, Type, TypeVar, Dict, List, Union, Optional
import logging
from typing import ClassVar, Type, TypeVar, Dict, List, Union, Optional, Set


class Item:
"""Base class for all item types"""

@classmethod
def __get_base__(cls: Type) -> Type:
"""Get the most generic superclass for this type of item"""
if base := getattr(cls, "__base__cache__", None):
return base

Expand All @@ -22,11 +24,19 @@ def __get_base__(cls: Type) -> Type:


class Record:
"""Associate types with entries"""
"""Associate types with entries
A record is a composition of items; each item base class is unique.
"""

items: Items

def __init__(self, *items: Union[Items, T], override=False, cls=None):
unpickled: bool = False
"""Flags unpickled records"""

def __init__(
self, *items: Union[Items, T], override=False, unpickled=False, cls=None
):
self.items = {}

if len(items) == 1 and isinstance(items[0], dict):
Expand All @@ -43,13 +53,16 @@ def __init__(self, *items: Union[Items, T], override=False, cls=None):
)
self.items[base] = entry

self.validate(cls or self.__class__)
if unpickled:
self.unpickled = True
else:
self.validate(cls or self.__class__)

def __new__(cls, *items: Union[Items, T], override=False):
def __new__(cls, *items: Union[Items, T], override=False, unpickled=False):
# Without this, impossible to pickle objects
if cls.__trueclass__ is not None:
record = object.__new__(cls.__trueclass__)
record.__init__(*items, cls=cls, override=override)
record.__init__(*items, cls=cls, override=override, unpickled=unpickled)
return record

return object.__new__(cls)
Expand All @@ -61,6 +74,13 @@ def __str__(self):
+ "}"
)

def __getstate__(self):
return self.items

def __setstate__(self, state):
self.unpickled = True
self.items = state

def validate(self, cls: Type["Record"] = None):
"""Validate the record"""
cls = cls if cls is not None else self.__class__
Expand Down Expand Up @@ -115,19 +135,21 @@ def update(self, *items: T) -> "Record":

# --- Class methods and variables

itemtypes: ClassVar[List[Type[T]]] = []
"""For specific records, this is the list of types"""
itemtypes: ClassVar[Set[Type[T]]] = []
"""For specific records, this is the list of types. The list is empty when
no validation is used (e.g. pickled records created on the fly)"""

__trueclass__: ClassVar[Optional[Type["Record"]]] = None
"""True when the class is defined in a module"""
"""The last class in the type hierarchy corresponding to an actual type,
i.e. not created on the fly"""

@classmethod
def has_type(cls, itemtype: Type[T]):
return any(issubclass(cls_itemtype, itemtype) for cls_itemtype in cls.itemtypes)

@classmethod
def _subclass(cls, *itemtypes: Type[T]):
cls_itemtypes = [x for x in getattr(cls, "itemtypes", [])]
cls_itemtypes = set((x for x in getattr(cls, "itemtypes", [])))
mapping = {
ix: itemtype.__get_base__() for ix, itemtype in enumerate(cls_itemtypes)
}
Expand All @@ -136,14 +158,15 @@ def _subclass(cls, *itemtypes: Type[T]):
if ix := mapping.get(itemtype.__get_base__(), None):
cls_itemtypes[ix] = itemtype
else:
cls_itemtypes.append(itemtype)
cls_itemtypes.add(itemtype)
return cls_itemtypes

@classmethod
def from_types(cls, name: str, *itemtypes: Type[T], module: str = None):
extra_dict = {}
if module:
extra_dict["__module__"] = module

return type(
name,
(cls,),
Expand Down Expand Up @@ -176,8 +199,9 @@ def __init__(self, name: str, *itemtypes: Type[T], module: str = None):
self._name = name
self._itemtypes = itemtypes
self._cache: Dict[Type[Record], Type[Record]] = {}
self._unpickled_warnings = False

def __getitem__(self, record_type: Type[Record]):
def get(self, record_type: Type[Record], unpickled=False):
if updated_type := self._cache.get(record_type, None):
return updated_type

Expand All @@ -190,4 +214,14 @@ def __getitem__(self, record_type: Type[Record]):
return updated_type

def update(self, record: Record, *items: Item):
return self[record.__class__](*record.items.values(), *items, override=True)
if record.unpickled and not self._unpickled_warnings:
# In that case, impossible to recover the hierarchy
logging.warning(
"Updating a pickled record is not recommended:"
" prefer using well defined record types"
)
self._unpickled_warnings = True

return self.get(record.__class__)(
*record.items.values(), *items, override=True, unpickled=record.unpickled
)
26 changes: 22 additions & 4 deletions src/datamaestro/test/test_record.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from datamaestro.record import Record, Item, RecordTypesCache, recordtypes
from attrs import define
import pytest
Expand All @@ -23,7 +24,11 @@ class CItem(Item):
c: int


class MyRecord(Record):
class BaseRecord(Record):
itemtypes = [A1Item]


class MyRecord(BaseRecord):
itemtypes = [A1Item, BItem]


Expand Down Expand Up @@ -74,18 +79,31 @@ def test_record_newtype():
def test_record_onthefly():
cache = RecordTypesCache("OnTheFly", CItem)

MyRecord2 = cache[MyRecord]
MyRecord2 = cache.get(MyRecord)
r2 = MyRecord2(A1Item(1, 2), BItem(2), CItem(3))
assert r2.__class__ is MyRecord

assert cache[MyRecord] is MyRecord2
assert cache.get(MyRecord) is MyRecord2

r = MyRecord(A1Item(1, 2), BItem(2))
assert cache[r.__class__] is MyRecord2
assert cache.get(r.__class__) is MyRecord2

r = cache.update(r, CItem(3))

# Same record type
cache2 = RecordTypesCache("OnTheFly", CItem)

cache2.update(r, CItem(4))


def test_record_pickled():
# First,
MyRecord2 = BaseRecord.from_types("MyRecordBis", BItem)
r = MyRecord2(A1Item(1, 2), BItem(2))
r = pickle.loads(pickle.dumps(r))

assert isinstance(r, BaseRecord) and not isinstance(r, MyRecord2)
assert r.unpickled
cache = RecordTypesCache("OnTheFly", CItem)

cache.update(r, CItem(4))

0 comments on commit bd7593c

Please sign in to comment.