Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/inject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def my_config(binder):
else:
_HAS_PEP560_SUPPORT = sys.version_info[:3] >= (3, 7, 0) # PEP 560
_RETURN = 'return'
_MISSING = object()

if _HAS_PEP604_SUPPORT:
from types import UnionType
Expand Down Expand Up @@ -325,6 +326,10 @@ def __init__(self, cls: Type[T] | Hashable) -> None:
doc="Return an attribute injection",
)

def __set_name__(self, owner: Type[T], name: str) -> None:
if self._cls is _MISSING:
self._cls = _unwrap_cls_annotation(owner, name)


class _ParameterInjection(Generic[T]):
__slots__ = ('_name', '_cls')
Expand Down Expand Up @@ -522,13 +527,16 @@ def instance(cls: Binding) -> Injectable:
"""Inject an instance of a class."""
return get_injector_or_die().get_instance(cls)

@overload
def attr() -> Injectable: ...

@overload
def attr(cls: Hashable) -> Injectable: ...

@overload
def attr(cls: Type[T]) -> T: ...

def attr(cls):
def attr(cls=_MISSING):
"""Return an attribute injection (descriptor)."""
return _AttributeInjection(cls)

Expand Down Expand Up @@ -653,3 +661,14 @@ def _is_union_type(typ):
return (typ is Union or
isinstance(typ, _GenericAlias) and typ.__origin__ is Union)
return type(typ) is _Union


def _unwrap_cls_annotation(cls: Type, attr_name: str):
types = get_type_hints(cls)
try:
attr_type = types[attr_name]
except KeyError:
msg = f"Couldn't find type annotation for {attr_name}"
raise InjectorException(msg)

return _unwrap_union_arg(attr_type)
38 changes: 21 additions & 17 deletions test/test_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@ class MyDataClass:
class MyClass:
field = inject.attr(int)
field2: int = inject.attr(int)
auto_typed_field: int = inject.attr()

inject.configure(lambda binder: binder.bind(int, 123))
my = MyClass()
my_dc = MyDataClass()
value0 = my.field
value1 = my.field
value2 = my_dc.field
value3 = my_dc.field

assert value0 == 123
assert value1 == 123
assert value2 == 123
assert value3 == 123
assert my.field == 123
assert my.field == 123
assert my.field2 == 123
assert my.field2 == 123
assert my.auto_typed_field == 123
assert my.auto_typed_field == 123
assert my_dc.field == 123
assert my_dc.field == 123

def test_invalid_attachment_to_dataclass(self):
@dataclass
Expand All @@ -36,21 +37,24 @@ class MyDataClass:

def test_class_attr(self):
descriptor = inject.attr(int)
auto_descriptor = inject.attr()

@dataclass
class MyDataClass:
field = descriptor

class MyClass(object):
field = descriptor
field2: int = descriptor
auto_typed_field: int = auto_descriptor

inject.configure(lambda binder: binder.bind(int, 123))
value0 = MyClass.field
value1 = MyClass.field
value2 = MyDataClass.field
value3 = MyDataClass.field

assert value0 is descriptor
assert value1 is descriptor
assert value2 is descriptor
assert value3 is descriptor

assert MyClass.field is descriptor
assert MyClass.field is descriptor
assert MyClass.field2 is descriptor
assert MyClass.field2 is descriptor
assert MyClass.auto_typed_field is auto_descriptor
assert MyClass.auto_typed_field is auto_descriptor
assert MyDataClass.field is descriptor
assert MyDataClass.field is descriptor