diff --git a/distributed/comm/registry.py b/distributed/comm/registry.py index 15d9af51781..47ba730a7d9 100644 --- a/distributed/comm/registry.py +++ b/distributed/comm/registry.py @@ -4,9 +4,18 @@ import sys from abc import ABC, abstractmethod from collections.abc import Iterable +from typing import Protocol + + +class _EntryPoints(Protocol): + def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]: + ... + if sys.version_info >= (3, 10): - _entry_points = importlib.metadata.entry_points + # py3.10 importlib.metadata type annotations are not in mypy yet + # https://github.com/python/typeshed/pull/7331 + _entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment] else: def _entry_points(