Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions distributed/bokeh/tests/test_scheduler_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,10 +565,12 @@ def test_GraphPlot_order(c, s, a, b):
)
def test_profile_server(c, s, a, b):
ptp = ProfileServer(s)
ptp.trigger_update()
yield gen.sleep(0.200)
ptp.trigger_update()
assert 2 < len(ptp.ts_source.data["time"]) < 20
start = time()
yield gen.sleep(0.100)
while len(ptp.ts_source.data["time"]) < 2:
yield gen.sleep(0.100)
ptp.trigger_update()
assert time() < start + 2


@gen_cluster(client=True, scheduler_kwargs={"services": {("bokeh", 0): BokehScheduler}})
Expand Down
5 changes: 4 additions & 1 deletion distributed/cli/tests/test_dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def test_nanny_worker_ports(loop):
else:
assert time() - start < 5
sleep(0.1)
assert d["workers"]["tcp://127.0.0.1:9684"]["services"]["nanny"] == 5273
assert (
d["workers"]["tcp://127.0.0.1:9684"]["nanny"]
== "tcp://127.0.0.1:5273"
)


def test_memory_limit(loop):
Expand Down
2 changes: 1 addition & 1 deletion distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def instantiate(self, comm=None):
ncores=self.ncores,
local_dir=self.local_dir,
services=self.services,
service_ports={"nanny": self.port},
nanny=self.address,
name=self.name,
memory_limit=self.memory_limit,
reconnect=self.reconnect,
Expand Down
52 changes: 52 additions & 0 deletions distributed/node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import print_function, division, absolute_import

import warnings

from tornado.ioloop import IOLoop

from .compatibility import unicode
from .core import Server, ConnectionPool
from .versions import get_versions

Expand Down Expand Up @@ -78,3 +81,52 @@ def __init__(

def versions(self, comm=None, packages=None):
return get_versions(packages=packages)

def start_services(self, default_listen_ip):
if default_listen_ip == "0.0.0.0":
default_listen_ip = "" # for IPV6

for k, v in self.service_specs.items():
listen_ip = None
if isinstance(k, tuple):
k, port = k
else:
port = 0

if isinstance(port, (str, unicode)):
port = port.split(":")

if isinstance(port, (tuple, list)):
if len(port) == 2:
listen_ip, port = (port[0], int(port[1]))
elif len(port) == 1:
[listen_ip], port = port, 0
else:
raise ValueError(port)

if isinstance(v, tuple):
v, kwargs = v
else:
kwargs = {}

try:
service = v(self, io_loop=self.loop, **kwargs)
service.listen(
(listen_ip if listen_ip is not None else default_listen_ip, port)
)
self.services[k] = service
except Exception as e:
warnings.warn(
"\nCould not launch service '%s' on port %s. " % (k, port)
+ "Got the following message:\n\n"
+ str(e),
stacklevel=3,
)

def stop_services(self):
for service in self.services.values():
service.stop()

@property
def service_ports(self):
return {k: v.port for k, v in self.services.items()}
63 changes: 14 additions & 49 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pickle
import random
import six
import warnings

import psutil
import sortedcontainers
Expand Down Expand Up @@ -190,6 +189,10 @@ class WorkerState(object):

The current status of the worker, either ``'running'`` or ``'closed'``

.. attribute:: nanny: str

Address of the associated Nanny, if present

.. attribute:: last_seen: Number

The last time we received a heartbeat from this worker, in local
Expand All @@ -214,6 +217,7 @@ class WorkerState(object):
"memory_limit",
"metrics",
"name",
"nanny",
"nbytes",
"ncores",
"occupancy",
Expand All @@ -235,6 +239,7 @@ def __init__(
memory_limit=0,
local_directory=None,
services=None,
nanny=None,
):
self.address = address
self.pid = pid
Expand All @@ -243,6 +248,7 @@ def __init__(
self.memory_limit = memory_limit
self.local_directory = local_directory
self.services = services or {}
self.nanny = nanny

self.status = "running"
self.nbytes = 0
Expand Down Expand Up @@ -271,6 +277,7 @@ def clean(self):
memory_limit=self.memory_limit,
local_directory=self.local_directory,
services=self.services,
nanny=self.nanny,
)
ws.processing = {ts.key for ts in self.processing}
return ws
Expand Down Expand Up @@ -298,6 +305,7 @@ def identity(self):
"last_seen": self.last_seen,
"services": self.services,
"metrics": self.metrics,
"nanny": self.nanny,
}


Expand Down Expand Up @@ -1157,46 +1165,6 @@ def get_worker_service_addr(self, worker, service_name, protocol=False):
else:
return ws.host, port

def start_services(self, default_listen_ip):
if default_listen_ip == "0.0.0.0":
default_listen_ip = "" # for IPV6

for k, v in self.service_specs.items():
listen_ip = None
if isinstance(k, tuple):
k, port = k
else:
port = 0

if isinstance(port, (str, unicode)):
port = port.split(":")

if isinstance(port, (tuple, list)):
listen_ip, port = (port[0], int(port[1]))

if isinstance(v, tuple):
v, kwargs = v
else:
kwargs = {}

try:
service = v(self, io_loop=self.loop, **kwargs)
service.listen(
(listen_ip if listen_ip is not None else default_listen_ip, port)
)
self.services[k] = service
except Exception as e:
warnings.warn(
"\nCould not launch service '%s' on port %s. " % (k, port)
+ "Got the following message:\n\n"
+ str(e),
stacklevel=3,
)

def stop_services(self):
for service in self.services.values():
service.stop()

