Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

重构Context #83

Merged
merged 1 commit into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/nonebot_plugin_pixivbot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@
)

# ================= provide beans =================
from importlib import import_module

from nonebot import logger

from .global_context import context
from .nb_providers import provide

Expand All @@ -48,6 +44,9 @@
from . import service

# ============== load custom protocol =============
from importlib import import_module
from nonebot import logger

supported_modules = ["nonebot_plugin_pixivbot_onebot_v11", "nonebot_plugin_pixivbot_kook"]

for p in supported_modules:
Expand Down
195 changes: 67 additions & 128 deletions src/nonebot_plugin_pixivbot/context.py
Original file line number Diff line number Diff line change
@@ -1,186 +1,140 @@
import functools
import inspect
import sys
import types
from threading import Lock
from typing import TypeVar, Type, Callable, Dict, Any
from abc import ABC, abstractmethod
from typing import TypeVar, Type, Callable, Union, Generic

from nonebot.log import logger

T = TypeVar("T")
T2 = TypeVar("T2")

# copied from inspect.py (py3.10)
def get_annotations(obj, *, globals=None, locals=None, eval_str=False) -> Dict[str, Any]:
if sys.version_info >= (3, 10):
return inspect.get_annotations(obj, globals=globals, locals=locals, eval_str=eval_str)

if isinstance(obj, type):
# class
obj_dict = getattr(obj, '__dict__', None)
if obj_dict and hasattr(obj_dict, 'get'):
ann = obj_dict.get('__annotations__', None)
if isinstance(ann, types.GetSetDescriptorType):
ann = None
else:
ann = None

obj_globals = None
module_name = getattr(obj, '__module__', None)
if module_name:
module = sys.modules.get(module_name, None)
if module:
obj_globals = getattr(module, '__dict__', None)
obj_locals = dict(vars(obj))
unwrap = obj
elif isinstance(obj, types.ModuleType):
# module
ann = getattr(obj, '__annotations__', None)
obj_globals = getattr(obj, '__dict__')
obj_locals = None
unwrap = None
elif callable(obj):
# this includes types.Function, types.BuiltinFunctionType,
# types.BuiltinMethodType, functools.partial, functools.singledispatch,
# "class funclike" from Lib/test/test_inspect... on and on it goes.
ann = getattr(obj, '__annotations__', None)
obj_globals = getattr(obj, '__globals__', None)
obj_locals = None
unwrap = obj
else:
raise TypeError(f"{obj!r} is not a module, class, or callable.")

if ann is None:
return {}

if not isinstance(ann, dict):
raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")

if not ann:
return {}

if not eval_str:
return dict(ann)

if unwrap is not None:
while True:
if hasattr(unwrap, '__wrapped__'):
unwrap = unwrap.__wrapped__
continue
if isinstance(unwrap, functools.partial):
unwrap = unwrap.func
continue
break
if hasattr(unwrap, "__globals__"):
obj_globals = unwrap.__globals__

if globals is None:
globals = obj_globals
if locals is None:
locals = obj_locals

return_value = {key: value if not isinstance(value, str) else eval(value, globals, locals)
for key, value in ann.items()}
return return_value
class Inject:
def __init__(self, key):
self._key = key

def __get__(self, instance, owner):
context = getattr(instance, "__context__")
return context.require(self._key)

T = TypeVar("T")

class Provider(ABC, Generic[T]):
@abstractmethod
def provide(self) -> T:
raise NotImplementedError()


class InstanceProvider(Provider[T], Generic[T]):
def __init__(self, instance: T):
self._instance = instance

def provide(self) -> T:
return self._instance


class DynamicProvider(Provider[T], Generic[T]):
def __init__(self, func: Callable[[], T], use_cache: bool = True):
self._func = func
self._use_cache = use_cache

self._cache = None
self._cached = False

def provide(self) -> T:
if not self._use_cache:
return self._func()

if not self._cached:
self._cache = self._func() # just let it throw
self._cached = True
return self._cache


class Context:
def __init__(self, parent=None):
def __init__(self, parent: "Context" = None):
self._parent = parent
self._container = {}
self._lazy_container = {}
self._binding = {}
self._lock = Lock()

@property
def parent(self):
def parent(self) -> "Context":
return self._parent

@property
def root(self):
def root(self) -> "Context":
if self._parent is None:
return self
else:
return self._parent.root

def register(self, key: Type[T], bean: T):
def register(self, key: Type[T], bean: Union[T, Provider[T]]):
"""
register a bean
"""
if not isinstance(bean, Provider):
bean = InstanceProvider(bean)

self._container[key] = bean
logger.trace(f"registered bean {key}")
logger.trace(f"registered bean {key}, provider type: {type(bean)}")

def register_lazy(self, key: Type[T], bean_initializer: Callable[[], T]):
"""
register a bean lazily
"""
if key in self._container:
del self._container[key]
self._lazy_container[key] = bean_initializer
self._container[key] = DynamicProvider(bean_initializer)
logger.trace(f"lazily registered bean {key}")

def register_singleton(self, *args, **kwargs):
def register_singleton(self, *args, **kwargs) -> Callable[[Type[T]], Type[T]]:
"""
decorator for a class to register a bean lazily
"""

def decorator(cls):
def decorator(cls: Type[T]) -> Type[T]:
self.register_lazy(cls, lambda: cls(*args, **kwargs))
return cls

return decorator

def register_eager_singleton(self, *args, **kwargs):
def register_eager_singleton(self, *args, **kwargs) -> Callable[[Type[T]], Type[T]]:
"""
decorator for a class to register a bean lazily
"""

def decorator(cls):
def decorator(cls: Type[T]) -> Type[T]:
bean = cls(*args, **kwargs)
self.register(cls, bean)
return cls

