diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index 3b0aa5b4c70..57a7168a3a2 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -5,6 +5,7 @@ import logging import gc import os +import re import shutil import sys import tempfile @@ -16,12 +17,7 @@ from distributed import Scheduler from distributed.security import Security -from distributed.utils import get_ip_interface -from distributed.cli.utils import ( - check_python_3, - install_signal_handlers, - uri_from_host_port, -) +from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.preloading import preload_modules, validate_preload_argv from distributed.proctitle import ( enable_proctitle_on_children, @@ -151,6 +147,9 @@ def main( ) dashboard_address = bokeh_port + if port is None and (not host or not re.search(r":\d", host)): + port = 8786 + sec = Security( tls_ca_file=tls_ca_file, tls_scheduler_cert=tls_cert, tls_scheduler_key=tls_key ) @@ -186,14 +185,6 @@ def del_pid_file(): limit = max(soft, hard // 2) resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) - if interface: - if host: - raise ValueError("Can not specify both interface and host") - else: - host = get_ip_interface(interface) - - addr = uri_from_host_port(host, port, 8786) - loop = IOLoop.current() logger.info("-" * 47) @@ -213,9 +204,15 @@ def del_pid_file(): logger.info("Unable to import bokeh: %s" % str(error)) scheduler = Scheduler( - loop=loop, services=services, scheduler_file=scheduler_file, security=sec + loop=loop, + services=services, + scheduler_file=scheduler_file, + security=sec, + host=host, + port=port, + interface=interface, ) - scheduler.start(addr) + scheduler.start() if not preload: preload = dask.config.get("distributed.scheduler.preload") if not preload_argv: @@ -237,7 +234,7 @@ def del_pid_file(): if local_directory_created: shutil.rmtree(local_directory) - logger.info("End scheduler at %r", addr) + logger.info("End scheduler at %r", scheduler.address) def go(): diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index a0bc801a960..6315939005d 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -10,14 +10,10 @@ import click from distributed import Nanny, Worker from distributed.config import config -from distributed.utils import get_ip_interface, parse_timedelta +from distributed.utils import parse_timedelta from distributed.worker import _ncores from distributed.security import Security -from distributed.cli.utils import ( - check_python_3, - uri_from_host_port, - install_signal_handlers, -) +from distributed.cli.utils import check_python_3, install_signal_handlers from distributed.comm import get_address_host_port from distributed.preloading import validate_preload_argv from distributed.proctitle import ( @@ -328,18 +324,6 @@ def del_pid_file(): "dask-worker SCHEDULER_ADDRESS:8786" ) - if interface: - if host: - raise ValueError("Can not specify both interface and host") - else: - host = get_ip_interface(interface) - - if host or port: - addr = uri_from_host_port(host, port, 0) - else: - # Choose appropriate address for scheduler - addr = None - if death_timeout is not None: death_timeout = parse_timedelta(death_timeout, "s") @@ -359,6 +343,9 @@ def del_pid_file(): preload_argv=preload_argv, security=sec, contact_address=contact_address, + interface=interface, + host=host, + port=port, name=name if nprocs == 1 or not name else name + "-" + str(i), **kwargs ) @@ -377,7 +364,7 @@ def on_signal(signum): @gen.coroutine def run(): - yield [n._start(addr) for n in nannies] + yield nannies while all(n.status != "closed" for n in nannies): yield gen.sleep(0.2) diff --git a/distributed/cli/tests/test_cli_utils.py b/distributed/cli/tests/test_cli_utils.py deleted file mode 100644 index 4f07f699de5..00000000000 --- a/distributed/cli/tests/test_cli_utils.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import print_function, division, absolute_import - -import pytest - -pytest.importorskip("requests") - -from distributed.cli.utils import uri_from_host_port -from distributed.utils import get_ip - - -external_ip = get_ip() - - -def test_uri_from_host_port(): - f = uri_from_host_port - - assert f("", 456, None) == "tcp://:456" - assert f("", 456, 123) == "tcp://:456" - assert f("", None, 123) == "tcp://:123" - assert f("", None, 0) == "tcp://" - assert f("", 0, 123) == "tcp://" - - assert f("localhost", 456, None) == "tcp://localhost:456" - assert f("localhost", 456, 123) == "tcp://localhost:456" - assert f("localhost", None, 123) == "tcp://localhost:123" - assert f("localhost", None, 0) == "tcp://localhost" - - assert f("192.168.1.2", 456, None) == "tcp://192.168.1.2:456" - assert f("192.168.1.2", 456, 123) == "tcp://192.168.1.2:456" - assert f("192.168.1.2", None, 123) == "tcp://192.168.1.2:123" - assert f("192.168.1.2", None, 0) == "tcp://192.168.1.2" - - assert f("tcp://192.168.1.2", 456, None) == "tcp://192.168.1.2:456" - assert f("tcp://192.168.1.2", 456, 123) == "tcp://192.168.1.2:456" - assert f("tcp://192.168.1.2", None, 123) == "tcp://192.168.1.2:123" - assert f("tcp://192.168.1.2", None, 0) == "tcp://192.168.1.2" - - assert f("tcp://192.168.1.2:456", None, None) == "tcp://192.168.1.2:456" - assert f("tcp://192.168.1.2:456", 0, 0) == "tcp://192.168.1.2:456" - assert f("tcp://192.168.1.2:456", 0, 123) == "tcp://192.168.1.2:456" - assert f("tcp://192.168.1.2:456", 456, 123) == "tcp://192.168.1.2:456" - - with pytest.raises(ValueError): - # Two incompatible port values - f("tcp://192.168.1.2:456", 123, None) - - assert f("tls://192.168.1.2:456", None, None) == "tls://192.168.1.2:456" - assert f("tls://192.168.1.2:456", 0, 0) == "tls://192.168.1.2:456" - assert f("tls://192.168.1.2:456", 0, 123) == "tls://192.168.1.2:456" - assert f("tls://192.168.1.2:456", 456, 123) == "tls://192.168.1.2:456" - - assert f("tcp://[::1]:456", None, None) == "tcp://[::1]:456" - - assert f("tls://[::1]:456", None, None) == "tls://[::1]:456" diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 72084e53141..2fa3779d9b4 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -7,7 +7,6 @@ import requests import sys from time import sleep -from toolz import first from distributed import Client from distributed.metrics import time @@ -52,7 +51,7 @@ def test_memory_limit(loop): while not c.ncores(): sleep(0.1) info = c.scheduler_info() - d = first(info["workers"].values()) + [d] = info["workers"].values() assert isinstance(d["memory_limit"], int) assert d["memory_limit"] == 2e9 diff --git a/distributed/cli/utils.py b/distributed/cli/utils.py index 4ce1d845821..2c2088a7556 100644 --- a/distributed/cli/utils.py +++ b/distributed/cli/utils.py @@ -3,13 +3,6 @@ from tornado import gen from tornado.ioloop import IOLoop -from distributed.comm import ( - parse_address, - unparse_address, - parse_host_port, - unparse_host_port, -) - py3_err_msg = """ Warning: Your terminal does not set locales. @@ -75,36 +68,3 @@ def cleanup_and_stop(): for sig in [signal.SIGINT, signal.SIGTERM]: old_handlers[sig] = signal.signal(sig, handle_signal) - - -def uri_from_host_port(host_arg, port_arg, default_port): - """ - Process the *host* and *port* CLI options. - Return a URI. - """ - # Much of distributed depends on a well-known IP being assigned to - # each entity (Worker, Scheduler, etc.), so avoid "universal" addresses - # like '' which would listen on all registered IPs and interfaces. - scheme, loc = parse_address(host_arg or "") - - host, port = parse_host_port( - loc, port_arg if port_arg is not None else default_port - ) - - if port is None and port_arg is None: - port_arg = default_port - - if port and port_arg and port != port_arg: - raise ValueError( - "port number given twice in options: " - "host %r and port %r" % (host_arg, port_arg) - ) - if port is None and port_arg is not None: - port = port_arg - # Note `port = 0` means "choose a random port" - if port is None: - port = default_port - loc = unparse_host_port(host, port) - addr = unparse_address(scheme, loc) - - return addr diff --git a/distributed/comm/addressing.py b/distributed/comm/addressing.py index 20ddb2c863f..3d79befe0f1 100644 --- a/distributed/comm/addressing.py +++ b/distributed/comm/addressing.py @@ -5,6 +5,7 @@ import dask from . import registry +from ..utils import get_ip_interface DEFAULT_SCHEME = dask.config.get("distributed.comm.default-scheme") @@ -172,3 +173,71 @@ def resolve_address(addr): scheme, loc = parse_address(addr) backend = registry.get_backend(scheme) return unparse_address(scheme, backend.resolve_address(loc)) + + +def uri_from_host_port(host_arg, port_arg, default_port): + """ + Process the *host* and *port* CLI options. + Return a URI. + """ + # Much of distributed depends on a well-known IP being assigned to + # each entity (Worker, Scheduler, etc.), so avoid "universal" addresses + # like '' which would listen on all registered IPs and interfaces. + scheme, loc = parse_address(host_arg or "") + + host, port = parse_host_port( + loc, port_arg if port_arg is not None else default_port + ) + + if port is None and port_arg is None: + port_arg = default_port + + if port and port_arg and port != port_arg: + raise ValueError( + "port number given twice in options: " + "host %r and port %r" % (host_arg, port_arg) + ) + if port is None and port_arg is not None: + port = port_arg + # Note `port = 0` means "choose a random port" + if port is None: + port = default_port + loc = unparse_host_port(host, port) + addr = unparse_address(scheme, loc) + + return addr + + +def address_from_user_args( + host=None, port=None, interface=None, protocol=None, peer=None, security=None +): + """ Get an address to listen on from common user provided arguments """ + if security and security.require_encryption and not protocol: + protocol = "tls" + + if protocol and protocol.rstrip("://") == "inplace": + if host or port or interface: + raise ValueError( + "Can not specify inproc protocol and host or port or interface" + ) + else: + return "inproc://" + + if interface: + if host: + raise ValueError("Can not specify both interface and host", interface, host) + else: + host = get_ip_interface(interface) + + if protocol and host and "://" not in host: + host = protocol.rstrip("://") + "://" + host + + if host or port: + addr = uri_from_host_port(host, port, 0) + else: + addr = "" + + if protocol and "://" not in addr: + addr = protocol.rstrip("://") + "://" + addr + + return addr diff --git a/distributed/nanny.py b/distributed/nanny.py index 60e83e86da7..ef0e0a38f0e 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -16,6 +16,7 @@ from tornado.locks import Event from .comm import get_address_host, get_local_address_for, unparse_host_port +from .comm.addressing import address_from_user_args from .core import rpc, RPCClosed, CommClosedError, coerce_to_address from .metrics import time from .node import ServerNode @@ -69,6 +70,10 @@ def __init__( listen_address=None, worker_class=None, env=None, + interface=None, + host=None, + port=None, + protocol=None, **worker_kwargs ): @@ -135,6 +140,14 @@ def __init__( pc = PeriodicCallback(self.memory_monitor, 100, io_loop=self.loop) self.periodic_callbacks["memory"] = pc + self._start_address = address_from_user_args( + host=host, + port=port, + interface=interface, + protocol=protocol, + security=security, + ) + self._listen_address = listen_address self.status = "init" @@ -175,6 +188,7 @@ def worker_dir(self): @gen.coroutine def _start(self, addr_or_port=0): """ Start nanny, start local process, start watching """ + addr_or_port = addr_or_port or self._start_address # XXX Factor this out if not addr_or_port: @@ -419,6 +433,7 @@ def start(self): self.process = AsyncProcess( target=self._run, + name="Dask Worker process (from Nanny)", kwargs=dict( worker_args=self.worker_args, worker_kwargs=self.worker_kwargs, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8442b6ddcea..8ba4cedf468 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -35,6 +35,7 @@ get_address_host, unparse_host_port, ) +from .comm.addressing import address_from_user_args from .compatibility import finalize, unicode, Mapping, Set from .core import rpc, connect, send_recv, clean_exception, CommClosedError from . import profile @@ -824,9 +825,12 @@ def __init__( security=None, worker_ttl=None, idle_timeout=None, + interface=None, + host=None, + port=8786, + protocol=None, **kwargs ): - self._setup_logging() # Attributes @@ -1056,6 +1060,14 @@ def __init__( connection_limit = get_fileno_limit() / 2 + self._start_address = address_from_user_args( + host=host, + port=port, + interface=interface, + protocol=protocol, + security=security, + ) + super(Scheduler, self).__init__( handlers=self.handlers, stream_handlers=merge(worker_handlers, client_handlers), @@ -1172,10 +1184,12 @@ def stop_services(self): for service in self.services.values(): service.stop() - def start(self, addr_or_port=8786, start_queues=True): + def start(self, addr_or_port=None, start_queues=True): """ Clear out old state and restart all running coroutines """ enable_gc_diagnosis() + addr_or_port = addr_or_port or self._start_address + self.clear_task_state() with ignoring(AttributeError): @@ -1234,6 +1248,15 @@ def del_scheduler_file(): return self.finished() + def __await__(self): + self.start() + + @gen.coroutine + def _(): + return self + + return _().__await__() + @gen.coroutine def finished(self): """ Wait until all coroutines have ceased """ diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8582c2abc83..3515be9ebcb 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2824,8 +2824,7 @@ def test_diagnostic_nbytes(c, s, a, b): @gen_test() def test_worker_aliases(): - s = Scheduler(validate=True) - s.start(0) + s = yield Scheduler(validate=True, port=0) a = Worker(s.ip, s.port, name="alice") b = Worker(s.ip, s.port, name="bob") w = Worker(s.ip, s.port, name=3) @@ -3062,8 +3061,7 @@ def test_unrunnable_task_runs(c, s, a, b): def test_add_worker_after_tasks(c, s): futures = c.map(inc, range(10)) - n = Nanny(s.ip, s.port, ncores=2, loop=s.loop) - n.start(0) + n = yield Nanny(s.ip, s.port, ncores=2, loop=s.loop, port=0) result = yield c.gather(futures) @@ -3603,8 +3601,7 @@ def test_as_completed_next_batch(c): @gen_test() def test_status(): - s = Scheduler() - s.start(0) + s = yield Scheduler(port=0) c = yield Client((s.ip, s.port), asynchronous=True) assert c.status == "running" diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9224ed69030..805b0e06ed0 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -533,8 +533,7 @@ def test_broadcast_nanny(s, a, b): @gen_test() def test_worker_name(): - s = Scheduler(validate=True) - s.start(0) + s = yield Scheduler(validate=True, port=0) w = yield Worker(s.ip, s.port, name="alice") assert s.workers[w.address].name == "alice" assert s.aliases["alice"] == w.address @@ -550,8 +549,7 @@ def test_worker_name(): @gen_test() def test_coerce_address(): with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): - s = Scheduler(validate=True) - s.start(0) + s = yield Scheduler(validate=True, port=0) print("scheduler:", s.address, s.listen_address) a = Worker(s.ip, s.port, name="alice") b = Worker(s.ip, s.port, name=123) @@ -824,7 +822,7 @@ def test_file_descriptors(c, s): yield [n.close() for n in nannies] assert not s.rpc.open - assert not c.rpc.active + assert not c.rpc.active, list(c.rpc._created) assert not s.stream_comms start = time() @@ -1133,8 +1131,7 @@ def test_fifo_submission(c, s, w): @gen_test() def test_scheduler_file(): with tmpfile() as fn: - s = Scheduler(scheduler_file=fn) - s.start(0) + s = yield Scheduler(scheduler_file=fn, port=0) with open(fn) as f: data = json.load(f) assert data["address"] == s.address @@ -1536,3 +1533,13 @@ def test_close_workers(s, a, b): yield s.close(close_workers=True) assert a.status == "closed" assert b.status == "closed" + + +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_test() +def test_host_address(): + s = yield Scheduler(host="127.0.0.2") + assert "127.0.0.2" in s.address + yield s.close() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 07864ab4b64..8ca3c5d9682 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -304,8 +304,7 @@ def test_broadcast(s, a, b): @gen_test() def test_worker_with_port_zero(): - s = Scheduler() - s.start(8007) + s = yield Scheduler(port=8007) w = yield Worker(s.address) assert isinstance(w.port, int) assert w.port > 1024 @@ -1007,8 +1006,7 @@ def test_start_services(s): @gen_test() def test_scheduler_file(): with tmpfile() as fn: - s = Scheduler(scheduler_file=fn) - s.start(8009) + s = yield Scheduler(scheduler_file=fn, port=8009) w = yield Worker(scheduler_file=fn) assert set(s.workers) == {w.address} yield w.close() @@ -1384,3 +1382,18 @@ def __init__(self, x, y): assert w.data.x == 123 assert w.data.y == 456 yield w.close() + + +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster(ncores=[], client=True) +def test_host_address(c, s): + w = yield Worker(s.address, host="127.0.0.2") + assert "127.0.0.2" in w.address + yield w.close() + + n = yield Nanny(s.address, host="127.0.0.3") + assert "127.0.0.3" in n.address + assert "127.0.0.3" in n.worker_address + yield n.close() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index b0c0d2d48cc..7aaa5b1ed0d 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -11,7 +11,6 @@ import logging import logging.config import os -import psutil import re import shutil import signal @@ -40,7 +39,7 @@ from tornado.ioloop import IOLoop from .client import default_client, _global_clients, Client -from .compatibility import PY3, Empty, WINDOWS, PY2 +from .compatibility import PY3, Empty, WINDOWS from .comm import Comm from .comm.utils import offload from .config import initialize_logging @@ -156,10 +155,7 @@ def start(): _cleanup_dangling() - if PY2: # no forkserver, so no extra procs - for child in psutil.Process().children(recursive=True): - with ignoring(psutil.NoSuchProcess): - child.terminate() + assert_no_leaked_processes() _global_clients.clear() @@ -482,8 +478,8 @@ def run_scheduler(q, nputs, **kwargs): # On Python 2.7 and Unix, fork() is used to spawn child processes, # so avoid inheriting the parent's IO loop. with pristine_loop() as loop: - scheduler = Scheduler(validate=True, **kwargs) - done = scheduler.start("127.0.0.1") + scheduler = Scheduler(validate=True, host="127.0.0.1", **kwargs) + done = scheduler.start() for i in range(nputs): q.put(scheduler.address) @@ -501,7 +497,7 @@ def run_worker(q, scheduler_q, **kwargs): with pristine_loop() as loop: scheduler_addr = scheduler_q.get() worker = Worker(scheduler_addr, validate=True, **kwargs) - loop.run_sync(lambda: worker._start(0)) + loop.run_sync(lambda: worker._start()) q.put(worker.address) try: @@ -521,7 +517,7 @@ def run_nanny(q, scheduler_q, **kwargs): with pristine_loop() as loop: scheduler_addr = scheduler_q.get() worker = Nanny(scheduler_addr, validate=True, **kwargs) - loop.run_sync(lambda: worker._start(0)) + loop.run_sync(lambda: worker._start()) q.put(worker.address) try: loop.start() @@ -657,6 +653,7 @@ def cluster( # Launch scheduler scheduler = mp_context.Process( + name="Dask cluster test: Scheduler", target=run_scheduler, args=(scheduler_q, nworkers + 1), kwargs=scheduler_kwargs, @@ -675,7 +672,10 @@ def cluster( worker_kwargs, ) proc = mp_context.Process( - target=_run_worker, args=(q, scheduler_q), kwargs=kwargs + name="Dask cluster test: Worker", + target=_run_worker, + args=(q, scheduler_q), + kwargs=kwargs, ) ws.add(proc) workers.append({"proc": proc, "queue": q, "dir": fn}) @@ -774,6 +774,16 @@ def cluster( print("Unclosed Comms", L) # raise ValueError("Unclosed Comms", L) + assert_no_leaked_processes() + + +def assert_no_leaked_processes(): + for i in range(20): + if mp_context.active_children(): + sleep(0.1) + else: + assert not mp_context.active_children() + @gen.coroutine def disconnect(addr, timeout=3, rpc_kwargs=None): @@ -854,6 +864,7 @@ def start_cluster( security=security, loop=loop, validate=True, + host=ncore[0], **(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs) ) for i, ncore in enumerate(ncores) @@ -861,7 +872,7 @@ def start_cluster( # for w in workers: # w.rpc = workers[0].rpc - yield [w._start(ncore[0]) for ncore, w in zip(ncores, workers)] + yield workers start = time() while len(s.workers) < len(ncores) or any( @@ -1061,6 +1072,9 @@ def coro(): _cleanup_dangling() with ignoring(AttributeError): del thread_state.on_event_loop_thread + + assert_no_leaked_processes() + return result return test_func diff --git a/distributed/worker.py b/distributed/worker.py index 915784edd88..abbd2376c42 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -30,6 +30,7 @@ from .batched import BatchedSend from .comm import get_address_host, get_local_address_for, connect from .comm.utils import offload +from .comm.addressing import address_from_user_args from .compatibility import unicode, get_thread_identity, finalize, MutableMapping from .core import error_message, CommClosedError, send_recv, pingpong, coerce_to_address from .diskutils import WorkSpace @@ -295,6 +296,10 @@ def __init__( extensions=None, metrics=None, data=None, + interface=None, + host=None, + port=None, + protocol=None, low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), **kwargs ): @@ -406,7 +411,16 @@ def __init__( scheduler_addr = coerce_to_address(scheduler_ip) else: scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port)) - self._port = 0 + self.contact_address = contact_address + + self._start_address = address_from_user_args( + host=host, + port=port, + interface=interface, + protocol=protocol, + security=security, + ) + self.ncores = ncores or _ncores self.total_resources = resources or {} self.available_resources = (resources or {}).copy() @@ -417,7 +431,6 @@ def __init__( self.preload_argv = preload_argv if self.preload_argv is None: self.preload_argv = dask.config.get("distributed.worker.preload-argv") - self.contact_address = contact_address self.memory_monitor_interval = parse_timedelta( memory_monitor_interval, default="ms" ) @@ -888,6 +901,7 @@ def start_services(self, default_listen_ip): @gen.coroutine def _start(self, addr_or_port=0): assert self.status is None + addr_or_port = addr_or_port or self._start_address enable_gc_diagnosis() thread_state.on_event_loop_thread = True