def start(self, addr_or_port=None, start_queues=True):
""" Clear out old state and restart all running coroutines """
enable_gc_diagnosis()
Expand Down Expand Up @@ -1347,7 +1315,7 @@ def close_worker(self, stream=None, worker=None, safe=None):
logger.info("Closing worker %s", worker)
with log_errors():
self.log_event(worker, {"action": "close-worker"})
nanny_addr = self.get_worker_service_addr(worker, "nanny", protocol=True)
nanny_addr = self.workers[worker].nanny
address = nanny_addr or worker

self.worker_send(worker, {"op": "close", "report": False})
Expand Down Expand Up @@ -1434,6 +1402,7 @@ def add_worker(
pid=0,
services=None,
local_directory=None,
nanny=None,
):
""" Add a new worker to the cluster """
with log_errors():
Expand All @@ -1453,6 +1422,7 @@ def add_worker(
name=name,
local_directory=local_directory,
services=services,
nanny=nanny,
)

if name in self.aliases:
Expand Down Expand Up @@ -2608,10 +2578,7 @@ def restart(self, client=None, timeout=3):
keys=[ts.key for ts in cs.wants_what], client=cs.client_key
)

nannies = {
addr: self.get_worker_service_addr(addr, "nanny", protocol=True)
for addr in self.workers
}
nannies = {addr: ws.nanny for addr, ws in self.workers.items()}

for addr in list(self.workers):
try:
Expand Down Expand Up @@ -2694,9 +2661,7 @@ def broadcast(
# TODO replace with worker_list

if nanny:
addresses = [
self.get_worker_service_addr(w, "nanny", protocol=True) for w in workers
]
addresses = [self.workers[w].nanny for w in workers]
else:
addresses = workers

Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3479,7 +3479,7 @@ def test_bad_tasks_fail(c, s, a, b):
with pytest.raises(KilledWorker) as info:
yield f

assert info.value.last_worker.services["nanny"] in {a.port, b.port}
assert info.value.last_worker.nanny in {a.address, b.address}


def test_get_processing_sync(c, s, a, b):
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_nanny(s):
with rpc(n.address) as nn:
assert n.is_alive()
assert s.ncores[n.worker_address] == 2
assert s.workers[n.worker_address].services["nanny"] > 1024
assert s.workers[n.worker_address].nanny == n.address

yield nn.kill()
assert not n.is_alive()
Expand All @@ -43,7 +43,7 @@ def test_nanny(s):
yield nn.instantiate()
assert n.is_alive()
assert s.ncores[n.worker_address] == 2
assert s.workers[n.worker_address].services["nanny"] > 1024
assert s.workers[n.worker_address].nanny == n.address

yield nn.terminate()
assert not n.is_alive()
Expand Down
12 changes: 12 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,3 +1542,15 @@ def test_host_address():
s = yield Scheduler(host="127.0.0.2")
assert "127.0.0.2" in s.address
yield s.close()


@gen_test()
def test_dashboard_address():
pytest.importorskip("bokeh")
s = yield Scheduler(dashboard_address="127.0.0.1:8901")
assert s.services["bokeh"].port == 8901
yield s.close()

s = yield Scheduler(dashboard_address="127.0.0.1")
assert s.services["bokeh"].port
yield s.close()
43 changes: 7 additions & 36 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ class Worker(ServerNode):
executor: concurrent.futures.Executor
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists

Examples
--------
Expand Down Expand Up @@ -311,6 +313,7 @@ def __init__(
port=None,
protocol=None,
dashboard_address=None,
nanny=None,
low_level_profiler=dask.config.get("distributed.worker.profile.low-level"),
**kwargs
):
Expand All @@ -323,6 +326,7 @@ def __init__(
self.who_has = dict()
self.has_what = defaultdict(set)
self.pending_data_per_worker = defaultdict(deque)
self.nanny = nanny
self._lock = threading.Lock()

self.data_needed = deque() # TODO: replace with heap?
Expand Down Expand Up @@ -535,7 +539,6 @@ def __init__(
sys.path.insert(0, self.local_dir)

self.services = {}
self.service_ports = service_ports or {}
self.service_specs = services or {}

if dashboard_address is not None:
Expand Down Expand Up @@ -731,6 +734,7 @@ def _register_with_scheduler(self):
memory_limit=self.memory_limit,
local_directory=self.local_dir,
services=self.service_ports,
nanny=self.nanny,
pid=os.getpid(),
metrics=self.get_metrics(),
),
Expand Down Expand Up @@ -895,34 +899,6 @@ def get_logs(self, comm=None, n=None):
# Lifecycle #
#############

def start_services(self, default_listen_ip):
if default_listen_ip == "0.0.0.0":
default_listen_ip = "" # for IPV6

for k, v in self.service_specs.items():
listen_ip = None
if isinstance(k, tuple):
k, port = k
else:
port = 0

if isinstance(port, (str, unicode)):
port = port.split(":")

if isinstance(port, (tuple, list)):
listen_ip, port = (port[0], int(port[1]))

if isinstance(v, tuple):
v, kwargs = v
else:
kwargs = {}

self.services[k] = v(self, io_loop=self.loop, **kwargs)
self.services[k].listen(
(listen_ip if listen_ip is not None else default_listen_ip, port)
)
self.service_ports[k] = self.services[k].port

@gen.coroutine
def _start(self, addr_or_port=0):
assert self.status is None
Expand Down Expand Up @@ -1047,13 +1023,8 @@ def close(self, report=True, timeout=10, nanny=True, executor_wait=True):
if self.batched_stream:
self.batched_stream.close()

if nanny and "nanny" in self.service_ports:
nanny_address = "%s%s:%d" % (
self.listener.prefix,
self.ip,
self.service_ports["nanny"],
)
with self.rpc(nanny_address) as r:
if nanny and self.nanny:
with self.rpc(self.nanny) as r:
yield r.terminate()

self.rpc.close()
Expand Down