Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve robustness of PipInstall plugin #7111

Merged
merged 7 commits into from Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 66 additions & 30 deletions distributed/diagnostics/plugin.py
Expand Up @@ -254,10 +254,6 @@ class PipInstall(WorkerPlugin):
libraries in the worker environment or image. This is
primarily intended for experimentation and debugging.

Additional issues may arise if multiple workers share the same
file system. Each worker might try to install the packages
simultaneously.

Parameters
----------
packages : List[str]
Expand All @@ -282,34 +278,74 @@ def __init__(self, packages, pip_options=None, restart=False):
self.packages = packages
self.restart = restart
self.pip_options = pip_options or []
self.id = f"pip-install-{uuid.uuid4()}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use the worker.id instead. I think that there is a 1to1 relationship between PipInstallPlugin and Worker instances but either way, the uniqueness property should link to the worker, shouldn't ti?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ID is unique per PipInstall instance. It's mainly there to avoid conflicts between multiple PipInstall plugins all trying to install a bunch of packages on the worker.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See PipInstall._compose_{installed|restarted}_key() for the namespacing. We achieve the uniqueness per worker of PipInstall._compose_restarted_key() using worker.nanny, which should be the correct key, but we may want to add the pid to that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worker.nanny is not perfect but I believe good enough. I don't think we have the PID of the nanny and I don't think this is necessary. Let's not overdo it here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL: We teardown plugins that are overwritten which allows us to nicely remove their metadata. We can register also plugins to custom names ignoring their name attribute. However, we do not tell those plugins the name they are registered by. For PipInstall plugins, plugin.name will remain pip. We also do not pass that information on setup or teardown. Otherwise, we would be able to drop the ID and just use their name.


async def setup(self, worker):
from distributed.lock import Lock

