Skip to content

Commit

Permalink
fix: record with a sub-class item
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Mar 5, 2024
1 parent 2f35592 commit 082c1e7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/datamaestro/record.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import ClassVar, Type, TypeVar, Dict, List, Union, Optional, Set
from typing import ClassVar, Type, TypeVar, Dict, List, Union, Optional, FrozenSet


class Item:
Expand Down Expand Up @@ -131,7 +131,7 @@ def update(self, *items: T) -> "Record":

# --- Class methods and variables

itemtypes: ClassVar[Optional[Set[Type[T]]]] = []
itemtypes: ClassVar[Optional[FrozenSet[Type[T]]]] = []
"""For specific records, this is the list of types. The value is null when
no validation is used (e.g. pickled records created on the fly)"""

Expand Down Expand Up @@ -182,7 +182,7 @@ def from_types(cls, name: str, *itemtypes: Type[T], module: str = None):
__RECORD_TYPES_CACHE__: Dict[frozenset, Type["Record"]] = {}

@staticmethod
def fromitemtypes(itemtypes: Set[T]):
def fromitemtypes(itemtypes: FrozenSet[T]):
if recordtype := Record.__RECORD_TYPES_CACHE__.get(itemtypes, None):
return recordtype

Expand Down Expand Up @@ -227,7 +227,7 @@ def get(self, record_type: Type[Record]):

updated_type = record_type.from_types(
f"{self._name}_{record_type.__name__}",
*self._itemtypes,
*(itemtype.__get_base__() for itemtype in self._itemtypes),
module=self._module,
)
self._cache[record_type] = updated_type
Expand Down
9 changes: 9 additions & 0 deletions src/datamaestro/test/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class BItem(Item):
b: int


@define
class B1Item(BItem):
b1: int


@define
class CItem(Item):
c: int
Expand Down Expand Up @@ -110,3 +115,7 @@ def test_record_pickled():

# This is OK
cache.update(r, CItem(4), cls=MyRecord)

# --- Test when we update a pickled record with an of a sub-class
cache = RecordTypesCache("OnTheFly", B1Item)
r2 = cache.update(r, B1Item(1, 2))

0 comments on commit 082c1e7

Please sign in to comment.