diff --git a/src/nonebot_plugin_pixivbot/__init__.py b/src/nonebot_plugin_pixivbot/__init__.py index fdace7e..f5662a5 100644 --- a/src/nonebot_plugin_pixivbot/__init__.py +++ b/src/nonebot_plugin_pixivbot/__init__.py @@ -34,10 +34,6 @@ ) # ================= provide beans ================= -from importlib import import_module - -from nonebot import logger - from .global_context import context from .nb_providers import provide @@ -48,6 +44,9 @@ from . import service # ============== load custom protocol ============= +from importlib import import_module +from nonebot import logger + supported_modules = ["nonebot_plugin_pixivbot_onebot_v11", "nonebot_plugin_pixivbot_kook"] for p in supported_modules: diff --git a/src/nonebot_plugin_pixivbot/context.py b/src/nonebot_plugin_pixivbot/context.py index 1aa5932..5d28385 100644 --- a/src/nonebot_plugin_pixivbot/context.py +++ b/src/nonebot_plugin_pixivbot/context.py @@ -1,186 +1,140 @@ -import functools -import inspect -import sys -import types -from threading import Lock -from typing import TypeVar, Type, Callable, Dict, Any +from abc import ABC, abstractmethod +from typing import TypeVar, Type, Callable, Union, Generic from nonebot.log import logger +T = TypeVar("T") +T2 = TypeVar("T2") -# copied from inspect.py (py3.10) -def get_annotations(obj, *, globals=None, locals=None, eval_str=False) -> Dict[str, Any]: - if sys.version_info >= (3, 10): - return inspect.get_annotations(obj, globals=globals, locals=locals, eval_str=eval_str) - if isinstance(obj, type): - # class - obj_dict = getattr(obj, '__dict__', None) - if obj_dict and hasattr(obj_dict, 'get'): - ann = obj_dict.get('__annotations__', None) - if isinstance(ann, types.GetSetDescriptorType): - ann = None - else: - ann = None - - obj_globals = None - module_name = getattr(obj, '__module__', None) - if module_name: - module = sys.modules.get(module_name, None) - if module: - obj_globals = getattr(module, '__dict__', None) - obj_locals = dict(vars(obj)) - unwrap = obj - elif isinstance(obj, types.ModuleType): - # module - ann = getattr(obj, '__annotations__', None) - obj_globals = getattr(obj, '__dict__') - obj_locals = None - unwrap = None - elif callable(obj): - # this includes types.Function, types.BuiltinFunctionType, - # types.BuiltinMethodType, functools.partial, functools.singledispatch, - # "class funclike" from Lib/test/test_inspect... on and on it goes. - ann = getattr(obj, '__annotations__', None) - obj_globals = getattr(obj, '__globals__', None) - obj_locals = None - unwrap = obj - else: - raise TypeError(f"{obj!r} is not a module, class, or callable.") - - if ann is None: - return {} - - if not isinstance(ann, dict): - raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None") - - if not ann: - return {} - - if not eval_str: - return dict(ann) - - if unwrap is not None: - while True: - if hasattr(unwrap, '__wrapped__'): - unwrap = unwrap.__wrapped__ - continue - if isinstance(unwrap, functools.partial): - unwrap = unwrap.func - continue - break - if hasattr(unwrap, "__globals__"): - obj_globals = unwrap.__globals__ - - if globals is None: - globals = obj_globals - if locals is None: - locals = obj_locals - - return_value = {key: value if not isinstance(value, str) else eval(value, globals, locals) - for key, value in ann.items()} - return return_value +class Inject: + def __init__(self, key): + self._key = key + def __get__(self, instance, owner): + context = getattr(instance, "__context__") + return context.require(self._key) -T = TypeVar("T") + +class Provider(ABC, Generic[T]): + @abstractmethod + def provide(self) -> T: + raise NotImplementedError() + + +class InstanceProvider(Provider[T], Generic[T]): + def __init__(self, instance: T): + self._instance = instance + + def provide(self) -> T: + return self._instance + + +class DynamicProvider(Provider[T], Generic[T]): + def __init__(self, func: Callable[[], T], use_cache: bool = True): + self._func = func + self._use_cache = use_cache + + self._cache = None + self._cached = False + + def provide(self) -> T: + if not self._use_cache: + return self._func() + + if not self._cached: + self._cache = self._func() # just let it throw + self._cached = True + return self._cache class Context: - def __init__(self, parent=None): + def __init__(self, parent: "Context" = None): self._parent = parent self._container = {} - self._lazy_container = {} - self._binding = {} - self._lock = Lock() @property - def parent(self): + def parent(self) -> "Context": return self._parent @property - def root(self): + def root(self) -> "Context": if self._parent is None: return self else: return self._parent.root - def register(self, key: Type[T], bean: T): + def register(self, key: Type[T], bean: Union[T, Provider[T]]): """ register a bean """ + if not isinstance(bean, Provider): + bean = InstanceProvider(bean) + self._container[key] = bean - logger.trace(f"registered bean {key}") + logger.trace(f"registered bean {key}, provider type: {type(bean)}") def register_lazy(self, key: Type[T], bean_initializer: Callable[[], T]): """ register a bean lazily """ - if key in self._container: - del self._container[key] - self._lazy_container[key] = bean_initializer + self._container[key] = DynamicProvider(bean_initializer) logger.trace(f"lazily registered bean {key}") - def register_singleton(self, *args, **kwargs): + def register_singleton(self, *args, **kwargs) -> Callable[[Type[T]], Type[T]]: """ decorator for a class to register a bean lazily """ - def decorator(cls): + def decorator(cls: Type[T]) -> Type[T]: self.register_lazy(cls, lambda: cls(*args, **kwargs)) return cls return decorator - def register_eager_singleton(self, *args, **kwargs): + def register_eager_singleton(self, *args, **kwargs) -> Callable[[Type[T]], Type[T]]: """ decorator for a class to register a bean lazily """ - def decorator(cls): + def decorator(cls: Type[T]) -> Type[T]: bean = cls(*args, **kwargs) self.register(cls, bean) return cls return decorator - def unregister(self, key: Type[T]): + def unregister(self, key: Type[T]) -> bool: """ unregister the bean of key """ if key in self._container: del self._container[key] - if key in self._lazy_container: - del self._container[key] + return True + return False - def bind_to(self, key, src_key): + def bind(self, key: Type[T], src_key: Type[T2]): """ bind key (usually the implementation class) to src_key (usually the base class) """ - self._binding[key] = src_key + self._container[key] = self._container[src_key] logger.trace(f"bind bean {key} to {src_key}") - def bind_singleton_to(self, key, *args, **kwargs): + def bind_singleton_to(self, key: Type[T], *args, **kwargs) -> Callable[[Type[T2]], Type[T2]]: """ decorator for a class (usually the implementation class) to bind to another class (usually the base class) """ - def decorator(cls): + def decorator(cls: Type[T2]) -> Type[T2]: self.register_singleton(*args, **kwargs)(cls) - self.bind_to(key, cls) + self.bind(key, cls) return cls return decorator def require(self, key: Type[T]) -> T: - if key in self._binding: - return self.require(self._binding[key]) - elif key in self._container: - return self._container[key] - elif key in self._lazy_container: - # TODO: Lock - self.register(key, self._lazy_container[key]()) - del self._lazy_container[key] - return self._container[key] + if key in self._container: + return self._container[key].provide() elif self._parent is not None: return self._parent.require(key) else: @@ -190,7 +144,7 @@ def __getitem__(self, key: Type[T]): return self.require(key) def __contains__(self, key: Type[T]) -> bool: - if key in self._binding or key in self._container or key in self._lazy_container: + if key in self._container: return True elif self._parent is not None: return self._parent.__contains__(key) @@ -198,23 +152,8 @@ def __contains__(self, key: Type[T]) -> bool: return False def inject(self, cls: Type[T]): - old_getattr = getattr(cls, "__getattr__", None) - - def __getattr__(obj: T, name: str): - ann = get_annotations(cls, eval_str=True) - if name in ann and ann[name] in self: - return self[ann[name]] - - if old_getattr: - return old_getattr(obj, name) - else: - for c in cls.mro()[1:]: - c_getattr = getattr(c, "__getattr__", None) - if c_getattr: - return c_getattr(obj, name) - - setattr(cls, "__getattr__", __getattr__) + setattr(cls, "__context__", self) return cls -__all__ = ("Context",) +__all__ = ("Context", "Inject") diff --git a/src/nonebot_plugin_pixivbot/data/errors.py b/src/nonebot_plugin_pixivbot/data/errors.py index 287c0b8..671105a 100644 --- a/src/nonebot_plugin_pixivbot/data/errors.py +++ b/src/nonebot_plugin_pixivbot/data/errors.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from nonebot_plugin_pixivbot.data.pixiv_repo.abstract_repo import PixivRepoMetadata + pass class DataSourceNotReadyError(RuntimeError): diff --git a/src/nonebot_plugin_pixivbot/data/local_tag_repo.py b/src/nonebot_plugin_pixivbot/data/local_tag_repo.py index c61fbe6..0dd5d27 100644 --- a/src/nonebot_plugin_pixivbot/data/local_tag_repo.py +++ b/src/nonebot_plugin_pixivbot/data/local_tag_repo.py @@ -7,6 +7,7 @@ from nonebot_plugin_pixivbot.global_context import context from nonebot_plugin_pixivbot.model import Tag from .source import MongoDataSource +from ..context import Inject class LocalTag(Tag, Document): @@ -24,7 +25,7 @@ class Settings: @context.inject @context.register_singleton() class LocalTagRepo: - mongo: MongoDataSource + mongo = Inject(MongoDataSource) @classmethod async def find_by_name(cls, name: str) -> Optional[Tag]: diff --git a/src/nonebot_plugin_pixivbot/data/pixiv_repo/compressor.py b/src/nonebot_plugin_pixivbot/data/pixiv_repo/compressor.py index 7114f00..50adc90 100644 --- a/src/nonebot_plugin_pixivbot/data/pixiv_repo/compressor.py +++ b/src/nonebot_plugin_pixivbot/data/pixiv_repo/compressor.py @@ -8,12 +8,13 @@ from nonebot_plugin_pixivbot.config import Config from .pkg_context import context +from ...context import Inject @context.inject @context.register_singleton() class Compressor: - _conf: Config + _conf = Inject(Config) def __init__(self) -> None: self.enabled = self._conf.pixiv_compression_enabled diff --git a/src/nonebot_plugin_pixivbot/data/pixiv_repo/local_repo.py b/src/nonebot_plugin_pixivbot/data/pixiv_repo/local_repo.py index 955d386..98eda3e 100644 --- a/src/nonebot_plugin_pixivbot/data/pixiv_repo/local_repo.py +++ b/src/nonebot_plugin_pixivbot/data/pixiv_repo/local_repo.py @@ -2,7 +2,6 @@ from typing import List, Union, Sequence, Any, AsyncGenerator, Optional, Type, Mapping import bson -from beanie import BulkWriter from beanie.odm.operators.find.comparison import In from beanie.odm.operators.find.logical import And from beanie.odm.operators.update.array import AddToSet @@ -22,6 +21,7 @@ from .pkg_context import context from ..local_tag_repo import LocalTagRepo from ..source import MongoDataSource +from ...context import Inject def _handle_expires_in(metadata: PixivRepoMetadata, expires_in: int): @@ -32,9 +32,9 @@ def _handle_expires_in(metadata: PixivRepoMetadata, expires_in: int): @context.inject @context.register_singleton() class LocalPixivRepo(AbstractPixivRepo): - conf: Config - mongo: MongoDataSource - local_tag_repo: LocalTagRepo + conf = Inject(Config) + mongo = Inject(MongoDataSource) + local_tag_repo = Inject(LocalTagRepo) async def _add_to_local_tags(self, illusts: List[Union[LazyIllust, Illust]]): tags = {} diff --git a/src/nonebot_plugin_pixivbot/data/pixiv_repo/remote_repo.py b/src/nonebot_plugin_pixivbot/data/pixiv_repo/remote_repo.py index 87c3c75..bf56209 100644 --- a/src/nonebot_plugin_pixivbot/data/pixiv_repo/remote_repo.py +++ b/src/nonebot_plugin_pixivbot/data/pixiv_repo/remote_repo.py @@ -18,6 +18,7 @@ from .compressor import Compressor from .lazy_illust import LazyIllust from .pkg_context import context +from ...context import Inject T = TypeVar("T") @@ -46,8 +47,8 @@ async def wrapped(*args, **kwargs): @context.inject @context.register_eager_singleton() class RemotePixivRepo(AbstractPixivRepo): - _conf: Config - _compressor: Compressor + _conf = Inject(Config) + _compressor = Inject(Compressor) # noinspection PyTypeChecker def __init__(self): diff --git a/src/nonebot_plugin_pixivbot/data/pixiv_repo/repo.py b/src/nonebot_plugin_pixivbot/data/pixiv_repo/repo.py index b959fd6..87accf6 100644 --- a/src/nonebot_plugin_pixivbot/data/pixiv_repo/repo.py +++ b/src/nonebot_plugin_pixivbot/data/pixiv_repo/repo.py @@ -17,6 +17,7 @@ from .models import PixivRepoMetadata from .pkg_context import context from .remote_repo import RemotePixivRepo +from ...context import Inject class SharedAgenIdentifier(BaseModel): @@ -38,9 +39,9 @@ class Config: class PixivSharedAsyncGeneratorManager(SharedAsyncGeneratorManager[SharedAgenIdentifier, Any]): log_tag = "pixiv_shared_agen" - conf: Config - local: LocalPixivRepo - remote: RemotePixivRepo + conf = Inject(Config) + local = Inject(LocalPixivRepo) + remote = Inject(RemotePixivRepo) def illust_detail_factory(self, illust_id: int, cache_strategy: CacheStrategy) -> AsyncGenerator[Illust, None]: @@ -169,7 +170,8 @@ def image_factory(self, illust_id: int, illust: Illust, PixivResType.IMAGE: image_factory, } - def agen(self, identifier: SharedAgenIdentifier, cache_strategy: CacheStrategy, **kwargs) -> AsyncGenerator[Any, None]: + def agen(self, identifier: SharedAgenIdentifier, cache_strategy: CacheStrategy, **kwargs) -> AsyncGenerator[ + Any, None]: if identifier.type in self.factories: merged_kwargs = identifier.kwargs | kwargs # noinspection PyTypeChecker @@ -205,9 +207,9 @@ async def on_agen_next(self, identifier: SharedAgenIdentifier, item: Any): @context.inject @context.root.register_singleton() class PixivRepo(AbstractPixivRepo): - _shared_agen_mgr: PixivSharedAsyncGeneratorManager - _local: LocalPixivRepo - _remote: RemotePixivRepo + _shared_agen_mgr = Inject(PixivSharedAsyncGeneratorManager) + _local = Inject(LocalPixivRepo) + _remote = Inject(RemotePixivRepo) async def invalidate_cache(self): self._shared_agen_mgr.invalidate_all() diff --git a/src/nonebot_plugin_pixivbot/data/source/mongo/mongo.py b/src/nonebot_plugin_pixivbot/data/source/mongo/mongo.py index 8b7f7e3..57fe347 100644 --- a/src/nonebot_plugin_pixivbot/data/source/mongo/mongo.py +++ b/src/nonebot_plugin_pixivbot/data/source/mongo/mongo.py @@ -9,6 +9,7 @@ from pymongo.errors import OperationFailure from nonebot_plugin_pixivbot.config import Config +from nonebot_plugin_pixivbot.context import Inject from nonebot_plugin_pixivbot.data.errors import DataSourceNotReadyError from nonebot_plugin_pixivbot.data.source.mongo.migration import MongoMigrationManager from nonebot_plugin_pixivbot.global_context import context @@ -18,8 +19,8 @@ @context.inject @context.register_eager_singleton() class MongoDataSource: - conf: Config - mongo_migration_mgr: MongoMigrationManager + conf = Inject(Config) + mongo_migration_mgr = Inject(MongoMigrationManager) app_db_version = 4 def __init__(self): diff --git a/src/nonebot_plugin_pixivbot/data/subscription_repo.py b/src/nonebot_plugin_pixivbot/data/subscription_repo.py index b359ddd..7911403 100644 --- a/src/nonebot_plugin_pixivbot/data/subscription_repo.py +++ b/src/nonebot_plugin_pixivbot/data/subscription_repo.py @@ -7,6 +7,7 @@ from nonebot_plugin_pixivbot.model import Subscription, PostIdentifier, ScheduleType from .source import MongoDataSource from .utils.process_subscriber import process_subscriber +from ..context import Inject UID = TypeVar("UID") GID = TypeVar("GID") @@ -29,7 +30,7 @@ class Settings: @context.inject @context.register_singleton() class SubscriptionRepo: - mongo: MongoDataSource + mongo = Inject(MongoDataSource) @classmethod async def get_by_subscriber(cls, subscriber: ID) -> AsyncGenerator[Subscription, None]: diff --git a/src/nonebot_plugin_pixivbot/data/watch_task_repo.py b/src/nonebot_plugin_pixivbot/data/watch_task_repo.py index 57ef6e1..28926cc 100644 --- a/src/nonebot_plugin_pixivbot/data/watch_task_repo.py +++ b/src/nonebot_plugin_pixivbot/data/watch_task_repo.py @@ -7,6 +7,7 @@ from nonebot_plugin_pixivbot.model import WatchTask, WatchType, PostIdentifier from .source import MongoDataSource from .utils.process_subscriber import process_subscriber +from ..context import Inject UID = TypeVar("UID") GID = TypeVar("GID") @@ -29,7 +30,7 @@ class Settings: @context.inject @context.register_singleton() class WatchTaskRepo: - mongo: MongoDataSource + mongo = Inject(MongoDataSource) @classmethod async def get_by_subscriber(cls, subscriber: ID) -> AsyncGenerator[WatchTask, None]: diff --git a/src/nonebot_plugin_pixivbot/handler/__init__.py b/src/nonebot_plugin_pixivbot/handler/__init__.py index 1c14b5c..1531c72 100644 --- a/src/nonebot_plugin_pixivbot/handler/__init__.py +++ b/src/nonebot_plugin_pixivbot/handler/__init__.py @@ -1 +1,7 @@ +from . import command +from . import common +from . import sniffer +from .entry_handler import EntryHandler, DelegationEntryHandler from .handler import Handler + +__all__ = ("Handler", "EntryHandler", "DelegationEntryHandler") diff --git a/src/nonebot_plugin_pixivbot/handler/command/bind.py b/src/nonebot_plugin_pixivbot/handler/command/bind.py index 0ab11db..a0b74da 100644 --- a/src/nonebot_plugin_pixivbot/handler/command/bind.py +++ b/src/nonebot_plugin_pixivbot/handler/command/bind.py @@ -1,5 +1,6 @@ from typing import TypeVar, Sequence +from nonebot_plugin_pixivbot.context import Inject from nonebot_plugin_pixivbot.global_context import context from nonebot_plugin_pixivbot.protocol_dep.post_dest import PostDestination from nonebot_plugin_pixivbot.service.pixiv_account_binder import PixivAccountBinder @@ -13,7 +14,7 @@ @context.inject @context.require(CommandHandler).sub_command("bind") class BindHandler(SubCommandHandler): - binder: PixivAccountBinder + binder = Inject(PixivAccountBinder) @classmethod def type(cls) -> str: @@ -59,7 +60,7 @@ async def actual_handle_bad_request(self, err: BadRequestError, @context.inject @context.require(CommandHandler).sub_command("unbind") class UnbindHandler(SubCommandHandler): - binder: PixivAccountBinder + binder = Inject(PixivAccountBinder) @classmethod def type(cls) -> str: diff --git a/src/nonebot_plugin_pixivbot/handler/command/command.py b/src/nonebot_plugin_pixivbot/handler/command/command.py index 24dbcd0..e7b5ae4 100644 --- a/src/nonebot_plugin_pixivbot/handler/command/command.py +++ b/src/nonebot_plugin_pixivbot/handler/command/command.py @@ -4,7 +4,7 @@ from lazy import lazy from nonebot import Bot, on_command -from nonebot import logger, on_regex +from nonebot import logger from nonebot.internal.adapter import Event from nonebot.internal.matcher import Matcher from nonebot.internal.params import Depends @@ -16,7 +16,7 @@ from ..entry_handler import EntryHandler from ..entry_handler import post_destination from ..handler import Handler -from ..utils import get_common_query_rule, get_command_rule +from ..utils import get_command_rule UID = TypeVar("UID") GID = TypeVar("GID") @@ -54,7 +54,7 @@ async def actual_handle_bad_request(self, err: BadRequestError, *, await self.post_plain_text(err.message, post_dest=post_dest) -@context.register_singleton() +@context.register_eager_singleton() class CommandHandler(EntryHandler): def __init__(self): super().__init__() diff --git a/src/nonebot_plugin_pixivbot/handler/command/help.py b/src/nonebot_plugin_pixivbot/handler/command/help.py index c3a371a..771efd8 100644 --- a/src/nonebot_plugin_pixivbot/handler/command/help.py +++ b/src/nonebot_plugin_pixivbot/handler/command/help.py @@ -23,5 +23,5 @@ def parse_args(self, args: Sequence[str], post_dest: PostDestination[UID, GID]) return {} async def actual_handle(self, *, post_dest: PostDestination[UID, GID], - silently: bool = False): + silently: bool = False, **kwargs): await self.post_plain_text(help_text, post_dest=post_dest) diff --git a/src/nonebot_plugin_pixivbot/handler/command/invalidate_cache.py b/src/nonebot_plugin_pixivbot/handler/command/invalidate_cache.py index bd6efe1..184f99e 100644 --- a/src/nonebot_plugin_pixivbot/handler/command/invalidate_cache.py +++ b/src/nonebot_plugin_pixivbot/handler/command/invalidate_cache.py @@ -5,6 +5,7 @@ from nonebot_plugin_pixivbot.handler.interceptor.permission_interceptor import SuperuserInterceptor from nonebot_plugin_pixivbot.protocol_dep.post_dest import PostDestination from .command import SubCommandHandler, CommandHandler +from ...context import Inject UID = TypeVar("UID") GID = TypeVar("GID") @@ -13,7 +14,7 @@ @context.inject @context.require(CommandHandler).sub_command("invalidate_cache") class InvalidateCacheHandler(SubCommandHandler): - repo: PixivRepo + repo = Inject(PixivRepo) def __init__(self): super().__init__() @@ -29,8 +30,7 @@ def enabled(self) -> bool: def parse_args(self, args: Sequence[str], post_dest: PostDestination[UID, GID]) -> dict: return {} - # noinspection PyMethodOverriding async def actual_handle(self, *, post_dest: PostDestination[UID, GID], - silently: bool = False): + silently: bool = False, **kwargs): await self.repo.invalidate_cache() await self.post_plain_text(message="ok", post_dest=post_dest) diff --git a/src/nonebot_plugin_pixivbot/handler/command/schedule.py b/src/nonebot_plugin_pixivbot/handler/command/schedule.py index d98a19a..f3e4b58 100644 --- a/src/nonebot_plugin_pixivbot/handler/command/schedule.py +++ b/src/nonebot_plugin_pixivbot/handler/command/schedule.py @@ -8,6 +8,7 @@ from nonebot_plugin_pixivbot.service.scheduler import Scheduler from nonebot_plugin_pixivbot.utils.errors import BadRequestError from .command import CommandHandler, SubCommandHandler +from ...context import Inject UID = TypeVar("UID") GID = TypeVar("GID") @@ -30,7 +31,7 @@ async def build_subscriptions_msg(subscriber: PostIdentifier[UID, GID]): @context.inject @context.require(CommandHandler).sub_command("schedule") class ScheduleHandler(SubCommandHandler): - scheduler: Scheduler + scheduler = Inject(Scheduler) def __init__(self): super().__init__() @@ -61,7 +62,7 @@ async def actual_handle(self, *, type: ScheduleType, schedule: str, args: List, post_dest: PostDestination[UID, GID], - silently: bool = False): + silently: bool = False, **kwargs): await self.scheduler.schedule(type, schedule, args, post_dest=post_dest) await self.post_plain_text(message="订阅成功", post_dest=post_dest) @@ -88,7 +89,7 @@ async def actual_handle_bad_request(self, err: BadRequestError, @context.inject @context.require(CommandHandler).sub_command("unschedule") class UnscheduleHandler(SubCommandHandler): - scheduler: Scheduler + scheduler = Inject(Scheduler) def __init__(self): super().__init__() diff --git a/src/nonebot_plugin_pixivbot/handler/command/watch.py b/src/nonebot_plugin_pixivbot/handler/command/watch.py index 9168278..2d8bbf1 100644 --- a/src/nonebot_plugin_pixivbot/handler/command/watch.py +++ b/src/nonebot_plugin_pixivbot/handler/command/watch.py @@ -9,6 +9,7 @@ from nonebot_plugin_pixivbot.service.watchman import Watchman from nonebot_plugin_pixivbot.utils.errors import BadRequestError from .command import CommandHandler, SubCommandHandler +from ...context import Inject UID = TypeVar("UID") GID = TypeVar("GID") @@ -67,7 +68,7 @@ async def parse_following_illusts_args(args: Sequence[str], post_dest: PostDesti @context.inject @context.require(CommandHandler).sub_command("watch") class WatchHandler(SubCommandHandler): - watchman: Watchman + watchman = Inject(Watchman) def __init__(self): super().__init__() @@ -105,7 +106,7 @@ async def actual_handle(self, *, type: WatchType, watch_kwargs: Dict[str, Any], success_message: str, post_dest: PostDestination[UID, GID], - silently: bool = False): + silently: bool = False, **kwargs): await self.watchman.watch(type, watch_kwargs, post_dest) await self.post_plain_text(success_message, post_dest) @@ -130,7 +131,7 @@ async def actual_handle_bad_request(self, err: BadRequestError, @context.inject @context.require(CommandHandler).sub_command("unwatch") class UnwatchHandler(SubCommandHandler): - watchman: Watchman + watchman = Inject(Watchman) def __init__(self): super().__init__() @@ -167,7 +168,7 @@ async def parse_args(self, args: Sequence[str], post_dest: PostDestination[UID, async def actual_handle(self, *, type: WatchType, watch_kwargs: Dict[str, Any], post_dest: PostDestination[UID, GID], - silently: bool = False): + silently: bool = False, **kwargs): if await self.watchman.unwatch(type, watch_kwargs, post_dest.identifier): await self.post_plain_text(message="成功取消订阅", post_dest=post_dest) else: diff --git a/src/nonebot_plugin_pixivbot/handler/common/common.py b/src/nonebot_plugin_pixivbot/handler/common/common.py index c626d29..f0d59ba 100644 --- a/src/nonebot_plugin_pixivbot/handler/common/common.py +++ b/src/nonebot_plugin_pixivbot/handler/common/common.py @@ -8,6 +8,7 @@ from ..interceptor.cooldown_interceptor import CooldownInterceptor from ..interceptor.loading_prompt_interceptor import LoadingPromptInterceptor from ..interceptor.timeout_interceptor import TimeoutInterceptor +from ...context import Inject from ...model import Illust from ...model.message import IllustMessageModel, IllustMessagesModel from ...protocol_dep.postman import PostmanManager @@ -21,7 +22,7 @@ @context.inject class RecordPostmanManager: - recorder: Recorder + recorder = Inject(Recorder) def __init__(self, delegation: PostmanManager): self.delegation = delegation @@ -43,7 +44,7 @@ def __getattr__(self, name: str): @context.inject class CommonHandler(EntryHandler, ABC): - service: PixivService + service = Inject(PixivService) def __init__(self): super().__init__() diff --git a/src/nonebot_plugin_pixivbot/handler/common/illust.py b/src/nonebot_plugin_pixivbot/handler/common/illust.py index a48746b..6e1249d 100644 --- a/src/nonebot_plugin_pixivbot/handler/common/illust.py +++ b/src/nonebot_plugin_pixivbot/handler/common/illust.py @@ -45,6 +45,6 @@ def parse_args(self, args: Sequence[str], post_dest: PostDestination[UID, GID]) # noinspection PyMethodOverriding async def actual_handle(self, *, illust_id: int, post_dest: PostDestination[UID, GID], - silently: bool = False): + silently: bool = False, **kwargs): illust = await self.service.illust_detail(illust_id) await self.post_illust(illust, post_dest=post_dest) diff --git a/src/nonebot_plugin_pixivbot/handler/common/more.py b/src/nonebot_plugin_pixivbot/handler/common/more.py index d7e5016..8d93655 100644 --- a/src/nonebot_plugin_pixivbot/handler/common/more.py +++ b/src/nonebot_plugin_pixivbot/handler/common/more.py @@ -14,6 +14,7 @@ from .recorder import Recorder from ..entry_handler import post_destination from ..utils import get_common_query_rule +from ...context import Inject UID = TypeVar("UID") GID = TypeVar("GID") @@ -22,7 +23,7 @@ @context.inject @context.root.register_eager_singleton() class MoreHandler(CommonHandler): - recorder: Recorder + recorder = Inject(Recorder) @classmethod def type(cls) -> str: @@ -42,7 +43,8 @@ async def on_match(self, bot: Bot, event: Event, state: T_State, matcher: Matche async def actual_handle(self, *, count: int = 1, post_dest: PostDestination[UID, GID], - silently: bool = False): + silently: bool = False, + **kwargs): req = self.recorder.get_req(post_dest.identifier) if not req: raise BadRequestError("你还没有发送过请求") diff --git a/src/nonebot_plugin_pixivbot/handler/common/random_bookmark.py b/src/nonebot_plugin_pixivbot/handler/common/random_bookmark.py index 7c877ba..6a1c576 100644 --- a/src/nonebot_plugin_pixivbot/handler/common/random_bookmark.py +++ b/src/nonebot_plugin_pixivbot/handler/common/random_bookmark.py @@ -15,6 +15,7 @@ from ..entry_handler import post_destination from ..interceptor.record_req_interceptor import RecordReqInterceptor from ..utils import get_common_query_rule, get_count, get_post_dest +from ...context import Inject UID = TypeVar("UID") GID = TypeVar("GID") @@ -23,7 +24,7 @@ @context.inject @context.root.register_eager_singleton() class RandomBookmarkHandler(CommonHandler): - binder: PixivAccountBinder + binder = Inject(PixivAccountBinder) def __init__(self): super().__init__() diff --git a/src/nonebot_plugin_pixivbot/handler/common/random_related_illust.py b/src/nonebot_plugin_pixivbot/handler/common/random_related_illust.py index 30ea85b..30598ae 100644 --- a/src/nonebot_plugin_pixivbot/handler/common/random_related_illust.py +++ b/src/nonebot_plugin_pixivbot/handler/common/random_related_illust.py @@ -15,6 +15,7 @@ from ..entry_handler import post_destination from ..interceptor.record_req_interceptor import RecordReqInterceptor from ..utils import get_common_query_rule +from ...context import Inject UID = TypeVar("UID") GID = TypeVar("GID") @@ -23,7 +24,7 @@ @context.inject @context.root.register_eager_singleton() class RandomRelatedIllustHandler(CommonHandler): - recorder: Recorder + recorder = Inject(Recorder) def __init__(self): super().__init__() diff --git a/src/nonebot_plugin_pixivbot/handler/common/recorder.py b/src/nonebot_plugin_pixivbot/handler/common/recorder.py index 04234dd..4911123 100644 --- a/src/nonebot_plugin_pixivbot/handler/common/recorder.py +++ b/src/nonebot_plugin_pixivbot/handler/common/recorder.py @@ -7,6 +7,7 @@ from nonebot import logger from nonebot_plugin_pixivbot.config import Config +from nonebot_plugin_pixivbot.context import Inject from nonebot_plugin_pixivbot.global_context import context from nonebot_plugin_pixivbot.model import PostIdentifier from nonebot_plugin_pixivbot.protocol_dep.post_dest import PostDestination @@ -54,7 +55,7 @@ def refresh(self): @context.inject @context.register_singleton() class Recorder: - conf: Config + conf = Inject(Config) def __init__(self, max_req_size: int = 65535, max_resp_size: int = 65535): diff --git a/src/nonebot_plugin_pixivbot/handler/handler.py b/src/nonebot_plugin_pixivbot/handler/handler.py index 50c7de6..623ecf0 100644 --- a/src/nonebot_plugin_pixivbot/handler/handler.py +++ b/src/nonebot_plugin_pixivbot/handler/handler.py @@ -7,6 +7,7 @@ from nonebot_plugin_pixivbot.protocol_dep.post_dest import PostDestination from .interceptor.combined_interceptor import CombinedInterceptor from .interceptor.interceptor import Interceptor +from ..context import Inject from ..protocol_dep.postman import PostmanManager UID = TypeVar("UID") @@ -17,8 +18,8 @@ @context.inject class Handler(ABC): - conf: Config - postman_manager: PostmanManager + conf = Inject(Config) + postman_manager = Inject(PostmanManager) def __init__(self): self.interceptor = None diff --git a/src/nonebot_plugin_pixivbot/handler/interceptor/cooldown_interceptor.py b/src/nonebot_plugin_pixivbot/handler/interceptor/cooldown_interceptor.py index 98c1154..401f82f 100644 --- a/src/nonebot_plugin_pixivbot/handler/interceptor/cooldown_interceptor.py +++ b/src/nonebot_plugin_pixivbot/handler/interceptor/cooldown_interceptor.py @@ -9,6 +9,7 @@ from nonebot_plugin_pixivbot.model import UserIdentifier from nonebot_plugin_pixivbot.protocol_dep.post_dest import PostDestination from .permission_interceptor import PermissionInterceptor +from ...context import Inject UID = TypeVar("UID") GID = TypeVar("GID") @@ -17,7 +18,7 @@ @context.inject @context.register_singleton() class CooldownInterceptor(PermissionInterceptor): - conf: Config + conf = Inject(Config) def __init__(self): super().__init__() diff --git a/src/nonebot_plugin_pixivbot/handler/interceptor/default_error_interceptor.py b/src/nonebot_plugin_pixivbot/handler/interceptor/default_error_interceptor.py index 149004a..927a0b3 100644 --- a/src/nonebot_plugin_pixivbot/handler/interceptor/default_error_interceptor.py +++ b/src/nonebot_plugin_pixivbot/handler/interceptor/default_error_interceptor.py @@ -13,10 +13,8 @@ GID = TypeVar("GID") -@context.inject @context.register_singleton() class DefaultErrorInterceptor(Interceptor): - async def intercept(self, wrapped_func: Callable, *args, post_dest: PostDestination[UID, GID], silently: bool, diff --git a/src/nonebot_plugin_pixivbot/handler/interceptor/interceptor.py b/src/nonebot_plugin_pixivbot/handler/interceptor/interceptor.py index faaf60d..98d0140 100644 --- a/src/nonebot_plugin_pixivbot/handler/interceptor/interceptor.py +++ b/src/nonebot_plugin_pixivbot/handler/interceptor/interceptor.py @@ -3,6 +3,7 @@ from typing import Callable, TypeVar from nonebot_plugin_pixivbot import context +from nonebot_plugin_pixivbot.context import Inject from nonebot_plugin_pixivbot.protocol_dep.post_dest import PostDestination from nonebot_plugin_pixivbot.protocol_dep.postman import PostmanManager @@ -12,7 +13,7 @@ @context.inject class Interceptor(ABC): - postman_manager: PostmanManager + postman_manager = Inject(PostmanManager) async def post_plain_text(self, message: str, post_dest: PostDestination): diff --git a/src/nonebot_plugin_pixivbot/handler/interceptor/loading_prompt_interceptor.py b/src/nonebot_plugin_pixivbot/handler/interceptor/loading_prompt_interceptor.py index 3f5738c..f563c16 100644 --- a/src/nonebot_plugin_pixivbot/handler/interceptor/loading_prompt_interceptor.py +++ b/src/nonebot_plugin_pixivbot/handler/interceptor/loading_prompt_interceptor.py @@ -7,12 +7,13 @@ from nonebot_plugin_pixivbot.config import Config from nonebot_plugin_pixivbot.protocol_dep.post_dest import PostDestination from .interceptor import Interceptor, UID, GID +from ...context import Inject @context.inject @context.register_singleton() class LoadingPromptInterceptor(Interceptor): - conf: Config + conf = Inject(Config) async def send_delayed_loading_prompt(self, post_dest: PostDestination[UID, GID]): await sleep(self.conf.pixiv_loading_prompt_delayed_time) diff --git a/src/nonebot_plugin_pixivbot/handler/interceptor/permission_interceptor.py b/src/nonebot_plugin_pixivbot/handler/interceptor/permission_interceptor.py index 0792d3c..d4dd6d7 100644 --- a/src/nonebot_plugin_pixivbot/handler/interceptor/permission_interceptor.py +++ b/src/nonebot_plugin_pixivbot/handler/interceptor/permission_interceptor.py @@ -9,12 +9,12 @@ from nonebot_plugin_pixivbot.protocol_dep.authenticator import AuthenticatorManager from nonebot_plugin_pixivbot.protocol_dep.post_dest import PostDestination from .interceptor import Interceptor +from ...context import Inject UID = TypeVar("UID") GID = TypeVar("GID") -@context.inject class PermissionInterceptor(Interceptor, ABC): @abstractmethod def has_permission(self, post_dest: PostDestination[UID, GID]) -> Union[bool, Awaitable[bool]]: @@ -78,7 +78,7 @@ def has_permission(self, post_dest: PostDestination[UID, GID]) -> bool: @context.inject @context.register_singleton() class GroupAdminInterceptor(PermissionInterceptor): - auth: AuthenticatorManager + auth = Inject(AuthenticatorManager) def has_permission(self, post_dest: PostDestination[UID, GID]) -> Union[bool, Awaitable[bool]]: if not post_dest.group_id: @@ -89,7 +89,7 @@ def has_permission(self, post_dest: PostDestination[UID, GID]) -> Union[bool, Aw @context.inject @context.register_singleton() class BlacklistInterceptor(PermissionInterceptor): - conf: Config + conf = Inject(Config) def __init__(self): super().__init__() diff --git a/src/nonebot_plugin_pixivbot/handler/interceptor/record_req_interceptor.py b/src/nonebot_plugin_pixivbot/handler/interceptor/record_req_interceptor.py index 8ba6b20..3630b36 100644 --- a/src/nonebot_plugin_pixivbot/handler/interceptor/record_req_interceptor.py +++ b/src/nonebot_plugin_pixivbot/handler/interceptor/record_req_interceptor.py @@ -1,5 +1,6 @@ from typing import TypeVar, Callable +from nonebot_plugin_pixivbot.context import Inject from nonebot_plugin_pixivbot.global_context import context from nonebot_plugin_pixivbot.handler.common.recorder import Recorder, Req from nonebot_plugin_pixivbot.handler.interceptor.interceptor import Interceptor @@ -12,7 +13,7 @@ @context.inject @context.register_singleton() class RecordReqInterceptor(Interceptor): - recorder: Recorder + recorder = Inject(Recorder) async def intercept(self, wrapped_func: Callable, *args, post_dest: PostDestination[UID, GID], diff --git a/src/nonebot_plugin_pixivbot/handler/interceptor/timeout_interceptor.py b/src/nonebot_plugin_pixivbot/handler/interceptor/timeout_interceptor.py index 388a5ac..5436e8a 100644 --- a/src/nonebot_plugin_pixivbot/handler/interceptor/timeout_interceptor.py +++ b/src/nonebot_plugin_pixivbot/handler/interceptor/timeout_interceptor.py @@ -2,6 +2,7 @@ from typing import TypeVar, Callable from nonebot_plugin_pixivbot.config import Config +from nonebot_plugin_pixivbot.context import Inject from nonebot_plugin_pixivbot.global_context import context from nonebot_plugin_pixivbot.handler.interceptor.interceptor import Interceptor from nonebot_plugin_pixivbot.protocol_dep.post_dest import PostDestination @@ -13,7 +14,7 @@ @context.inject @context.register_singleton() class TimeoutInterceptor(Interceptor): - conf: Config + conf = Inject(Config) async def intercept(self, wrapped_func: Callable, *args, post_dest: PostDestination[UID, GID], diff --git a/src/nonebot_plugin_pixivbot/handler/sniffer/illust_link.py b/src/nonebot_plugin_pixivbot/handler/sniffer/illust_link.py index a06229e..072b94c 100644 --- a/src/nonebot_plugin_pixivbot/handler/sniffer/illust_link.py +++ b/src/nonebot_plugin_pixivbot/handler/sniffer/illust_link.py @@ -14,6 +14,7 @@ from nonebot_plugin_pixivbot.protocol_dep.post_dest import PostDestination +@context.register_eager_singleton() class IllustLinkHandler(DelegationEntryHandler): @classmethod def type(cls) -> str: diff --git a/src/nonebot_plugin_pixivbot/model/watch_task.py b/src/nonebot_plugin_pixivbot/model/watch_task.py index 9ad1e7c..649e219 100644 --- a/src/nonebot_plugin_pixivbot/model/watch_task.py +++ b/src/nonebot_plugin_pixivbot/model/watch_task.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, TypeVar, Generic, Dict -from pydantic import BaseModel from pydantic.generics import GenericModel from nonebot_plugin_pixivbot.model import PostIdentifier diff --git a/src/nonebot_plugin_pixivbot/nb_providers.py b/src/nonebot_plugin_pixivbot/nb_providers.py index f1077b5..0a5bb0e 100644 --- a/src/nonebot_plugin_pixivbot/nb_providers.py +++ b/src/nonebot_plugin_pixivbot/nb_providers.py @@ -16,7 +16,7 @@ def asyncio_scheduler_provider(context: Context): def base_scheduler_provider(context: Context): from apscheduler.schedulers.base import BaseScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler - context.bind_to(BaseScheduler, AsyncIOScheduler) + context.bind(BaseScheduler, AsyncIOScheduler) providers = [asyncio_scheduler_provider, base_scheduler_provider] diff --git a/src/nonebot_plugin_pixivbot/service/pixiv_account_binder.py b/src/nonebot_plugin_pixivbot/service/pixiv_account_binder.py index fade945..544abe8 100644 --- a/src/nonebot_plugin_pixivbot/service/pixiv_account_binder.py +++ b/src/nonebot_plugin_pixivbot/service/pixiv_account_binder.py @@ -1,5 +1,6 @@ from typing import TypeVar, Optional +from nonebot_plugin_pixivbot.context import Inject from nonebot_plugin_pixivbot.data.pixiv_binding_repo import PixivBindingRepo from nonebot_plugin_pixivbot.global_context import context from nonebot_plugin_pixivbot.model.pixiv_binding import PixivBinding @@ -10,7 +11,7 @@ @context.inject @context.register_singleton() class PixivAccountBinder: - repo: PixivBindingRepo + repo = Inject(PixivBindingRepo) async def bind(self, adapter: str, user_id: UID, pixiv_user_id: int): binding = PixivBinding(adapter=adapter, user_id=user_id, pixiv_user_id=pixiv_user_id) diff --git a/src/nonebot_plugin_pixivbot/service/pixiv_service.py b/src/nonebot_plugin_pixivbot/service/pixiv_service.py index e8837a8..a5ab042 100644 --- a/src/nonebot_plugin_pixivbot/service/pixiv_service.py +++ b/src/nonebot_plugin_pixivbot/service/pixiv_service.py @@ -3,7 +3,8 @@ from nonebot import logger from nonebot_plugin_pixivbot.config import Config -from nonebot_plugin_pixivbot.data.local_tag_repo import LocalTag, LocalTagRepo +from nonebot_plugin_pixivbot.context import Inject +from nonebot_plugin_pixivbot.data.local_tag_repo import LocalTagRepo from nonebot_plugin_pixivbot.data.pixiv_repo import LazyIllust, PixivRepo from nonebot_plugin_pixivbot.enums import RandomIllustMethod, RankingMode from nonebot_plugin_pixivbot.global_context import context @@ -15,9 +16,9 @@ @context.inject @context.register_singleton() class PixivService: - conf: Config - repo: PixivRepo - local_tag_repo: LocalTagRepo + conf = Inject(Config) + repo = Inject(PixivRepo) + local_tag_repo = Inject(LocalTagRepo) async def _choice_and_load(self, illusts: List[LazyIllust], random_method: RandomIllustMethod, count: int) \ -> List[Illust]: diff --git a/src/nonebot_plugin_pixivbot/service/scheduler.py b/src/nonebot_plugin_pixivbot/service/scheduler.py index c3fefae..ac8b718 100644 --- a/src/nonebot_plugin_pixivbot/service/scheduler.py +++ b/src/nonebot_plugin_pixivbot/service/scheduler.py @@ -12,6 +12,7 @@ from nonebot import logger, Bot from nonebot.exception import ActionFailed +from nonebot_plugin_pixivbot.context import Inject from nonebot_plugin_pixivbot.data.subscription_repo import SubscriptionRepo from nonebot_plugin_pixivbot.data.utils.process_subscriber import process_subscriber from nonebot_plugin_pixivbot.global_context import context @@ -64,10 +65,10 @@ def parse_schedule(raw_schedule: str) -> Sequence[int]: @context.inject @context.register_eager_singleton() class Scheduler: - apscheduler: AsyncIOScheduler - repo: SubscriptionRepo - pd_factory_mgr: PostDestinationFactoryManager - auth_mgr: AuthenticatorManager + apscheduler = Inject(AsyncIOScheduler) + repo = Inject(SubscriptionRepo) + pd_factory_mgr = Inject(PostDestinationFactoryManager) + auth_mgr = Inject(AuthenticatorManager) def __init__(self): on_bot_connect(self.on_bot_connect, replay=True) diff --git a/src/nonebot_plugin_pixivbot/service/watchman/shared_agen.py b/src/nonebot_plugin_pixivbot/service/watchman/shared_agen.py index 8e4979e..a4dc9e8 100644 --- a/src/nonebot_plugin_pixivbot/service/watchman/shared_agen.py +++ b/src/nonebot_plugin_pixivbot/service/watchman/shared_agen.py @@ -12,6 +12,7 @@ from nonebot_plugin_pixivbot.utils.shared_agen import SharedAsyncGeneratorManager from .pkg_context import context from .user_following_illusts import user_following_illusts +from ...context import Inject UID = TypeVar("UID") GID = TypeVar("GID") @@ -38,8 +39,8 @@ class Config: class WatchmanSharedAsyncGeneratorManager(SharedAsyncGeneratorManager[WatchmanSharedAgenIdentifier, Illust]): log_tag = "watchman_shared_agen" - pixiv: PixivRepo - remote_pixiv: RemotePixivRepo + pixiv = Inject(PixivRepo) + remote_pixiv = Inject(RemotePixivRepo) async def agen(self, identifier: WatchmanSharedAgenIdentifier, cache_strategy: CacheStrategy, **kwargs) -> AsyncGenerator[Illust, None]: self.set_expires_time(identifier, datetime.now(timezone.utc) + timedelta(seconds=30)) # 保证每分钟的所有task都能共享 diff --git a/src/nonebot_plugin_pixivbot/service/watchman/watchman.py b/src/nonebot_plugin_pixivbot/service/watchman/watchman.py index eaa3d76..338d16f 100644 --- a/src/nonebot_plugin_pixivbot/service/watchman/watchman.py +++ b/src/nonebot_plugin_pixivbot/service/watchman/watchman.py @@ -19,6 +19,7 @@ from nonebot_plugin_pixivbot.utils.nonebot import get_adapter_name from .pkg_context import context from .shared_agen import WatchmanSharedAsyncGeneratorManager, WatchmanSharedAgenIdentifier +from ...context import Inject from ...protocol_dep.authenticator import AuthenticatorManager UID = TypeVar("UID") @@ -41,14 +42,14 @@ @context.inject @context.root.register_eager_singleton() class Watchman: - conf: Config - apscheduler: AsyncIOScheduler - repo: WatchTaskRepo - binder: PixivAccountBinder - postman_mgr: PostmanManager - pd_factory_mgr: PostDestinationFactoryManager - shared_agen: WatchmanSharedAsyncGeneratorManager - auth_mgr: AuthenticatorManager + conf = Inject(Config) + apscheduler = Inject(AsyncIOScheduler) + repo = Inject(WatchTaskRepo) + binder = Inject(PixivAccountBinder) + postman_mgr = Inject(PostmanManager) + pd_factory_mgr = Inject(PostDestinationFactoryManager) + shared_agen = Inject(WatchmanSharedAsyncGeneratorManager) + auth_mgr = Inject(AuthenticatorManager) def __init__(self): on_bot_connect(self.on_bot_connect, replay=True) diff --git a/src/tests/test_context.py b/src/tests/test_context.py index d55ecc7..86ac144 100644 --- a/src/tests/test_context.py +++ b/src/tests/test_context.py @@ -1,10 +1,16 @@ +from unittest.mock import MagicMock + import pytest from tests import MyTest +cnt = 0 + class A: def __init__(self, data): + global cnt + cnt += 1 self.data = data def hello(self): @@ -33,25 +39,29 @@ def test_register_require_contains(self, context): assert A in context def test_register_lazy(self, context): - context.register_lazy(A, lambda: A(2)) - assert A not in context._container + initializer = MagicMock(side_effect=lambda: A(2)) + context.register_lazy(A, initializer) + initializer.assert_not_called() assert context.require(A).hello() == 2 - assert A in context._container + initializer.assert_called_once() def test_register_singleton(self, context): + old_cnt = cnt context.register_singleton(3)(A) - assert A not in context._container + assert cnt == old_cnt assert context.require(A).hello() == 3 - assert A in context._container + assert cnt == old_cnt + 1 def test_register_eager_singleton(self, context): + old_cnt = cnt context.register_eager_singleton(4)(A) - assert A in context._container + assert cnt == old_cnt + 1 assert context.require(A).hello() == 4 + assert cnt == old_cnt + 1 def test_bind_to(self, context): context.register(B, B("world")) - context.bind_to(A, B) + context.bind(A, B) assert context.require(A).hello() == "Hello world" @@ -73,26 +83,30 @@ def test_parent(self, context): assert A in third def test_inject(self, context): + from nonebot_plugin_pixivbot.context import Inject + context.register(A, A(5)) @context.inject class X: - a: A + a = Inject(A) x = X() assert x.a.hello() == 5 def test_inherited_inject(self, context): + from nonebot_plugin_pixivbot.context import Inject + context.register(A, A(6)) context.register(B, A(7)) @context.inject class X: - a: A + a = Inject(A) @context.inject class Y(X): - b: B + b = Inject(B) y = Y()