diff --git a/src/datamaestro/record.py b/src/datamaestro/record.py index 7f85f6a..23d7324 100644 --- a/src/datamaestro/record.py +++ b/src/datamaestro/record.py @@ -1,5 +1,5 @@ import logging -from typing import ClassVar, Type, TypeVar, Dict, List, Union, Optional, Set +from typing import ClassVar, Type, TypeVar, Dict, List, Union, Optional, FrozenSet class Item: @@ -131,7 +131,7 @@ def update(self, *items: T) -> "Record": # --- Class methods and variables - itemtypes: ClassVar[Optional[Set[Type[T]]]] = [] + itemtypes: ClassVar[Optional[FrozenSet[Type[T]]]] = [] """For specific records, this is the list of types. The value is null when no validation is used (e.g. pickled records created on the fly)""" @@ -182,7 +182,7 @@ def from_types(cls, name: str, *itemtypes: Type[T], module: str = None): __RECORD_TYPES_CACHE__: Dict[frozenset, Type["Record"]] = {} @staticmethod - def fromitemtypes(itemtypes: Set[T]): + def fromitemtypes(itemtypes: FrozenSet[T]): if recordtype := Record.__RECORD_TYPES_CACHE__.get(itemtypes, None): return recordtype @@ -227,7 +227,7 @@ def get(self, record_type: Type[Record]): updated_type = record_type.from_types( f"{self._name}_{record_type.__name__}", - *self._itemtypes, + *(itemtype.__get_base__() for itemtype in self._itemtypes), module=self._module, ) self._cache[record_type] = updated_type diff --git a/src/datamaestro/test/test_record.py b/src/datamaestro/test/test_record.py index e564fa5..8e5725d 100644 --- a/src/datamaestro/test/test_record.py +++ b/src/datamaestro/test/test_record.py @@ -19,6 +19,11 @@ class BItem(Item): b: int +@define +class B1Item(BItem): + b1: int + + @define class CItem(Item): c: int @@ -110,3 +115,7 @@ def test_record_pickled(): # This is OK cache.update(r, CItem(4), cls=MyRecord) + + # --- Test when we update a pickled record with an of a sub-class + cache = RecordTypesCache("OnTheFly", B1Item) + r2 = cache.update(r, B1Item(1, 2))