return decorator

def unregister(self, key: Type[T]):
def unregister(self, key: Type[T]) -> bool:
"""
unregister the bean of key
"""
if key in self._container:
del self._container[key]
if key in self._lazy_container:
del self._container[key]
return True
return False

def bind_to(self, key, src_key):
def bind(self, key: Type[T], src_key: Type[T2]):
"""
bind key (usually the implementation class) to src_key (usually the base class)
"""
self._binding[key] = src_key
self._container[key] = self._container[src_key]
logger.trace(f"bind bean {key} to {src_key}")

def bind_singleton_to(self, key, *args, **kwargs):
def bind_singleton_to(self, key: Type[T], *args, **kwargs) -> Callable[[Type[T2]], Type[T2]]:
"""
decorator for a class (usually the implementation class) to bind to another class (usually the base class)
"""

def decorator(cls):
def decorator(cls: Type[T2]) -> Type[T2]:
self.register_singleton(*args, **kwargs)(cls)
self.bind_to(key, cls)
self.bind(key, cls)
return cls

return decorator

def require(self, key: Type[T]) -> T:
if key in self._binding:
return self.require(self._binding[key])
elif key in self._container:
return self._container[key]
elif key in self._lazy_container:
# TODO: Lock
self.register(key, self._lazy_container[key]())
del self._lazy_container[key]
return self._container[key]
if key in self._container:
return self._container[key].provide()
elif self._parent is not None:
return self._parent.require(key)
else:
Expand All @@ -190,31 +144,16 @@ def __getitem__(self, key: Type[T]):
return self.require(key)

def __contains__(self, key: Type[T]) -> bool:
if key in self._binding or key in self._container or key in self._lazy_container:
if key in self._container:
return True
elif self._parent is not None:
return self._parent.__contains__(key)
else:
return False

def inject(self, cls: Type[T]):
old_getattr = getattr(cls, "__getattr__", None)

def __getattr__(obj: T, name: str):
ann = get_annotations(cls, eval_str=True)
if name in ann and ann[name] in self:
return self[ann[name]]

if old_getattr:
return old_getattr(obj, name)
else:
for c in cls.mro()[1:]:
c_getattr = getattr(c, "__getattr__", None)
if c_getattr:
return c_getattr(obj, name)

setattr(cls, "__getattr__", __getattr__)
setattr(cls, "__context__", self)
return cls


__all__ = ("Context",)
__all__ = ("Context", "Inject")
2 changes: 1 addition & 1 deletion src/nonebot_plugin_pixivbot/data/errors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from nonebot_plugin_pixivbot.data.pixiv_repo.abstract_repo import PixivRepoMetadata
pass


class DataSourceNotReadyError(RuntimeError):
Expand Down
3 changes: 2 additions & 1 deletion src/nonebot_plugin_pixivbot/data/local_tag_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nonebot_plugin_pixivbot.global_context import context
from nonebot_plugin_pixivbot.model import Tag
from .source import MongoDataSource
from ..context import Inject


class LocalTag(Tag, Document):
Expand All @@ -24,7 +25,7 @@ class Settings:
@context.inject
@context.register_singleton()
class LocalTagRepo:
mongo: MongoDataSource
mongo = Inject(MongoDataSource)

@classmethod
async def find_by_name(cls, name: str) -> Optional[Tag]:
Expand Down
3 changes: 2 additions & 1 deletion src/nonebot_plugin_pixivbot/data/pixiv_repo/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

from nonebot_plugin_pixivbot.config import Config
from .pkg_context import context
from ...context import Inject


@context.inject
@context.register_singleton()
class Compressor:
_conf: Config
_conf = Inject(Config)

def __init__(self) -> None:
self.enabled = self._conf.pixiv_compression_enabled
Expand Down
8 changes: 4 additions & 4 deletions src/nonebot_plugin_pixivbot/data/pixiv_repo/local_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Union, Sequence, Any, AsyncGenerator, Optional, Type, Mapping

import bson
from beanie import BulkWriter
from beanie.odm.operators.find.comparison import In
from beanie.odm.operators.find.logical import And
from beanie.odm.operators.update.array import AddToSet
Expand All @@ -22,6 +21,7 @@
from .pkg_context import context
from ..local_tag_repo import LocalTagRepo
from ..source import MongoDataSource
from ...context import Inject


def _handle_expires_in(metadata: PixivRepoMetadata, expires_in: int):
Expand All @@ -32,9 +32,9 @@ def _handle_expires_in(metadata: PixivRepoMetadata, expires_in: int):
@context.inject
@context.register_singleton()
class LocalPixivRepo(AbstractPixivRepo):
conf: Config
mongo: MongoDataSource
local_tag_repo: LocalTagRepo
conf = Inject(Config)
mongo = Inject(MongoDataSource)
local_tag_repo = Inject(LocalTagRepo)

async def _add_to_local_tags(self, illusts: List[Union[LazyIllust, Illust]]):
tags = {}
Expand Down
5 changes: 3 additions & 2 deletions src/nonebot_plugin_pixivbot/data/pixiv_repo/remote_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .compressor import Compressor
from .lazy_illust import LazyIllust
from .pkg_context import context
from ...context import Inject

T = TypeVar("T")

Expand Down Expand Up @@ -46,8 +47,8 @@ async def wrapped(*args, **kwargs):
@context.inject
@context.register_eager_singleton()
class RemotePixivRepo(AbstractPixivRepo):
_conf: Config
_compressor: Compressor
_conf = Inject(Config)
_compressor = Inject(Compressor)

# noinspection PyTypeChecker
def __init__(self):
Expand Down
Loading