From 6c648c0f8971bae248676983b903c28c7f2e093f Mon Sep 17 00:00:00 2001 From: Jongwook Choi Date: Sat, 23 Dec 2023 22:26:37 -0500 Subject: [PATCH 1/3] refactor: add type annotations for `pynvim.plugin.host` --- pynvim/api/nvim.py | 6 +- pynvim/plugin/__init__.py | 2 +- pynvim/plugin/decorators.py | 20 ++++-- pynvim/plugin/host.py | 140 +++++++++++++++++++++++------------- test/test_host.py | 15 ++-- 5 files changed, 118 insertions(+), 65 deletions(-) diff --git a/pynvim/api/nvim.py b/pynvim/api/nvim.py index f0c33fdc..bfecb7c5 100644 --- a/pynvim/api/nvim.py +++ b/pynvim/api/nvim.py @@ -24,10 +24,6 @@ if TYPE_CHECKING: from pynvim.msgpack_rpc import Session -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal __all__ = ['Nvim'] @@ -281,7 +277,7 @@ def __exit__(self, *exc_info: Any) -> None: """ self.close() - def with_decode(self, decode: Literal[True] = True) -> Nvim: + def with_decode(self, decode: TDecodeMode = True) -> Nvim: """Initialize a new Nvim instance.""" return Nvim(self._session, self.channel_id, self.metadata, self.types, decode, self._err_cb) diff --git a/pynvim/plugin/__init__.py b/pynvim/plugin/__init__.py index cb4ba41e..9365438b 100644 --- a/pynvim/plugin/__init__.py +++ b/pynvim/plugin/__init__.py @@ -2,7 +2,7 @@ from pynvim.plugin.decorators import (autocmd, command, decode, encoding, function, plugin, rpc_export, shutdown_hook) -from pynvim.plugin.host import Host # type: ignore[attr-defined] +from pynvim.plugin.host import Host __all__ = ('Host', 'plugin', 'rpc_export', 'command', 'autocmd', diff --git a/pynvim/plugin/decorators.py b/pynvim/plugin/decorators.py index 675fc4cc..6e090722 100644 --- a/pynvim/plugin/decorators.py +++ b/pynvim/plugin/decorators.py @@ -3,14 +3,14 @@ import inspect import logging import sys -from typing import Any, Callable, Dict, Optional, TypeVar, Union - -from pynvim.compat import unicode_errors_default +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, TypeVar, Union if sys.version_info < (3, 8): - from typing_extensions import Literal + from typing_extensions import Literal, TypedDict else: - from typing import Literal + from typing import Literal, TypedDict + +from pynvim.compat import unicode_errors_default logger = logging.getLogger(__name__) debug, info, warn = (logger.debug, logger.info, logger.warning,) @@ -21,6 +21,16 @@ F = TypeVar('F', bound=Callable[..., Any]) +if TYPE_CHECKING: + class RpcSpec(TypedDict): + type: Literal['command', 'autocmd', 'function'] + name: str + sync: Union[bool, Literal['urgent']] + opts: Any +else: + RpcSpec = dict + + def plugin(cls: T) -> T: """Tag a class as a plugin. diff --git a/pynvim/plugin/host.py b/pynvim/plugin/host.py index ea4c1df6..8aae15ed 100644 --- a/pynvim/plugin/host.py +++ b/pynvim/plugin/host.py @@ -1,6 +1,7 @@ -# type: ignore """Implements a Nvim host for python plugins.""" +from __future__ import annotations + import importlib import inspect import logging @@ -10,11 +11,14 @@ import sys from functools import partial from traceback import format_exc -from typing import Any, Sequence +from types import ModuleType +from typing import (Any, Callable, Dict, List, Optional, Sequence, Type, + TypeVar, Union) from pynvim.api import Nvim, decode_if_bytes, walk +from pynvim.api.common import TDecodeMode from pynvim.msgpack_rpc import ErrorResponse -from pynvim.plugin import script_host +from pynvim.plugin import decorators, script_host from pynvim.util import format_exc_skip, get_client_info __all__ = ('Host',) @@ -26,7 +30,10 @@ host_method_spec = {"poll": {}, "specs": {"nargs": 1}, "shutdown": {}} -def _handle_import(path: str, name: str): +T = TypeVar('T') + + +def _handle_import(path: str, name: str) -> ModuleType: """Import python module `name` from a known file path or module directory. The path should be the base directory from which the module can be imported. @@ -40,13 +47,15 @@ def _handle_import(path: str, name: str): return importlib.import_module(name) -class Host(object): - +class Host: """Nvim host for python plugins. Takes care of loading/unloading plugins and routing msgpack-rpc requests/notifications to the appropriate handlers. """ + _specs: Dict[str, list[decorators.RpcSpec]] # path -> rpc spec + _loaded: Dict[str, dict] # path -> {handlers: ..., modules: ...} + _load_errors: Dict[str, str] # path -> error message def __init__(self, nvim: Nvim): """Set handlers for plugin_load/plugin_unload.""" @@ -54,10 +63,10 @@ def __init__(self, nvim: Nvim): self._specs = {} self._loaded = {} self._load_errors = {} - self._notification_handlers = { + self._notification_handlers: Dict[str, Callable] = { 'nvim_error_event': self._on_error_event } - self._request_handlers = { + self._request_handlers: Dict[str, Callable] = { 'poll': lambda: 'ok', 'specs': self._on_specs_request, 'shutdown': self.shutdown @@ -75,9 +84,8 @@ def _on_error_event(self, kind: Any, msg: str) -> None: errmsg = "{}: Async request caused an error:\n{}\n".format( self.name, decode_if_bytes(msg)) self.nvim.err_write(errmsg, async_=True) - return errmsg - def start(self, plugins): + def start(self, plugins: Sequence[str]) -> None: """Start listening for msgpack-rpc requests and notifications.""" self.nvim.run_loop(self._on_request, self._on_notification, @@ -89,31 +97,48 @@ def shutdown(self) -> None: self._unload() self.nvim.stop_loop() - def _wrap_delayed_function(self, cls, delayed_handlers, name, sync, - module_handlers, path, *args): + def _wrap_delayed_function( + self, + cls: Type[T], # a class type + delayed_handlers: List[Callable], + name: str, + sync: bool, + module_handlers: List[Callable], + path: str, + *args: Any, + ) -> Any: # delete the delayed handlers to be sure for handler in delayed_handlers: - method_name = handler._nvim_registered_name - if handler._nvim_rpc_sync: + method_name = handler._nvim_registered_name # type: ignore[attr-defined] + if handler._nvim_rpc_sync: # type: ignore[attr-defined] del self._request_handlers[method_name] else: del self._notification_handlers[method_name] # create an instance of the plugin and pass the nvim object - plugin = cls(self._configure_nvim_for(cls)) + plugin: T = cls(self._configure_nvim_for(cls)) # type: ignore[call-arg] # discover handlers in the plugin instance - self._discover_functions(plugin, module_handlers, path, False) + self._discover_functions(plugin, module_handlers, + plugin_path=path, delay=False) if sync: return self._request_handlers[name](*args) else: return self._notification_handlers[name](*args) - def _wrap_function(self, fn, sync, decode, nvim_bind, name, *args): + def _wrap_function( + self, + fn: Callable, + sync: bool, + decode: TDecodeMode, + nvim_bind: Optional[Nvim], + name: str, + *args: Any, + ) -> Any: if decode: args = walk(decode_if_bytes, args, decode) if nvim_bind is not None: - args.insert(0, nvim_bind) + args = (nvim_bind, *args) try: return fn(*args) except Exception: @@ -126,12 +151,12 @@ def _wrap_function(self, fn, sync, decode, nvim_bind, name, *args): .format(name, args, format_exc_skip(1))) self._on_async_err(msg + "\n") - def _on_request(self, name: str, args: Sequence[Any]) -> None: + def _on_request(self, name: str, args: Sequence[Any]) -> Any: """Handle a msgpack-rpc request.""" name = decode_if_bytes(name) handler = self._request_handlers.get(name, None) if not handler: - msg = self._missing_handler_error(name, 'request') + msg = self._missing_handler_error(name, kind='request') error(msg) raise ErrorResponse(msg) @@ -145,7 +170,7 @@ def _on_notification(self, name: str, args: Sequence[Any]) -> None: name = decode_if_bytes(name) handler = self._notification_handlers.get(name, None) if not handler: - msg = self._missing_handler_error(name, 'notification') + msg = self._missing_handler_error(name, kind='notification') error(msg) self._on_async_err(msg + "\n") return @@ -153,7 +178,7 @@ def _on_notification(self, name: str, args: Sequence[Any]) -> None: debug('calling notification handler for "%s", args: "%s"', name, args) handler(*args) - def _missing_handler_error(self, name, kind): + def _missing_handler_error(self, name: str, *, kind: str) -> str: msg = 'no {} handler registered for "{}"'.format(kind, name) pathmatch = re.match(r'(.+):[^:]+:[^:]+', name) if pathmatch: @@ -168,7 +193,7 @@ def _load(self, plugins: Sequence[str]) -> None: Args: plugins: List of plugin paths to rplugin python modules registered by remote#host#RegisterPlugin('python3', ...) - (see the generated rplugin.vim manifest) + (see the generated ~/.local/share/nvim/rplugin.vim manifest) """ # self.nvim.err_write("host init _load\n", async_=True) has_script = False @@ -185,9 +210,9 @@ def _load(self, plugins: Sequence[str]) -> None: else: directory, name = os.path.split(os.path.splitext(path)[0]) module = _handle_import(directory, name) - handlers = [] + handlers: List[Callable] = [] self._discover_classes(module, handlers, path) - self._discover_functions(module, handlers, path, False) + self._discover_functions(module, handlers, path, delay=False) if not handlers: error('{} exports no handlers'.format(path)) continue @@ -218,48 +243,65 @@ def _unload(self) -> None: self._specs = {} self._loaded = {} - def _discover_classes(self, module, handlers, plugin_path): + def _discover_classes( + self, + module: ModuleType, + handlers: List[Callable], + plugin_path: str, + ) -> None: for _, cls in inspect.getmembers(module, inspect.isclass): if getattr(cls, '_nvim_plugin', False): # discover handlers in the plugin instance - self._discover_functions(cls, handlers, plugin_path, True) - - def _discover_functions(self, obj, handlers, plugin_path, delay): - def predicate(o): + self._discover_functions(cls, handlers, plugin_path, delay=True) + + def _discover_functions( + self, + obj: Union[Type, ModuleType, Any], # class, module, or plugin instance + handlers: List[Callable], + plugin_path: str, + delay: bool, + ) -> None: + def predicate(o: Any) -> bool: return hasattr(o, '_nvim_rpc_method_name') - cls_handlers = [] - specs = [] - objdecode = getattr(obj, '_nvim_decode', self._decode_default) + cls_handlers: List[Callable] = [] + specs: List[decorators.RpcSpec] = [] + obj_decode: TDecodeMode = getattr(obj, '_nvim_decode', + self._decode_default) # type: ignore for _, fn in inspect.getmembers(obj, predicate): - method = fn._nvim_rpc_method_name + method: str = fn._nvim_rpc_method_name if fn._nvim_prefix_plugin_path: method = '{}:{}'.format(plugin_path, method) - sync = fn._nvim_rpc_sync + sync: bool = fn._nvim_rpc_sync if delay: + # TODO: Fix typing on obj. delay=True assumes obj is a class! + assert isinstance(obj, type), "obj must be a class type" fn_wrapped = partial(self._wrap_delayed_function, obj, cls_handlers, method, sync, handlers, plugin_path) else: - decode = getattr(fn, '_nvim_decode', objdecode) - nvim_bind = None + decode: TDecodeMode = getattr(fn, '_nvim_decode', obj_decode) + nvim_bind: Optional[Nvim] = None if fn._nvim_bind: nvim_bind = self._configure_nvim_for(fn) fn_wrapped = partial(self._wrap_function, fn, sync, decode, nvim_bind, method) self._copy_attributes(fn, fn_wrapped) - fn_wrapped._nvim_registered_name = method + fn_wrapped._nvim_registered_name = method # type: ignore[attr-defined] + # register in the rpc handler dict if sync: if method in self._request_handlers: - raise Exception(('Request handler for "{}" is ' - + 'already registered').format(method)) + raise Exception( + f'Request handler for "{method}" is already registered' + ) self._request_handlers[method] = fn_wrapped else: if method in self._notification_handlers: - raise Exception(('Notification handler for "{}" is ' - + 'already registered').format(method)) + raise Exception( + f'Notification handler for "{method}" is already registered' + ) self._notification_handlers[method] = fn_wrapped if hasattr(fn, '_nvim_rpc_spec'): specs.append(fn._nvim_rpc_spec) @@ -268,19 +310,21 @@ def predicate(o): if specs: self._specs[plugin_path] = specs - def _copy_attributes(self, fn, fn2): + def _copy_attributes(self, src: Any, dst: Any) -> None: # Copy _nvim_* attributes from the original function - for attr in dir(fn): + for attr in dir(src): if attr.startswith('_nvim_'): - setattr(fn2, attr, getattr(fn, attr)) + setattr(dst, attr, getattr(src, attr)) - def _on_specs_request(self, path): + def _on_specs_request(self, path: Union[str, bytes] + ) -> List[decorators.RpcSpec]: path = decode_if_bytes(path) + assert isinstance(path, str) if path in self._load_errors: self.nvim.out_write(self._load_errors[path] + '\n') - return self._specs.get(path, 0) + return self._specs.get(path, []) - def _configure_nvim_for(self, obj): + def _configure_nvim_for(self, obj: Any) -> Nvim: # Configure a nvim instance for obj (checks encoding configuration) nvim = self.nvim decode = getattr(obj, '_nvim_decode', self._decode_default) diff --git a/test/test_host.py b/test/test_host.py index 18cff327..860dc553 100644 --- a/test/test_host.py +++ b/test/test_host.py @@ -3,13 +3,14 @@ import os from typing import Sequence +from pynvim.api.nvim import Nvim from pynvim.plugin.host import Host, host_method_spec from pynvim.plugin.script_host import ScriptHost __PATH__ = os.path.abspath(os.path.dirname(__file__)) -def test_host_imports(vim): +def test_host_imports(vim: Nvim): h = ScriptHost(vim) try: assert h.module.__dict__['vim'] @@ -19,7 +20,7 @@ def test_host_imports(vim): h.teardown() -def test_host_import_rplugin_modules(vim): +def test_host_import_rplugin_modules(vim: Nvim): # Test whether a Host can load and import rplugins (#461). # See also $VIMRUNTIME/autoload/provider/pythonx.vim. h = Host(vim) @@ -47,18 +48,20 @@ def test_host_clientinfo(vim): # Smoke test for Host._on_error_event(). #425 -def test_host_async_error(vim): +def test_host_async_error(vim: Nvim): h = Host(vim) h._load([]) # Invoke a bogus Ex command via notify (async). vim.command("lolwut", async_=True) event = vim.next_message() assert event[1] == 'nvim_error_event' - assert 'rplugin-host: Async request caused an error:\nboom\n' \ - in h._on_error_event(None, 'boom') + h._on_error_event(None, 'boom') + msg = vim.command_output('messages') + assert 'rplugin-host: Async request caused an error:\nboom' in msg -def test_legacy_vim_eval(vim): + +def test_legacy_vim_eval(vim: Nvim): h = ScriptHost(vim) try: assert h.legacy_vim.eval('1') == '1' From 950f441b742873fa6202bad94bbeb014b6914335 Mon Sep 17 00:00:00 2001 From: Jongwook Choi Date: Tue, 26 Dec 2023 12:53:10 -0500 Subject: [PATCH 2/3] refactor: add typing for Handler, revise typings in decorators - `Handler` is a structured subtype of `Callable` decorated by pynvim decorators for RPC handlers; this type provides static typing for the private fields internally set by pynvim decorators. This way, we have a better static typing and thus no need to rely on `hasattr` checks. - Add minimal test cases for other decorator types. --- pynvim/plugin/decorators.py | 221 +++++++++++++++++++++++++++--------- pynvim/plugin/host.py | 77 +++++++------ test/test_decorators.py | 94 ++++++++++++++- 3 files changed, 300 insertions(+), 92 deletions(-) diff --git a/pynvim/plugin/decorators.py b/pynvim/plugin/decorators.py index 6e090722..82b52bd8 100644 --- a/pynvim/plugin/decorators.py +++ b/pynvim/plugin/decorators.py @@ -1,27 +1,40 @@ """Decorators used by python host plugin system.""" +from __future__ import annotations + import inspect import logging import sys -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, TypeVar, Union +from functools import partial +from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type, + TypeVar, Union, cast, overload) if sys.version_info < (3, 8): - from typing_extensions import Literal, TypedDict + from typing_extensions import Literal, Protocol, TypedDict +else: + from typing import Literal, Protocol, TypedDict + +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec else: - from typing import Literal, TypedDict + from typing import ParamSpec +from pynvim.api.common import TDecodeMode from pynvim.compat import unicode_errors_default logger = logging.getLogger(__name__) -debug, info, warn = (logger.debug, logger.info, logger.warning,) +debug, info, warn = ( + logger.debug, + logger.info, + logger.warning, +) __all__ = ('plugin', 'rpc_export', 'command', 'autocmd', 'function', 'encoding', 'decode', 'shutdown_hook') T = TypeVar('T') -F = TypeVar('F', bound=Callable[..., Any]) - if TYPE_CHECKING: + class RpcSpec(TypedDict): type: Literal['command', 'autocmd', 'function'] name: str @@ -30,6 +43,48 @@ class RpcSpec(TypedDict): else: RpcSpec = dict +# type variables for Handler, to represent Callable: P -> R +P = ParamSpec('P') +R = TypeVar('R') + + +class Handler(Protocol[P, R]): + """An interface to pynvim-decorated RPC handler. + + Handler is basically a callable (method) that is decorated by pynvim. + It will have some private fields (prefixed with `_nvim_`), set by + decorators that follow below. This generic type allows stronger, static + typing for all the private attributes (see `host.Host` for the usage). + + Note: Any valid Handler that is created by pynvim's decorator is guaranteed + to have *all* of the following `_nvim_*` attributes defined as per the + "Protocol", so there is NO need to check `hasattr(handler, "_nvim_...")`. + Exception is _nvim_decode; this is an optional attribute orthgonally set by + the decorator `@decode()`. + """ + __call__: Callable[P, R] + + _nvim_rpc_method_name: str + _nvim_rpc_sync: bool + _nvim_bind: bool + _nvim_prefix_plugin_path: bool + _nvim_rpc_spec: Optional[RpcSpec] + _nvim_shutdown_hook: bool + + _nvim_registered_name: Optional[str] # set later by host when discovered + + @classmethod + def wrap(cls, fn: Callable[P, R]) -> Handler[P, R]: + fn = cast(Handler[P, R], partial(fn)) + fn._nvim_bind = False + fn._nvim_rpc_method_name = None # type: ignore + fn._nvim_rpc_sync = None # type: ignore + fn._nvim_prefix_plugin_path = False + fn._nvim_rpc_spec = None + fn._nvim_shutdown_hook = False + fn._nvim_registered_name = None + return fn + def plugin(cls: T) -> T: """Tag a class as a plugin. @@ -38,24 +93,34 @@ def plugin(cls: T) -> T: plugin_load method of the host. """ cls._nvim_plugin = True # type: ignore[attr-defined] + # the _nvim_bind attribute is set to True by default, meaning that # decorated functions have a bound Nvim instance as first argument. # For methods in a plugin-decorated class this is not required, because # the class initializer will already receive the nvim object. - predicate = lambda fn: hasattr(fn, '_nvim_bind') + predicate = lambda fn: getattr(fn, '_nvim_bind', False) for _, fn in inspect.getmembers(cls, predicate): fn._nvim_bind = False return cls -def rpc_export(rpc_method_name: str, sync: bool = False) -> Callable[[F], F]: +def rpc_export( + rpc_method_name: str, + sync: bool = False, +) -> Callable[[Callable[P, R]], Handler[P, R]]: """Export a function or plugin method as a msgpack-rpc request handler.""" - def dec(f: F) -> F: - f._nvim_rpc_method_name = rpc_method_name # type: ignore[attr-defined] - f._nvim_rpc_sync = sync # type: ignore[attr-defined] - f._nvim_bind = True # type: ignore[attr-defined] - f._nvim_prefix_plugin_path = False # type: ignore[attr-defined] + + def dec(f: Callable[P, R]) -> Handler[P, R]: + f = cast(Handler[P, R], f) + f._nvim_rpc_method_name = rpc_method_name + f._nvim_rpc_sync = sync + f._nvim_bind = True + f._nvim_prefix_plugin_path = False + f._nvim_rpc_spec = None # not used + f._nvim_shutdown_hook = False # not used + f._nvim_registered_name = None # TBD return f + return dec @@ -70,15 +135,15 @@ def command( sync: bool = False, allow_nested: bool = False, eval: Optional[str] = None -) -> Callable[[F], F]: +) -> Callable[[Callable[P, R]], Handler[P, R]]: """Tag a function or plugin method as a Nvim command handler.""" - def dec(f: F) -> F: - f._nvim_rpc_method_name = ( # type: ignore[attr-defined] - 'command:{}'.format(name) - ) - f._nvim_rpc_sync = sync # type: ignore[attr-defined] - f._nvim_bind = True # type: ignore[attr-defined] - f._nvim_prefix_plugin_path = True # type: ignore[attr-defined] + + def dec(f: Callable[P, R]) -> Handler[P, R]: + f = cast(Handler[P, R], f) + f._nvim_rpc_method_name = ('command:{}'.format(name)) + f._nvim_rpc_sync = sync + f._nvim_bind = True + f._nvim_prefix_plugin_path = True opts: Dict[str, Any] = {} @@ -107,13 +172,16 @@ def dec(f: F) -> F: else: rpc_sync = sync - f._nvim_rpc_spec = { # type: ignore[attr-defined] + f._nvim_rpc_spec = { 'type': 'command', 'name': name, 'sync': rpc_sync, 'opts': opts } + f._nvim_shutdown_hook = False + f._nvim_registered_name = None # TBD return f + return dec @@ -123,19 +191,17 @@ def autocmd( sync: bool = False, allow_nested: bool = False, eval: Optional[str] = None -) -> Callable[[F], F]: +) -> Callable[[Callable[P, R]], Handler[P, R]]: """Tag a function or plugin method as a Nvim autocommand handler.""" - def dec(f: F) -> F: - f._nvim_rpc_method_name = ( # type: ignore[attr-defined] - 'autocmd:{}:{}'.format(name, pattern) - ) - f._nvim_rpc_sync = sync # type: ignore[attr-defined] - f._nvim_bind = True # type: ignore[attr-defined] - f._nvim_prefix_plugin_path = True # type: ignore[attr-defined] - - opts = { - 'pattern': pattern - } + + def dec(f: Callable[P, R]) -> Handler[P, R]: + f = cast(Handler[P, R], f) + f._nvim_rpc_method_name = ('autocmd:{}:{}'.format(name, pattern)) + f._nvim_rpc_sync = sync + f._nvim_bind = True + f._nvim_prefix_plugin_path = True + + opts = {'pattern': pattern} if eval: opts['eval'] = eval @@ -145,13 +211,16 @@ def dec(f: F) -> F: else: rpc_sync = sync - f._nvim_rpc_spec = { # type: ignore[attr-defined] + f._nvim_rpc_spec = { 'type': 'autocmd', 'name': name, 'sync': rpc_sync, 'opts': opts } + f._nvim_shutdown_hook = False + f._nvim_registered_name = None # TBD return f + return dec @@ -161,15 +230,15 @@ def function( sync: bool = False, allow_nested: bool = False, eval: Optional[str] = None -) -> Callable[[F], F]: +) -> Callable[[Callable[P, R]], Handler[P, R]]: """Tag a function or plugin method as a Nvim function handler.""" - def dec(f: F) -> F: - f._nvim_rpc_method_name = ( # type: ignore[attr-defined] - 'function:{}'.format(name) - ) - f._nvim_rpc_sync = sync # type: ignore[attr-defined] - f._nvim_bind = True # type: ignore[attr-defined] - f._nvim_prefix_plugin_path = True # type: ignore[attr-defined] + + def dec(f: Callable[P, R]) -> Handler[P, R]: + f = cast(Handler[P, R], f) + f._nvim_rpc_method_name = ('function:{}'.format(name)) + f._nvim_rpc_sync = sync + f._nvim_bind = True + f._nvim_prefix_plugin_path = True opts = {} @@ -184,37 +253,79 @@ def dec(f: F) -> F: else: rpc_sync = sync - f._nvim_rpc_spec = { # type: ignore[attr-defined] + f._nvim_rpc_spec = { 'type': 'function', 'name': name, 'sync': rpc_sync, 'opts': opts } + f._nvim_shutdown_hook = False # not used + f._nvim_registered_name = None # TBD return f + return dec -def shutdown_hook(f: F) -> F: +def shutdown_hook(f: Callable[P, R]) -> Handler[P, R]: """Tag a function or method as a shutdown hook.""" - f._nvim_shutdown_hook = True # type: ignore[attr-defined] - f._nvim_bind = True # type: ignore[attr-defined] + f = cast(Handler[P, R], f) + f._nvim_rpc_method_name = '' # Falsy value, not used + f._nvim_rpc_sync = True # not used + f._nvim_prefix_plugin_path = False # not used + f._nvim_rpc_spec = None # not used + + f._nvim_shutdown_hook = True + f._nvim_bind = True + f._nvim_registered_name = None # TBD return f -def decode(mode: str = unicode_errors_default) -> Callable[[F], F]: - """Configure automatic encoding/decoding of strings.""" - def dec(f: F) -> F: - f._nvim_decode = mode # type: ignore[attr-defined] +T_Decode = Union[Type, Handler[P, R]] + + +def decode( + mode: TDecodeMode = unicode_errors_default, +) -> Callable[[T_Decode], T_Decode]: + """Configure automatic encoding/decoding of strings. + + This decorator can be put around an individual Handler (@rpc_export, + @autocmd, @function, @command, or @shutdown_hook), or around a class + (@plugin, has an effect on all the methods unless overridden). + + The argument `mode` will be passed as an argument to: + bytes.decode("utf-8", errors=mode) + when decoding bytestream Nvim RPC responses. + + See https://docs.python.org/3/library/codecs.html#error-handlers for + the list of valid modes (error handler values). + + See also: + pynvim.api.Nvim.with_decode(mode) + pynvim.api.common.decode_if_bytes(..., mode) + """ + + @overload + def dec(f: Handler[P, R]) -> Handler[P, R]: + ... # decorator on method + + @overload + def dec(f: Type[T]) -> Type[T]: + ... # decorator on class + + def dec(f): # type: ignore + f._nvim_decode = mode return f - return dec + return dec # type: ignore -def encoding(encoding: Union[bool, str] = True) -> Callable[[F], F]: + +def encoding(encoding: Union[bool, str] = True): # type: ignore """DEPRECATED: use pynvim.decode().""" if isinstance(encoding, str): encoding = True - def dec(f: F) -> F: - f._nvim_decode = encoding # type: ignore[attr-defined] + def dec(f): # type: ignore + f._nvim_decode = encoding if encoding else None return f + return dec diff --git a/pynvim/plugin/host.py b/pynvim/plugin/host.py index 8aae15ed..b16eabdf 100644 --- a/pynvim/plugin/host.py +++ b/pynvim/plugin/host.py @@ -13,7 +13,7 @@ from traceback import format_exc from types import ModuleType from typing import (Any, Callable, Dict, List, Optional, Sequence, Type, - TypeVar, Union) + TypeVar, Union, cast) from pynvim.api import Nvim, decode_if_bytes, walk from pynvim.api.common import TDecodeMode @@ -32,6 +32,9 @@ T = TypeVar('T') +RpcSpec = decorators.RpcSpec +Handler = decorators.Handler + def _handle_import(path: str, name: str) -> ModuleType: """Import python module `name` from a known file path or module directory. @@ -53,7 +56,7 @@ class Host: Takes care of loading/unloading plugins and routing msgpack-rpc requests/notifications to the appropriate handlers. """ - _specs: Dict[str, list[decorators.RpcSpec]] # path -> rpc spec + _specs: Dict[str, list[RpcSpec]] # path -> list[ rpc handler spec ] _loaded: Dict[str, dict] # path -> {handlers: ..., modules: ...} _load_errors: Dict[str, str] # path -> error message @@ -63,13 +66,13 @@ def __init__(self, nvim: Nvim): self._specs = {} self._loaded = {} self._load_errors = {} - self._notification_handlers: Dict[str, Callable] = { - 'nvim_error_event': self._on_error_event + self._notification_handlers: Dict[str, Handler] = { + 'nvim_error_event': Handler.wrap(self._on_error_event), } - self._request_handlers: Dict[str, Callable] = { - 'poll': lambda: 'ok', - 'specs': self._on_specs_request, - 'shutdown': self.shutdown + self._request_handlers: Dict[str, Handler] = { + 'poll': Handler.wrap(lambda: 'ok'), + 'specs': Handler.wrap(self._on_specs_request), + 'shutdown': Handler.wrap(self.shutdown), } self._decode_default = True @@ -100,17 +103,18 @@ def shutdown(self) -> None: def _wrap_delayed_function( self, cls: Type[T], # a class type - delayed_handlers: List[Callable], + delayed_handlers: List[Handler], name: str, sync: bool, - module_handlers: List[Callable], + module_handlers: List[Handler], path: str, *args: Any, ) -> Any: # delete the delayed handlers to be sure for handler in delayed_handlers: - method_name = handler._nvim_registered_name # type: ignore[attr-defined] - if handler._nvim_rpc_sync: # type: ignore[attr-defined] + method_name = handler._nvim_registered_name + assert method_name is not None + if handler._nvim_rpc_sync: del self._request_handlers[method_name] else: del self._notification_handlers[method_name] @@ -210,7 +214,7 @@ def _load(self, plugins: Sequence[str]) -> None: else: directory, name = os.path.split(os.path.splitext(path)[0]) module = _handle_import(directory, name) - handlers: List[Callable] = [] + handlers: List[Handler] = [] self._discover_classes(module, handlers, path) self._discover_functions(module, handlers, path, delay=False) if not handlers: @@ -234,7 +238,7 @@ def _unload(self) -> None: handlers = plugin['handlers'] for handler in handlers: method_name = handler._nvim_registered_name - if hasattr(handler, '_nvim_shutdown_hook'): + if handler._nvim_shutdown_hook: handler() elif handler._nvim_rpc_sync: del self._request_handlers[method_name] @@ -246,7 +250,7 @@ def _unload(self) -> None: def _discover_classes( self, module: ModuleType, - handlers: List[Callable], + handlers: List[Handler], plugin_path: str, ) -> None: for _, cls in inspect.getmembers(module, inspect.isclass): @@ -257,18 +261,19 @@ def _discover_classes( def _discover_functions( self, obj: Union[Type, ModuleType, Any], # class, module, or plugin instance - handlers: List[Callable], + handlers: List[Handler], plugin_path: str, delay: bool, ) -> None: def predicate(o: Any) -> bool: - return hasattr(o, '_nvim_rpc_method_name') + return bool(getattr(o, '_nvim_rpc_method_name', False)) - cls_handlers: List[Callable] = [] + cls_handlers: List[Handler] = [] specs: List[decorators.RpcSpec] = [] - obj_decode: TDecodeMode = getattr(obj, '_nvim_decode', - self._decode_default) # type: ignore + obj_decode: TDecodeMode = cast( + TDecodeMode, getattr(obj, '_nvim_decode', self._decode_default)) for _, fn in inspect.getmembers(obj, predicate): + fn = cast(Handler, fn) # because hasattr(_nvim_rpc_method_name) method: str = fn._nvim_rpc_method_name if fn._nvim_prefix_plugin_path: method = '{}:{}'.format(plugin_path, method) @@ -276,34 +281,33 @@ def predicate(o: Any) -> bool: if delay: # TODO: Fix typing on obj. delay=True assumes obj is a class! assert isinstance(obj, type), "obj must be a class type" - fn_wrapped = partial(self._wrap_delayed_function, obj, - cls_handlers, method, sync, - handlers, plugin_path) + _fn_wrapped = partial(self._wrap_delayed_function, obj, + cls_handlers, method, sync, + handlers, plugin_path) else: decode: TDecodeMode = getattr(fn, '_nvim_decode', obj_decode) nvim_bind: Optional[Nvim] = None if fn._nvim_bind: nvim_bind = self._configure_nvim_for(fn) - fn_wrapped = partial(self._wrap_function, fn, - sync, decode, nvim_bind, method) - self._copy_attributes(fn, fn_wrapped) - fn_wrapped._nvim_registered_name = method # type: ignore[attr-defined] + _fn_wrapped = partial(self._wrap_function, fn, + sync, decode, nvim_bind, method) + self._copy_attributes(fn, _fn_wrapped) + fn_wrapped: Handler = cast(Handler, _fn_wrapped) + fn_wrapped._nvim_registered_name = method # register in the rpc handler dict if sync: if method in self._request_handlers: - raise Exception( - f'Request handler for "{method}" is already registered' - ) + raise Exception(f'Request handler for "{method}" ' + 'is already registered') self._request_handlers[method] = fn_wrapped else: if method in self._notification_handlers: - raise Exception( - f'Notification handler for "{method}" is already registered' - ) + raise Exception(f'Notification handler for "{method}" ' + 'is already registered') self._notification_handlers[method] = fn_wrapped - if hasattr(fn, '_nvim_rpc_spec'): + if fn._nvim_rpc_spec: specs.append(fn._nvim_rpc_spec) handlers.append(fn_wrapped) cls_handlers.append(fn_wrapped) @@ -317,7 +321,7 @@ def _copy_attributes(self, src: Any, dst: Any) -> None: setattr(dst, attr, getattr(src, attr)) def _on_specs_request(self, path: Union[str, bytes] - ) -> List[decorators.RpcSpec]: + ) -> List[RpcSpec]: path = decode_if_bytes(path) assert isinstance(path, str) if path in self._load_errors: @@ -327,7 +331,8 @@ def _on_specs_request(self, path: Union[str, bytes] def _configure_nvim_for(self, obj: Any) -> Nvim: # Configure a nvim instance for obj (checks encoding configuration) nvim = self.nvim - decode = getattr(obj, '_nvim_decode', self._decode_default) + decode: TDecodeMode = cast( + TDecodeMode, getattr(obj, '_nvim_decode', self._decode_default)) if decode: nvim = nvim.with_decode(decode) return nvim diff --git a/test/test_decorators.py b/test/test_decorators.py index 1f5c857e..be5f9482 100644 --- a/test/test_decorators.py +++ b/test/test_decorators.py @@ -1,8 +1,35 @@ # type: ignore -from pynvim.plugin.decorators import command +from pynvim.plugin.decorators import (Handler, autocmd, command, decode, + function, plugin, rpc_export, shutdown_hook) + + +def _ensure_attributes(decorated: Handler) -> Handler: + """Ensure that a Handler has all the private _nvim_* attributes set.""" + attrs = [ + k for k in Handler.__annotations__.keys() if k.startswith('_nvim_') + ] + + for attr in attrs: + assert hasattr(decorated, attr), \ + f"{decorated} does not have attr: {attr}" + + assert decorated._nvim_registered_name is None # shouldn't be set yet + return decorated + + +def test_rpc_export() -> None: + + @rpc_export("rpc_remote_point", sync=True) + def handler(): + pass + + _ensure_attributes(handler) + assert "rpc_remote_point" == handler._nvim_rpc_method_name + assert True == handler._nvim_rpc_sync # noqa def test_command_count() -> None: + def function() -> None: """A dummy function to decorate.""" return @@ -10,6 +37,7 @@ def function() -> None: # ensure absence with default value of None decorated = command('test')(function) assert 'count' not in decorated._nvim_rpc_spec['opts'] + _ensure_attributes(decorated) # ensure absence with explicit value of None count_value = None @@ -27,3 +55,67 @@ def function() -> None: decorated = command('test', count=count_value)(function) assert 'count' in decorated._nvim_rpc_spec['opts'] assert decorated._nvim_rpc_spec['opts']['count'] == count_value + + +def test_autocmd() -> None: + + @autocmd(name="BufEnter", pattern="*.py", sync=True) + def handler(afile): + print(afile) + + _ensure_attributes(handler) + assert 'autocmd:BufEnter:*.py' == handler._nvim_rpc_method_name + + +def test_function() -> None: + pass + + @function(name="MyRemoteFunction") + def MyRemoteFunc(a: int, b: int) -> int: + """Add two integers.""" + return a + b + + _ensure_attributes(MyRemoteFunc) + assert 'function:MyRemoteFunction' == MyRemoteFunc._nvim_rpc_method_name + + +def test_shutdown_hook() -> None: + + @shutdown_hook + def hook(): + print("shutdown...") + + _ensure_attributes(hook) + assert True == hook._nvim_shutdown_hook # noqa + assert not hook._nvim_rpc_method_name + + +def test_decode() -> None: + + # Case 1 + @decode(mode="strict") + @function(name="MyFunc") + def handler1(): + """A valid usage.""" + + # decode set, and all other attributes are preserved + assert "function:MyFunc" == handler1._nvim_rpc_method_name + assert "strict" == handler1._nvim_decode + + # Case 2: decode "inside" function + @function(name="MyFunc") + @decode(mode="strict") + def handler2(): + """Note the swapped order between function and decode.""" + + assert "function:MyFunc" == handler2._nvim_rpc_method_name + assert "strict" == handler2._nvim_decode + + # Case 3: on class + @decode(mode="strict") + @plugin + class MyPlugin: + pass + + assert "strict" == MyPlugin._nvim_decode + assert True == MyPlugin._nvim_plugin # noqa From dda270e80f5e9f413af9236b037e8ec39714bfe1 Mon Sep 17 00:00:00 2001 From: Jongwook Choi Date: Wed, 27 Dec 2023 05:30:25 -0500 Subject: [PATCH 3/3] test: add tests and docs for plugin.Host --- pynvim/__init__.py | 2 ++ pynvim/plugin/host.py | 76 +++++++++++++++++++++++++++------------- test/test_host.py | 81 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 134 insertions(+), 25 deletions(-) diff --git a/pynvim/__init__.py b/pynvim/__init__.py index daefc1ec..b604b482 100644 --- a/pynvim/__init__.py +++ b/pynvim/__init__.py @@ -44,6 +44,8 @@ def start_host(session: Optional[Session] = None) -> None: This function is normally called at program startup and could have been defined as a separate executable. It is exposed as a library function for testing purposes only. + + See also $VIMRUNTIME/autoload/provider/pythonx.vim for python host startup. """ plugins = [] for arg in sys.argv: diff --git a/pynvim/plugin/host.py b/pynvim/plugin/host.py index b16eabdf..6152c571 100644 --- a/pynvim/plugin/host.py +++ b/pynvim/plugin/host.py @@ -194,38 +194,36 @@ def _missing_handler_error(self, name: str, *, kind: str) -> str: def _load(self, plugins: Sequence[str]) -> None: """Load the remote plugins and register handlers defined in the plugins. - Args: - plugins: List of plugin paths to rplugin python modules - registered by remote#host#RegisterPlugin('python3', ...) - (see the generated ~/.local/share/nvim/rplugin.vim manifest) + Parameters + ---------- + plugins: List of plugin paths to rplugin python modules registered by + `remote#host#RegisterPlugin('python3', ...)`. Each element should + be either: + (1) "script_host.py": this is a special plugin for python3 + rplugin host. See $VIMRUNTIME/autoload/provider/python3.vim + ; or + (2) (absolute) path to the top-level plugin module directory; + e.g., for a top-level python module `mymodule`: it would be + `"/path/to/plugin/rplugin/python3/mymodule"`. + See the generated ~/.local/share/nvim/rplugin.vim manifest + for real examples. """ # self.nvim.err_write("host init _load\n", async_=True) has_script = False for path in plugins: path = os.path.normpath(path) # normalize path - err = None - if path in self._loaded: - warn('{} is already loaded'.format(path)) - continue try: - if path == "script_host.py": - module = script_host - has_script = True - else: - directory, name = os.path.split(os.path.splitext(path)[0]) - module = _handle_import(directory, name) - handlers: List[Handler] = [] - self._discover_classes(module, handlers, path) - self._discover_functions(module, handlers, path, delay=False) - if not handlers: - error('{} exports no handlers'.format(path)) + plugin_spec = self._load_plugin(path=path) + if not plugin_spec: continue - self._loaded[path] = {'handlers': handlers, 'module': module} + if plugin_spec["path"] == "script_host.py": + has_script = True except Exception as e: - err = ('Encountered {} loading plugin at {}: {}\n{}' - .format(type(e).__name__, path, e, format_exc(5))) - error(err) - self._load_errors[path] = err + errmsg: str = ( + 'Encountered {} loading plugin at {}: {}\n{}'.format( + type(e).__name__, path, e, format_exc(5))) + error(errmsg) + self._load_errors[path] = errmsg kind = ("script-host" if len(plugins) == 1 and has_script else "rplugin-host") @@ -233,6 +231,36 @@ def _load(self, plugins: Sequence[str]) -> None: self.name = info[0] self.nvim.api.set_client_info(*info, async_=True) + def _load_plugin( + self, path: str, *, + module: Optional[ModuleType] = None, + ) -> Union[Dict[str, Any], None]: + # Note: path must be normalized. + if path in self._loaded: + warn('{} is already loaded'.format(path)) + return None + + if path == "script_host.py": + module = script_host + elif module is not None: + pass # Note: module is provided only when testing + else: + directory, module_name = os.path.split(os.path.splitext(path)[0]) + module = _handle_import(directory, module_name) + handlers: List[Handler] = [] + self._discover_classes(module, handlers, path) + self._discover_functions(module, handlers, path, delay=False) + if not handlers: + error('{} exports no handlers'.format(path)) + return None + + self._loaded[path] = { + 'handlers': handlers, + 'module': module, + 'path': path, + } + return self._loaded[path] + def _unload(self) -> None: for path, plugin in self._loaded.items(): handlers = plugin['handlers'] diff --git a/test/test_host.py b/test/test_host.py index 860dc553..5f8c9e2b 100644 --- a/test/test_host.py +++ b/test/test_host.py @@ -1,9 +1,11 @@ # type: ignore # pylint: disable=protected-access import os +from types import SimpleNamespace from typing import Sequence from pynvim.api.nvim import Nvim +from pynvim.plugin import decorators from pynvim.plugin.host import Host, host_method_spec from pynvim.plugin.script_host import ScriptHost @@ -32,6 +34,8 @@ def test_host_import_rplugin_modules(vim: Nvim): ] h._load(plugins) assert len(h._loaded) == 2 + assert len(h._specs) == 2 + assert len(h._load_errors) == 0 # pylint: disable-next=unbalanced-tuple-unpacking simple_nvim, mymodule = list(h._loaded.values()) @@ -39,7 +43,82 @@ def test_host_import_rplugin_modules(vim: Nvim): assert mymodule['module'].__name__ == 'mymodule' -def test_host_clientinfo(vim): +# @pytest.mark.timeout(5.0) +def test_host_register_plugin_handlers(vim: Nvim): + """Test whether a Host can register plugin's RPC handlers.""" + h = Host(vim) + + @decorators.plugin + class TestPluginModule: + """A plugin for testing, having all types of the decorators.""" + def __init__(self, nvim: Nvim): + self._nvim = nvim + + @decorators.rpc_export('python_foobar', sync=True) + def foobar(self): + pass + + @decorators.command("MyCommandSync", sync=True) + def command(self): + pass + + @decorators.function("MyFunction", sync=True) + def function(self, a, b): + return a + b + + @decorators.autocmd("BufEnter", pattern="*.py", sync=True) + def buf_enter(self): + vim.command("echom 'BufEnter'") + + @decorators.rpc_export('python_foobar_async', sync=False) + def foobar_async(self): + pass + + @decorators.command("MyCommandAsync", sync=False) + def command_async(self): + pass + + @decorators.function("MyFunctionAsync", sync=False) + def function_async(self, a, b): + return a + b + + @decorators.autocmd("BufEnter", pattern="*.async", sync=False) + def buf_enter_async(self): + vim.command("echom 'BufEnter'") + + @decorators.shutdown_hook + def shutdown_hook(): + print("bye") + + @decorators.function("ModuleFunction") + def module_function(self): + pass + + dummy_module = SimpleNamespace( + TestPluginModule=TestPluginModule, + module_function=module_function, + ) + h._load_plugin("virtual://dummy_module", module=dummy_module) + assert list(h._loaded.keys()) == ["virtual://dummy_module"] + assert h._loaded['virtual://dummy_module']['module'] is dummy_module + + # _notification_handlers: async commands and functions + print(h._notification_handlers.keys()) + assert 'python_foobar_async' in h._notification_handlers + assert 'virtual://dummy_module:autocmd:BufEnter:*.async' in h._notification_handlers + assert 'virtual://dummy_module:command:MyCommandAsync' in h._notification_handlers + assert 'virtual://dummy_module:function:MyFunctionAsync' in h._notification_handlers + assert 'virtual://dummy_module:function:ModuleFunction' in h._notification_handlers + + # _request_handlers: sync commands and functions + print(h._request_handlers.keys()) + assert 'python_foobar' in h._request_handlers + assert 'virtual://dummy_module:autocmd:BufEnter:*.py' in h._request_handlers + assert 'virtual://dummy_module:command:MyCommandSync' in h._request_handlers + assert 'virtual://dummy_module:function:MyFunction' in h._request_handlers + + +def test_host_clientinfo(vim: Nvim): h = Host(vim) assert h._request_handlers.keys() == host_method_spec.keys() assert 'remote' == vim.api.get_chan_info(vim.channel_id)['client']['type']