From 71164972af31936cb56dee872136700432c2e556 Mon Sep 17 00:00:00 2001 From: Jace Browning Date: Sun, 18 Apr 2021 14:28:08 -0400 Subject: [PATCH] Loop through all converter subclasses when resolving string annotations --- CHANGELOG.md | 4 ++++ datafiles/converters/__init__.py | 4 ++-- datafiles/tests/test_converters.py | 5 +++++ datafiles/utils.py | 6 ++++++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 39ff7c33..80bdb1ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 0.13.1 (2021-04-18) + +- Fixed handling of string annotations for extended types. + # 0.13 (2021-04-17) - Added support for generic types. diff --git a/datafiles/converters/__init__.py b/datafiles/converters/__init__.py index 312476d9..ea7a7329 100644 --- a/datafiles/converters/__init__.py +++ b/datafiles/converters/__init__.py @@ -6,7 +6,7 @@ import log from ruamel.yaml.scalarfloat import ScalarFloat -from ..utils import cached +from ..utils import cached, subclasses from ._bases import Converter from .builtins import Boolean, Float, Integer, String from .containers import Dataclass, Dictionary, List, Set @@ -119,7 +119,7 @@ def map_type(cls, *, name: str = '', item_cls: Optional[type] = None): if isinstance(cls, str): log.debug(f'Searching for class matching {cls!r} annotation') - for cls2 in Converter.__subclasses__(): + for cls2 in subclasses(Converter): if cls2.__name__ == cls: register(cls, cls2) log.debug(f'Registered {cls2} as new converter') diff --git a/datafiles/tests/test_converters.py b/datafiles/tests/test_converters.py index 575e3584..f5fbdaf4 100644 --- a/datafiles/tests/test_converters.py +++ b/datafiles/tests/test_converters.py @@ -112,6 +112,11 @@ def it_handles_string_type_annotations(expect): converter = converters.map_type('float') expect(converter.TYPE) == float + def it_handles_string_type_annotations_for_extensions(expect): + converter = converters.map_type('Number') + expect(converter.TYPE) == float + expect(converter.__name__) == 'Number' + def it_rejects_unknown_types(expect): with expect.raises( TypeError, diff --git a/datafiles/utils.py b/datafiles/utils.py index e65d88fe..a8d47386 100644 --- a/datafiles/utils.py +++ b/datafiles/utils.py @@ -18,6 +18,12 @@ cached = lru_cache() +def subclasses(cls): + return set(cls.__subclasses__()).union( + [s for c in cls.__subclasses__() for s in subclasses(c)] + ) + + def get_default_field_value(instance, name): for field in dataclasses.fields(instance): if field.name == name: