Skip to content

Commit

Permalink
check_process_leak overhaul (#5739)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Feb 3, 2022
1 parent 5634434 commit 834421b
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 14 deletions.
59 changes: 59 additions & 0 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import os
import pathlib
import signal
import socket
import threading
from contextlib import contextmanager
Expand All @@ -13,12 +14,15 @@
import dask.config

from distributed import Client, Nanny, Scheduler, Worker, config, default_client
from distributed.compatibility import WINDOWS
from distributed.core import Server, rpc
from distributed.metrics import time
from distributed.utils import mp_context
from distributed.utils_test import (
_LockedCommPool,
_UnhashableCallable,
assert_worker_story,
check_process_leak,
cluster,
dump_cluster_state,
gen_cluster,
Expand Down Expand Up @@ -531,3 +535,58 @@ async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmpdir):
)

clog_fut.cancel()


def garbage_process(barrier, ignore_sigterm: bool = False, t: float = 3600) -> None:
if ignore_sigterm:
for signum in (signal.SIGTERM, signal.SIGHUP, signal.SIGINT):
signal.signal(signum, signal.SIG_IGN)
barrier.wait()
sleep(t)


def test_check_process_leak():
barrier = mp_context.Barrier(parties=2)
with pytest.raises(AssertionError):
with check_process_leak(check=True, check_timeout=0.01):
p = mp_context.Process(target=garbage_process, args=(barrier,))
p.start()
barrier.wait()
assert not p.is_alive()


def test_check_process_leak_slow_cleanup():
"""check_process_leak waits a bit for processes to terminate themselves"""
barrier = mp_context.Barrier(parties=2)
with check_process_leak(check=True):
p = mp_context.Process(target=garbage_process, args=(barrier, False, 0.2))
p.start()
barrier.wait()
assert not p.is_alive()


@pytest.mark.parametrize(
"ignore_sigterm",
[False, pytest.param(True, marks=pytest.mark.skipif(WINDOWS, reason="no SIGKILL"))],
)
def test_check_process_leak_pre_cleanup(ignore_sigterm):
barrier = mp_context.Barrier(parties=2)
p = mp_context.Process(target=garbage_process, args=(barrier, ignore_sigterm))
p.start()
barrier.wait()

with check_process_leak(term_timeout=0.2):
assert not p.is_alive()


@pytest.mark.parametrize(
"ignore_sigterm",
[False, pytest.param(True, marks=pytest.mark.skipif(WINDOWS, reason="no SIGKILL"))],
)
def test_check_process_leak_post_cleanup(ignore_sigterm):
barrier = mp_context.Barrier(parties=2)
with check_process_leak(check=False, term_timeout=0.2):
p = mp_context.Process(target=garbage_process, args=(barrier, ignore_sigterm))
p.start()
barrier.wait()
assert not p.is_alive()
72 changes: 58 additions & 14 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io
import logging
import logging.config
import multiprocessing
import os
import queue
import re
Expand Down Expand Up @@ -1595,24 +1596,67 @@ def check_thread_leak():
assert False, (bad_thread, call_stacks)


@contextmanager
def check_process_leak(check=True):
for proc in mp_context.active_children():
def wait_active_children(timeout: float) -> list[multiprocessing.Process]:
"""Wait until timeout for mp_context.active_children() to terminate.
Return list of active subprocesses after the timeout expired.
"""
t0 = time()
while True:
# Do not sample the subprocesses once at the beginning with
# `for proc in mp_context.active_children: ...`, assume instead that new
# children processes may be spawned before the timeout expires.
children = mp_context.active_children()
if not children:
return []
join_timeout = timeout - time() + t0
if join_timeout <= 0:
return children
children[0].join(timeout=join_timeout)


def term_or_kill_active_children(timeout: float) -> None:
"""Send SIGTERM to mp_context.active_children(), wait up to 3 seconds for processes
to die, then send SIGKILL to the survivors
"""
children = mp_context.active_children()
for proc in children:
proc.terminate()

yield
children = wait_active_children(timeout=timeout)
for proc in children:
proc.kill()

if check:
for i in range(200):
if not set(mp_context.active_children()):
break
else:
sleep(0.2)
else:
assert not mp_context.active_children()
children = wait_active_children(timeout=30)
if children: # pragma: nocover
logger.warning("Leaked unkillable children processes: %s", children)
# It should be impossible to ignore SIGKILL on Linux/MacOSX
assert WINDOWS

for proc in mp_context.active_children():
proc.terminate()

@contextmanager
def check_process_leak(
check: bool = True, check_timeout: float = 40, term_timeout: float = 3
):
"""Terminate any currently-running subprocesses at both the beginning and end of this context
Parameters
----------
check : bool, optional
If True, raise AssertionError if any processes survive at the exit
check_timeout: float, optional
Wait up to these many seconds for subprocesses to terminate before failing
term_timeout: float, optional
After sending SIGTERM to a subprocess, wait up to these many seconds before
sending SIGKILL
"""
term_or_kill_active_children(timeout=term_timeout)
try:
yield
if check:
children = wait_active_children(timeout=check_timeout)
assert not children, f"Test leaked subprocesses: {children}"
finally:
term_or_kill_active_children(timeout=term_timeout)


@contextmanager
Expand Down

0 comments on commit 834421b

Please sign in to comment.