Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drop pkg_resources in favour of importlib.metadata #5923

Merged
merged 3 commits into from Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Expand Up @@ -39,7 +39,6 @@ repos:
- types-docutils
- types-requests
- types-paramiko
- types-pkg_resources
- types-PyYAML
- types-setuptools
- types-psutil
Expand Down
67 changes: 37 additions & 30 deletions distributed/comm/registry.py
@@ -1,6 +1,29 @@
from __future__ import annotations

import importlib.metadata
import sys
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Protocol


class _EntryPoints(Protocol):
def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]:
...


if sys.version_info >= (3, 10):
# py3.10 importlib.metadata type annotations are not in mypy yet
# https://github.com/python/typeshed/pull/7331
_entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment]
else:

def _entry_points(
*, group: str, name: str
) -> Iterable[importlib.metadata.EntryPoint]:
graingert marked this conversation as resolved.
Show resolved Hide resolved
for ep in importlib.metadata.entry_points().get(group, []):
if ep.name == name:
yield ep


class Backend(ABC):
Expand Down Expand Up @@ -59,40 +82,24 @@ def get_local_address_for(self, loc):
backends: dict[str, Backend] = {}


def get_backend(scheme: str, require: bool = True) -> Backend:
def get_backend(scheme: str) -> Backend:
"""
Get the Backend instance for the given *scheme*.
It looks for matching scheme in dask's internal cache, and falls-back to
package metadata for the group name ``distributed.comm.backends``

Parameters
----------

require : bool
Verify that the backends requirements are properly installed. See
https://setuptools.readthedocs.io/en/latest/pkg_resources.html for more
information.
"""

backend = backends.get(scheme)
if backend is None:
import pkg_resources

backend = None
for backend_class_ep in pkg_resources.iter_entry_points(
"distributed.comm.backends", scheme
):
# resolve and require are equivalent to load
backend_factory = backend_class_ep.resolve()
if require:
backend_class_ep.require()
backend = backend_factory()

if backend is None:
raise ValueError(
"unknown address scheme %r (known schemes: %s)"
% (scheme, sorted(backends))
)
else:
backends[scheme] = backend
return backend
if backend is not None:
return backend

for backend_class_ep in _entry_points(
name=scheme, group="distributed.comm.backends"
):
backend = backend_class_ep.load()()
backends[scheme] = backend
return backend

raise ValueError(
f"unknown address scheme {scheme!r} (known schemes: {sorted(backends)})"
)
41 changes: 13 additions & 28 deletions distributed/comm/tests/test_comms.py
Expand Up @@ -2,18 +2,15 @@
import os
import sys
import threading
import types
import warnings
from functools import partial

import pkg_resources
import pytest
from tornado import ioloop
from tornado.concurrent import Future

import dask

import distributed
from distributed.comm import (
CommClosedError,
asyncio_tcp,
Expand All @@ -30,7 +27,7 @@
from distributed.comm.registry import backends, get_backend
from distributed.metrics import time
from distributed.protocol import Serialized, deserialize, serialize, to_serialize
from distributed.utils import get_ip, get_ipv6
from distributed.utils import get_ip, get_ipv6, mp_context
from distributed.utils_test import (
get_cert,
get_client_ssl_context,
Expand Down Expand Up @@ -1313,30 +1310,18 @@ async def test_inproc_adresses():
await check_addresses(a, b)


def test_register_backend_entrypoint():
# Code adapted from pandas backend entry point testing
# https://github.com/pandas-dev/pandas/blob/2470690b9f0826a8feb426927694fa3500c3e8d2/pandas/tests/plotting/test_backend.py#L50-L76
Comment on lines -1317 to -1318
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pandas test_register_entrypoint now looks like this: https://github.com/pandas-dev/pandas/blob/48d515958d5805f0e62e34b7424097e5575089a8/pandas/tests/plotting/test_backend.py#L47-L62

I used:

    monkeypatch.syspath_prepend(tmp_path)
    monkeypatch.setitem(sys.modules, "pandas_dummy_backend", dummy_backend)

in the pandas version of this PR because they already do a lot of mutating of sys.modules so figured a sys.path mutation wasn't so bad. Here I use a multiprocessing.Pool(1) to isolate sys.modules and sys.path changes

def _get_backend_on_path(path):
sys.path.append(os.fsdecode(path))
return get_backend("udp")

dist = pkg_resources.get_distribution("distributed")
if dist.module_path not in distributed.__file__:
# We are running from a non-installed distributed, and this test is invalid
pytest.skip("Testing a non-installed distributed")

mod = types.ModuleType("dask_udp")
mod.UDPBackend = lambda: 1
sys.modules[mod.__name__] = mod

entry_point_name = "distributed.comm.backends"
backends_entry_map = pkg_resources.get_entry_map("distributed")
if entry_point_name not in backends_entry_map:
backends_entry_map[entry_point_name] = dict()
backends_entry_map[entry_point_name]["udp"] = pkg_resources.EntryPoint(
"udp", mod.__name__, attrs=["UDPBackend"], dist=dist
def test_register_backend_entrypoint(tmp_path):
(tmp_path / "dask_udp.py").write_bytes(b"def udp_backend():\n return 1\n")
dist_info = tmp_path / "dask_udp-0.0.0.dist-info"
dist_info.mkdir()
(dist_info / "entry_points.txt").write_bytes(
b"[distributed.comm.backends]\nudp = dask_udp:udp_backend\n"
)

# The require is disabled here since particularly unit tests may install
# dirty or dev versions which are conflicting with backend entrypoints if
# they are demanding for exact, stable versions. This should not fail the
# test
result = get_backend("udp", require=False)
assert result == 1
with mp_context.Pool(1) as pool:
assert pool.apply(_get_backend_on_path, args=(tmp_path,)) == 1
pool.join()
2 changes: 0 additions & 2 deletions distributed/utils.py
Expand Up @@ -74,8 +74,6 @@ def _initialize_mp_context():
if method == "forkserver":
# Makes the test suite much faster
preload = ["distributed"]
if "pkg_resources" in sys.modules:
preload.append("pkg_resources")

from distributed.versions import optional_packages, required_packages

Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Expand Up @@ -11,4 +11,3 @@ toolz >= 0.8.2
tornado >= 6.0.3
zict >= 0.1.3
pyyaml
setuptools