From 29714bb258c1e876aa34569ce749c5f5127ec579 Mon Sep 17 00:00:00 2001 From: Dima Burmistrov Date: Tue, 19 Aug 2025 01:30:31 +0400 Subject: [PATCH] Implement class member type auto-discovery for inject.attr --- src/inject/__init__.py | 21 ++++++++++++++++++++- test/test_attr.py | 38 +++++++++++++++++++++----------------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/src/inject/__init__.py b/src/inject/__init__.py index 88db0d1..8517a16 100644 --- a/src/inject/__init__.py +++ b/src/inject/__init__.py @@ -94,6 +94,7 @@ def my_config(binder): else: _HAS_PEP560_SUPPORT = sys.version_info[:3] >= (3, 7, 0) # PEP 560 _RETURN = 'return' +_MISSING = object() if _HAS_PEP604_SUPPORT: from types import UnionType @@ -325,6 +326,10 @@ def __init__(self, cls: Type[T] | Hashable) -> None: doc="Return an attribute injection", ) + def __set_name__(self, owner: Type[T], name: str) -> None: + if self._cls is _MISSING: + self._cls = _unwrap_cls_annotation(owner, name) + class _ParameterInjection(Generic[T]): __slots__ = ('_name', '_cls') @@ -522,13 +527,16 @@ def instance(cls: Binding) -> Injectable: """Inject an instance of a class.""" return get_injector_or_die().get_instance(cls) +@overload +def attr() -> Injectable: ... + @overload def attr(cls: Hashable) -> Injectable: ... @overload def attr(cls: Type[T]) -> T: ... -def attr(cls): +def attr(cls=_MISSING): """Return an attribute injection (descriptor).""" return _AttributeInjection(cls) @@ -653,3 +661,14 @@ def _is_union_type(typ): return (typ is Union or isinstance(typ, _GenericAlias) and typ.__origin__ is Union) return type(typ) is _Union + + +def _unwrap_cls_annotation(cls: Type, attr_name: str): + types = get_type_hints(cls) + try: + attr_type = types[attr_name] + except KeyError: + msg = f"Couldn't find type annotation for {attr_name}" + raise InjectorException(msg) + + return _unwrap_union_arg(attr_type) diff --git a/test/test_attr.py b/test/test_attr.py index 93cbbe3..6f4fb53 100644 --- a/test/test_attr.py +++ b/test/test_attr.py @@ -12,19 +12,20 @@ class MyDataClass: class MyClass: field = inject.attr(int) field2: int = inject.attr(int) + auto_typed_field: int = inject.attr() inject.configure(lambda binder: binder.bind(int, 123)) my = MyClass() my_dc = MyDataClass() - value0 = my.field - value1 = my.field - value2 = my_dc.field - value3 = my_dc.field - assert value0 == 123 - assert value1 == 123 - assert value2 == 123 - assert value3 == 123 + assert my.field == 123 + assert my.field == 123 + assert my.field2 == 123 + assert my.field2 == 123 + assert my.auto_typed_field == 123 + assert my.auto_typed_field == 123 + assert my_dc.field == 123 + assert my_dc.field == 123 def test_invalid_attachment_to_dataclass(self): @dataclass @@ -36,6 +37,7 @@ class MyDataClass: def test_class_attr(self): descriptor = inject.attr(int) + auto_descriptor = inject.attr() @dataclass class MyDataClass: @@ -43,14 +45,16 @@ class MyDataClass: class MyClass(object): field = descriptor + field2: int = descriptor + auto_typed_field: int = auto_descriptor inject.configure(lambda binder: binder.bind(int, 123)) - value0 = MyClass.field - value1 = MyClass.field - value2 = MyDataClass.field - value3 = MyDataClass.field - - assert value0 is descriptor - assert value1 is descriptor - assert value2 is descriptor - assert value3 is descriptor + + assert MyClass.field is descriptor + assert MyClass.field is descriptor + assert MyClass.field2 is descriptor + assert MyClass.field2 is descriptor + assert MyClass.auto_typed_field is auto_descriptor + assert MyClass.auto_typed_field is auto_descriptor + assert MyDataClass.field is descriptor + assert MyDataClass.field is descriptor