Skip to content

Commit

Permalink
fix: bug fixes and new single record type cache
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Mar 5, 2024
1 parent 082c1e7 commit b3be147
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 27 deletions.
3 changes: 3 additions & 0 deletions docs/source/api/records.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,6 @@ information can be easily added to a record.

.. autoclass:: datamaestro.record.RecordTypesCache
:members: __init__, update

.. autoclass:: datamaestro.record.SingleRecordTypeCache
:members: __init__, update
93 changes: 70 additions & 23 deletions src/datamaestro/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,18 @@ def has_type(cls, itemtype: Type[T]):

@classmethod
def _subclass(cls, *itemtypes: Type[T]):
cls_itemtypes = set((x for x in getattr(cls, "itemtypes", [])))
cls_itemtypes = [x for x in getattr(cls, "itemtypes", [])]
mapping = {
ix: itemtype.__get_base__() for ix, itemtype in enumerate(cls_itemtypes)
itemtype.__get_base__(): ix for ix, itemtype in enumerate(cls_itemtypes)
}

for itemtype in itemtypes:
if ix := mapping.get(itemtype.__get_base__(), None):
if (ix := mapping.get(itemtype.__get_base__(), -1)) >= 0:
cls_itemtypes[ix] = itemtype
else:
cls_itemtypes.add(itemtype)
return cls_itemtypes
cls_itemtypes.append(itemtype)

return frozenset(cls_itemtypes)

@classmethod
def from_types(cls, name: str, *itemtypes: Type[T], module: str = None):
Expand Down Expand Up @@ -205,7 +206,22 @@ def decorate(cls: Type[Record]):
return decorate


class RecordTypesCache:
class RecordTypesCacheBase:
def __init__(self, name: str, *itemtypes: Type[T], module: str = None):
self._module = module
self._name = name
self._itemtypes = itemtypes

def _compute(self, record_type: Type[Record]):
updated_type = record_type.from_types(
f"{self._name}_{record_type.__name__}",
*self._itemtypes,
module=self._module,
)
return updated_type


class RecordTypesCache(RecordTypesCacheBase):
"""Class to use when new record types need to be created on the fly by
adding new items"""

Expand All @@ -215,22 +231,13 @@ def __init__(self, name: str, *itemtypes: Type[T], module: str = None):
:param name: Base name for new record types
:param module: The module name for new types, defaults to None
"""
self._module = module
self._name = name
self._itemtypes = itemtypes
super().__init__(name, *itemtypes, module=module)
self._cache: Dict[Type[Record], Type[Record]] = {}
self._warning = False

def get(self, record_type: Type[Record]):
if updated_type := self._cache.get(record_type, None):
return updated_type

updated_type = record_type.from_types(
f"{self._name}_{record_type.__name__}",
*(itemtype.__get_base__() for itemtype in self._itemtypes),
module=self._module,
)
self._cache[record_type] = updated_type
def __call__(self, record_type: Type[Record]):
if (updated_type := self._cache.get(record_type, None)) is None:
self._cache[record_type] = updated_type = self._compute(record_type)
return updated_type

def update(self, record: Record, *items: Item, cls=None):
Expand All @@ -248,13 +255,53 @@ def update(self, record: Record, *items: Item, cls=None):
"Updating unpickled records is not recommended"
" (speed issues): use the pickle record class as the cls input"
)
itemtypes = frozenset(
itemtype.__get_base__() for itemtype in record.items.keys()
)
itemtypes = frozenset(type(item) for item in record.items.values())
cls = Record.fromitemtypes(itemtypes)
else:
assert (
record.is_pickled()
), "cls can be used only when the record as been pickled"

return self.get(cls)(*record.items.values(), *items, override=True)
return self(cls)(*record.items.values(), *items, override=True)


class SingleRecordTypeCache(RecordTypesCacheBase):
"""Class to use when new record types need to be created on the fly by
adding new items
This class supposes that the input record type is always the same (no check
is done to ensure this)"""

def __init__(self, name: str, *itemtypes: Type[T], module: str = None):
"""Creates a new cache
:param name: Base name for new record types
:param module: The module name for new types, defaults to None
"""
super().__init__(name, *itemtypes, module=module)
self._cache: Optional[Type[Record]] = None

def __call__(self, record_type: Type[Record]):
if self._cache is None:
self._cache = self._compute(record_type)
return self._cache

def update(self, record: Record, *items: Item, cls=None):
"""Update the record with the given items
:param record: The record to which we add items
:param cls: The class of the record, useful if the record has been
pickled, defaults to None
:return: A new record with the extra items
"""
if self._cache is None:
if cls is None:
cls = record.__class__
itemtypes = frozenset(type(item) for item in record.items.values())
cls = Record.fromitemtypes(itemtypes)
else:
assert (
record.is_pickled()
), "cls can be used only when the record as been pickled"

return self(cls)(*record.items.values(), *items, override=True)
34 changes: 30 additions & 4 deletions src/datamaestro/test/test_record.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import pickle
from datamaestro.record import Record, Item, RecordTypesCache, recordtypes
from datamaestro.record import (
Record,
Item,
RecordTypesCache,
recordtypes,
SingleRecordTypeCache,
)
from attrs import define
import pytest

Expand Down Expand Up @@ -76,16 +82,21 @@ def test_record_decorator():
MyRecord2(A1Item(1, 2), BItem(2), CItem(3))


def test_record_type_update():
itemtypes = MyRecord2.from_types("Test", B1Item).itemtypes
assert itemtypes == frozenset((A1Item, B1Item, CItem))


def test_record_onthefly():
cache = RecordTypesCache("OnTheFly", CItem)

MyRecord2 = cache.get(MyRecord)
MyRecord2 = cache(MyRecord)
MyRecord2(A1Item(1, 2), BItem(2), CItem(3))

assert cache.get(MyRecord) is MyRecord2
assert cache(MyRecord) is MyRecord2

r = MyRecord(A1Item(1, 2), BItem(2))
assert cache.get(r.__class__) is MyRecord2
assert cache(r.__class__) is MyRecord2

r = cache.update(r, CItem(3))

Expand Down Expand Up @@ -119,3 +130,18 @@ def test_record_pickled():
# --- Test when we update a pickled record with an of a sub-class
cache = RecordTypesCache("OnTheFly", B1Item)
r2 = cache.update(r, B1Item(1, 2))


def test_record_pickled_single():
MyRecord2 = BaseRecord.from_types("MyRecordBis", BItem)
r = MyRecord2(A1Item(1, 2), BItem(2))
r = pickle.loads(pickle.dumps(r))

cache = SingleRecordTypeCache("OnTheFly", CItem)

updated = cache.update(r, CItem(4))

assert updated.itemtypes == frozenset((A1Item, BItem, CItem))

# Even with the wrong record, no change now
assert cache(BaseRecord).itemtypes == frozenset((A1Item, BItem, CItem))

0 comments on commit b3be147

Please sign in to comment.