async with Lock(socket.gethostname()): # don't clobber one installation
logger.info("Pip installing the following packages: %s", self.packages)
proc = subprocess.Popen(
[sys.executable, "-m", "pip", "install"]
+ self.pip_options
+ self.packages,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = proc.communicate()
returncode = proc.wait()

if returncode:
logger.error("Pip install failed with '%s'", stderr.decode().strip())
return

if self.restart and worker.nanny:
lines = stdout.strip().split(b"\n")
if not all(
line.startswith(b"Requirement already satisfied") for line in lines
):
worker.loop.add_callback(
worker.close_gracefully, restart=True
) # restart
from distributed.semaphore import Semaphore

async with (
await Semaphore(max_leases=1, name=socket.gethostname(), register=True)
):
if not await self._is_installed(worker):
logger.info("Pip installing the following packages: %s", self.packages)
await self._set_installed(worker)
self._install()
else:
logger.info(
"The following packages have already been installed: %s",
self.packages,
)

if self.restart and worker.nanny and not await self._is_restarted(worker):
logger.info("Restarting worker to refresh interpreter.")
await self._set_restarted(worker)
worker.loop.add_callback(worker.close_gracefully, restart=True)

def _install(self):
proc = subprocess.Popen(
[sys.executable, "-m", "pip", "install"] + self.pip_options + self.packages,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
_, stderr = proc.communicate()
returncode = proc.wait()
if returncode != 0:
msg = f"Pip install failed with '{stderr.decode().strip()}'"
logger.error(msg)
raise RuntimeError(msg)

async def _is_installed(self, worker):
return await worker.client.get_metadata(
self._compose_installed_key(), default=False
)

async def _set_installed(self, worker):
await worker.client.set_metadata(
self._compose_installed_key(),
True,
)

def _compose_installed_key(self):
return [
self.id,
"installed",
socket.gethostname(),
]

async def _is_restarted(self, worker):
return await worker.client.get_metadata(
self._compose_restarted_key(worker),
default=False,
)

async def _set_restarted(self, worker):
await worker.client.set_metadata(
self._compose_restarted_key(worker),
True,
)

def _compose_restarted_key(self, worker):
return [self.id, "restarted", worker.nanny]


# Adapted from https://github.com/dask/distributed/issues/3560#issuecomment-596138522
Expand Down
130 changes: 104 additions & 26 deletions distributed/tests/test_worker.py
Expand Up @@ -15,6 +15,7 @@
from concurrent.futures.process import BrokenProcessPool
from numbers import Number
from operator import add
from textwrap import dedent
from time import sleep
from unittest import mock

Expand Down Expand Up @@ -1629,51 +1630,128 @@ def bad_startup(w):
pass


@gen_cluster(client=True)
async def test_pip_install(c, s, a, b):
with mock.patch(
"distributed.diagnostics.plugin.subprocess.Popen.communicate",
return_value=(b"", b""),
) as p1:
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_pip_install(c, s, a):
with captured_logger(
"distributed.diagnostics.plugin", level=logging.INFO
) as logger:
mocked = mock.Mock()
mocked.configure_mock(
**{"communicate.return_value": (b"", b""), "wait.return_value": 0}
)
with mock.patch(
"distributed.diagnostics.plugin.subprocess.Popen", return_value=p1
) as p2:
p1.communicate.return_value = b"", b""
p1.wait.return_value = 0
"distributed.diagnostics.plugin.subprocess.Popen", return_value=mocked
) as Popen:
await c.register_worker_plugin(
PipInstall(packages=["requests"], pip_options=["--upgrade"])
)

args = p2.call_args[0][0]
args = Popen.call_args[0][0]
assert "python" in args[0]
assert args[1:] == ["-m", "pip", "install", "--upgrade", "requests"]
assert Popen.call_count == 1
logs = logger.getvalue()
assert "Pip installing" in logs
assert "failed" not in logs
assert "restart" not in logs


@gen_cluster(client=True)
@gen_cluster(client=True, nthreads=[("", 2), ("", 2)])
async def test_pip_install_fails(c, s, a, b):
with captured_logger(
"distributed.diagnostics.plugin", level=logging.ERROR
) as logger:
with mock.patch(
"distributed.diagnostics.plugin.subprocess.Popen.communicate",
return_value=(b"", b"error"),
) as p1:
with mock.patch(
"distributed.diagnostics.plugin.subprocess.Popen", return_value=p1
) as p2:
p1.communicate.return_value = (
mocked = mock.Mock()
mocked.configure_mock(
**{
"communicate.return_value": (
b"",
b"Could not find a version that satisfies the requirement not-a-package",
)
p1.wait.return_value = 1
),
"wait.return_value": 1,
}
)
with mock.patch(
"distributed.diagnostics.plugin.subprocess.Popen", return_value=mocked
) as Popen:
with pytest.raises(RuntimeError):
await c.register_worker_plugin(PipInstall(packages=["not-a-package"]))

assert "not-a-package" in logger.getvalue()
assert Popen.call_count == 1
logs = logger.getvalue()
assert "install failed" in logs
assert "not-a-package" in logs


@gen_cluster(client=True, nthreads=[])
async def test_pip_install_restarts_on_nanny(c, s):
preload = dedent(
"""\
from unittest import mock

mock.patch(
"distributed.diagnostics.plugin.PipInstall._install", return_value=None
).start()
"""
)
async with Nanny(s.address, preload=preload):
(addr,) = s.workers
await c.register_worker_plugin(
PipInstall(packages=["requests"], pip_options=["--upgrade"], restart=True)
)

# Wait until the worker is restarted
while len(s.workers) != 1 or set(s.workers) == {addr}:
await asyncio.sleep(0.01)


# args = p2.call_args[0][0]
# assert "python" in args[0]
# assert args[1:] == ["-m", "pip", "--upgrade", "install", "requests"]
@gen_cluster(client=True, nthreads=[])
async def test_pip_install_failing_does_not_restart_on_nanny(c, s):
preload = dedent(
"""\
from unittest import mock

mock.patch(
"distributed.diagnostics.plugin.PipInstall._install", side_effect=RuntimeError
).start()
"""
)
async with Nanny(s.address, preload=preload) as n:
(addr,) = s.workers
with pytest.raises(RuntimeError):
await c.register_worker_plugin(
PipInstall(
packages=["requests"], pip_options=["--upgrade"], restart=True
)
)
# Nanny does not restart
assert n.status is Status.running
assert set(s.workers) == {addr}


@gen_cluster(client=True, nthreads=[("", 1), ("", 1)])
async def test_pip_install_multiple_workers(c, s, a, b):
with captured_logger(
"distributed.diagnostics.plugin", level=logging.INFO
) as logger:
mocked = mock.Mock()
mocked.configure_mock(
**{"communicate.return_value": (b"", b""), "wait.return_value": 0}
)
with mock.patch(
"distributed.diagnostics.plugin.subprocess.Popen", return_value=mocked
) as Popen:
await c.register_worker_plugin(
PipInstall(packages=["requests"], pip_options=["--upgrade"])
)

args = Popen.call_args[0][0]
assert "python" in args[0]
assert args[1:] == ["-m", "pip", "install", "--upgrade", "requests"]
assert Popen.call_count == 1
logs = logger.getvalue()
assert "Pip installing" in logs
assert "already been installed" in logs


@gen_cluster(nthreads=[])
Expand Down