Skip to content

Commit

Permalink
Merge pull request #9 from bojiang/spawn-issue
Browse files Browse the repository at this point in the history
refactor(container)!: add sync_container, container decorator
  • Loading branch information
bojiang committed Jul 7, 2021
2 parents 4c42dca + b05744e commit b42adfe
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 100 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ A simple, strictly typed dependency injection library.
Examples:

```python
from simple_di import inject, Provide, Provider
from simple_di import inject, Provide, Provider, container
from simple_di.providers import Static, Factory, Configuration


class Options(Container):
@container
class OptionsClass(container):
cpu: Provider[int] = Static(2)
worker: Provider[int] = Factory(lambda c: 2 * int(c) + 1, c=cpu)

Options = OptionsClass()

@inject
def func(worker: int = Provide[Options.worker]):
return worker
Expand All @@ -52,6 +55,8 @@ Examples:

## API

- [container](#container)
- [sync_container](#sync_container)
- [inject](#inject)
- [Provide](#Provide)
- [providers](#providers)
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,9 @@
"Programming Language :: Python :: Implementation :: CPython",
],
python_requires=">=3.6.1",
install_requires=[
'dataclasses; python_version < "3.7.0"',
'types-dataclasses; python_version < "3.7.0"',
],
extras_require={"test": ["pytest", "mypy"]},
)
68 changes: 44 additions & 24 deletions simple_di/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
'''
"""
A simple dependency injection framework
'''
"""
import dataclasses
import functools
import inspect
from typing import (
Expand All @@ -15,6 +16,7 @@
cast,
overload,
)
from typing_extensions import Protocol


class _SentinelClass:
Expand All @@ -28,10 +30,10 @@ class _SentinelClass:


class Provider(Generic[VT]):
'''
"""
The base class for Provider implementations. Could be used as the type annotations
of all the implementations.
'''
"""

STATE_FIELDS: Tuple[str, ...] = ("_override",)

Expand All @@ -42,25 +44,25 @@ def _provide(self) -> VT:
raise NotImplementedError

def set(self, value: Union[_SentinelClass, VT]) -> None:
'''
"""
set the value to this provider, overriding the original values
'''
"""
if isinstance(value, _SentinelClass):
return
self._override = value

def get(self) -> VT:
'''
"""
get the value of this provider
'''
"""
if not isinstance(self._override, _SentinelClass):
return self._override
return self._provide()

def reset(self) -> None:
'''
"""
remove the overriding and restore the original value
'''
"""
self._override = sentinel

def __getstate__(self) -> Dict[str, Any]:
Expand All @@ -72,10 +74,10 @@ def __setstate__(self, state: Dict[str, Any]) -> None:


class _ProvideClass:
'''
"""
Used as the default value of a injected functool/method. Would be replaced by the
final value of the provider when this function/method gets called.
'''
"""

def __getitem__(self, provider: Provider[VT]) -> VT:
return provider # type: ignore
Expand All @@ -100,6 +102,9 @@ def _inject_kwargs(


def _inject(func: WrappedCallable, squeeze_none: bool) -> WrappedCallable:
if getattr(func, "_is_injected", False):
return func

sig = inspect.signature(func)

@functools.wraps(func)
Expand All @@ -121,6 +126,7 @@ def _(

return func(*_inject_args(bind.args), **_inject_kwargs(bind.kwargs))

setattr(_, "_is_injected", True)
return cast(WrappedCallable, _)


Expand All @@ -139,29 +145,43 @@ def inject(
def inject(
func: Optional[WrappedCallable] = None, squeeze_none: bool = False
) -> Union[WrappedCallable, Callable[[WrappedCallable], WrappedCallable]]:
'''
"""
Used with `Provide`, inject values to provided defaults of the decorated
function/method when gets called.
'''
"""
if func is None:
wrapped = functools.partial(_inject, squeeze_none=squeeze_none)
return cast(Callable[[WrappedCallable], WrappedCallable], wrapped)
wrapper = functools.partial(_inject, squeeze_none=squeeze_none)
return cast(Callable[[WrappedCallable], WrappedCallable], wrapper)

if callable(func):
return _inject(func, squeeze_none=squeeze_none)

raise ValueError('You must pass either None or Callable')
raise ValueError("You must pass either None or Callable")


class Container:
'''
The base class of containers
'''
def sync_container(from_: Any, to_: Any) -> None:
for f in dataclasses.fields(to_):
src = f.default
target = getattr(from_, f.name, None)
if target is None:
continue
if isinstance(src, Provider):
src.__setstate__(target.__getstate__())
elif dataclasses.is_dataclass(src):
sync_container(src, target)

def __init__(self) -> None:
raise TypeError('Container should not be instantiated')

container = dataclasses.dataclass


skip = not_passed = sentinel

__all__ = ["Container", "Provider", "Provide", "inject", "not_passed", "skip"]
__all__ = [
"container",
"Provider",
"Provide",
"inject",
"not_passed",
"skip",
"sync_container",
]
91 changes: 63 additions & 28 deletions simple_di/providers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
'''
"""
Provider implementations
'''
from typing import Any, Callable as CallableType, Dict, Optional, Tuple, Union
"""
import importlib
from types import LambdaType, ModuleType
from typing import Any, Callable as CallableType, Dict, Tuple, Union

from simple_di import (
Provider,
Expand All @@ -25,11 +27,11 @@


class Static(Provider[VT]):
'''
"""
provider that returns static values
'''
"""

STATE_FIELDS = Provider.STATE_FIELDS + ('_value',)
STATE_FIELDS = Provider.STATE_FIELDS + ("_value",)

def __init__(self, value: VT):
super().__init__()
Expand All @@ -39,34 +41,67 @@ def _provide(self) -> VT:
return self._value


class Callable(Provider[VT]):
'''
def _probe_unique_name(module: ModuleType, origin_name: str) -> str:
name = "__simple_di_" + origin_name.replace(".", "_").replace("<lambda>", "lambda")
num = 0
while hasattr(module, f"{name}{num or ''}"):
num += 1
return f"{name}{num or ''}"


def _patch_anonymous(func: Any) -> None:
module_name = func.__module__
origin_name = func.__qualname__

module = importlib.import_module(module_name)
name = _probe_unique_name(module, origin_name)
func.__qualname__ = name
func.__name__ = name
setattr(module, name, func)


class Factory(Provider[VT]):
"""
provider that returns the result of a callable
'''
"""

STATE_FIELDS = Provider.STATE_FIELDS + ('_args', "_kwargs", "_func")
STATE_FIELDS = Provider.STATE_FIELDS + (
"_args",
"_kwargs",
"_func",
"_chain_inject",
)

def __init__(self, func: CallableType[..., VT], *args: Any, **kwargs: Any) -> None:
super().__init__()
self._args = args
self._kwargs = kwargs
self._func: CallableType[..., VT] = func
self._chain_inject = False
if isinstance(func, classmethod):
raise TypeError("Factory as decorator only supports static methods")
if isinstance(func, LambdaType):
_patch_anonymous(func)
if isinstance(func, staticmethod):
self._chain_inject = True
func = func.__func__
_patch_anonymous(func)
self._func = func

def _provide(self) -> VT:
return self._func(*_inject_args(self._args), **_inject_kwargs(self._kwargs))

def __get__(self, obj: Any, objtype: Any = None) -> "Callable[VT]":
if isinstance(self._func, (classmethod, staticmethod)):
self._func = inject(self._func.__get__(obj, objtype))
return self
if self._chain_inject:
return inject(self._func)(
*_inject_args(self._args), **_inject_kwargs(self._kwargs)
)
else:
return self._func(*_inject_args(self._args), **_inject_kwargs(self._kwargs))


class MemoizedCallable(Callable[VT]):
'''
class SingletonFactory(Factory[VT]):
"""
provider that returns the result of a callable, but memorize the returns.
'''
"""

STATE_FIELDS = Callable.STATE_FIELDS + ("_cache",)
STATE_FIELDS = Factory.STATE_FIELDS + ("_cache",)

def __init__(self, func: CallableType[..., VT], *args: Any, **kwargs: Any) -> None:
super().__init__(func, *args, **kwargs)
Expand All @@ -75,21 +110,21 @@ def __init__(self, func: CallableType[..., VT], *args: Any, **kwargs: Any) -> No
def _provide(self) -> VT:
if not isinstance(self._cache, _SentinelClass):
return self._cache
value = self._func(*_inject_args(self._args), **_inject_kwargs(self._kwargs))
value = super()._provide()
self._cache = value
return value


Factory = Callable
SingletonFactory = MemoizedCallable
Callable = Factory
MemoizedCallable = SingletonFactory


class Configuration(Provider[Dict[str, Any]]):
'''
"""
special provider that reflects the structure of a configuration dictionary.
'''
"""

STATE_FIELDS = Provider.STATE_FIELDS + ('_data', "fallback")
STATE_FIELDS = Provider.STATE_FIELDS + ("_data", "fallback")

def __init__(
self,
Expand Down Expand Up @@ -126,7 +161,7 @@ def __repr__(self) -> str:

class _ConfigurationItem(Provider[Any]):

STATE_FIELDS = Provider.STATE_FIELDS + ('_config', "_path")
STATE_FIELDS = Provider.STATE_FIELDS + ("_config", "_path")

def __init__(self, config: Configuration, path: Tuple[str, ...],) -> None:
super().__init__()
Expand Down
Loading

0 comments on commit b42adfe

Please sign in to comment.