Skip to content

Commit

Permalink
fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jan 19, 2024
1 parent c4924c3 commit 5c99782
Show file tree
Hide file tree
Showing 13 changed files with 64 additions and 60 deletions.
13 changes: 3 additions & 10 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
import traceback
import types
import uuid
import warnings
import weakref
from collections import defaultdict
Expand Down Expand Up @@ -145,7 +144,6 @@ class Server:
default_ip: ClassVar[str] = ""
default_port: ClassVar[int] = 0

id: str
blocked_handlers: list[str]
handlers: dict[str, Callable]
stream_handlers: dict[str, Callable]
Expand Down Expand Up @@ -174,20 +172,15 @@ def __init__(
timeout=None,
):
self.handlers = {
"identity": self.identity,
"echo": self.echo,
"identity": self.identity,
"connection_stream": self.handle_stream,
}
self.handlers.update(handlers)
if blocked_handlers is None:
blocked_handlers = dask.config.get(
"distributed.%s.blocked-handlers" % type(self).__name__.lower(), []
)
self.blocked_handlers = blocked_handlers
self.blocked_handlers = blocked_handlers or {}
self.stream_handlers = {}
self.stream_handlers.update(stream_handlers or {})

self.id = type(self).__name__ + "-" + str(uuid.uuid4())
self._address = None
self._listen_address = None
self._port = None
Expand Down Expand Up @@ -350,7 +343,7 @@ def port(self):
return self._port

def identity(self) -> dict[str, str]:
return {"type": type(self).__name__, "id": self.id}
return {"type": type(self).__name__, "id": str(id(self))}

def echo(self, data=None):
return data
Expand Down
2 changes: 1 addition & 1 deletion distributed/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, scheduler):
# we can remove the event
self._waiter_count = defaultdict(int)

self.scheduler.handlers.update(
self.scheduler.server.handlers.update(
{
"event_wait": self.event_wait,
"event_set": self.event_set,
Expand Down
2 changes: 1 addition & 1 deletion distributed/http/scheduler/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get(self):

class IdentityJSON(RequestHandler):
def get(self):
self.write(self.server.identity())
self.write(self.identity())


class IndexJSON(RequestHandler):
Expand Down
2 changes: 1 addition & 1 deletion distributed/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, scheduler):
self.events = defaultdict(deque)
self.ids = dict()

self.scheduler.handlers.update(
self.scheduler.server.handlers.update(
{"lock_acquire": self.acquire, "lock_release": self.release}
)

Expand Down
2 changes: 1 addition & 1 deletion distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ async def start_unsafe(self):
security=self.security,
)
try:
await self.listen(
await self.server.listen(
start_address, **self.security.get_listen_args("worker")
)
except OSError as e:
Expand Down
28 changes: 22 additions & 6 deletions distributed/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,20 @@ def __init__(

_handlers = {
"dump_state": self._to_dict,
"identity": self.identity,
}
if handlers:
_handlers.update(handlers)
import uuid

self.id = type(self).__name__ + "-" + str(uuid.uuid4())

if blocked_handlers is None:
blocked_handlers = dask.config.get(
"distributed.%s.blocked-handlers" % type(self).__name__.lower(), []
)
self.server = Server(
handlers=handlers,
handlers=_handlers,
blocked_handlers=blocked_handlers,
stream_handlers=stream_handlers,
connection_limit=connection_limit,
Expand All @@ -566,6 +574,17 @@ def __init__(
needs_workdir=needs_workdir,
)

def identity(self) -> dict[str, str]:
return {"type": type(self).__name__, "id": self.id}

@property
def port(self):
return self.server.port

@property
def listen_address(self):
return self.server.address

@property
def address(self):
return self.server.address
Expand All @@ -574,10 +593,6 @@ def address(self):
def address_safe(self):
return self.server.address_safe

@property
def id(self):
return self.server.id

async def start_unsafe(self):
await self.server
await super().start_unsafe()
Expand All @@ -592,6 +607,7 @@ async def close(self, reason: str | None = None) -> None:
# Close network connections and background tasks
await self.server.close()
await Node.close(self, reason=reason)
self.status = Status.closed
finally:
self._event_finished.set()

Expand All @@ -605,7 +621,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict[str, Any]:
Client.dump_cluster_state
distributed.utils.recursive_to_dict
"""
info: dict[str, Any] = self.server.identity()
info: dict[str, Any] = self.identity()
extra = {
"address": self.server.address,
"status": self.status.name,
Expand Down
2 changes: 1 addition & 1 deletion distributed/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class PubSubWorkerExtension:

def __init__(self, worker):
self.worker = worker
self.worker.stream_handlers.update(
self.worker.server.stream_handlers.update(
{
"pubsub-add-subscriber": self.add_subscriber,
"pubsub-remove-subscriber": self.remove_subscriber,
Expand Down
13 changes: 8 additions & 5 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
valmap,
)
from tornado.ioloop import IOLoop
from typing_extensions import Self

import dask
from dask.core import get_deps, validate_key
Expand Down Expand Up @@ -4027,7 +4028,7 @@ def get_worker_service_addr(
else:
return ws.host, port

async def start_unsafe(self):
async def start_unsafe(self) -> Self:
"""Clear out old state and restart all running coroutines"""
await super().start_unsafe()

Expand All @@ -4042,7 +4043,7 @@ async def start_unsafe(self):
handshake_overrides={"pickle-protocol": 4, "compression": None},
**self.security.get_listen_args("scheduler"),
)
self.ip = get_address_host(self.listen_address)
self.ip = get_address_host(self.server.listen_address)
listen_ip = self.ip

if listen_ip == "0.0.0.0":
Expand All @@ -4054,7 +4055,7 @@ async def start_unsafe(self):
# Services listen on all addresses
self.start_services(listen_ip)

for listener in self.listeners:
for listener in self.server.listeners:
logger.info(" Scheduler at: %25s", listener.contact_address)
for name, server in self.services.items():
if name == "dashboard":
Expand Down Expand Up @@ -4089,8 +4090,10 @@ def del_scheduler_file():
if self.jupyter:
# Allow insecure communications from local users
if self.server.address.startswith("tls://"):
await self.listen("tcp://localhost:0")
os.environ["DASK_SCHEDULER_ADDRESS"] = self.listeners[-1].contact_address
await self.server.listen("tcp://localhost:0")
os.environ["DASK_SCHEDULER_ADDRESS"] = self.server.listeners[
-1
].contact_address

await asyncio.gather(
*[plugin.start(self) for plugin in list(self.plugins.values())]
Expand Down
10 changes: 0 additions & 10 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from distributed.batched import BatchedSend
from distributed.comm.core import CommClosedError
from distributed.comm.registry import backends
from distributed.comm.tcp import TCPBackend, TCPListener
from distributed.core import (
AsyncTaskGroup,
Expand Down Expand Up @@ -1287,15 +1286,6 @@ class TCPAsyncListenerBackend(TCPBackend):
_listener_class = AsyncStopTCPListener


@gen_test()
async def test_async_listener_stop(monkeypatch):
monkeypatch.setitem(backends, "tcp", TCPAsyncListenerBackend())
with pytest.warns(DeprecationWarning):
async with Server({}) as s:
await s.listen(0)
assert s.listeners


@gen_test()
async def test_messages_are_ordered_bsend():
ledger = []
Expand Down
8 changes: 3 additions & 5 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ async def test_nanny_process_failure(c, s):
assert not os.path.exists(second_dir)
assert not os.path.exists(first_dir)
assert first_dir != n.worker_dir
s.stop()


@gen_cluster(nthreads=[])
Expand Down Expand Up @@ -201,10 +200,9 @@ def func(dask_worker):
@gen_test()
async def test_scheduler_file():
with tmpfile() as fn:
s = await Scheduler(scheduler_file=fn, dashboard_address=":0")
async with Nanny(scheduler_file=fn) as n:
assert set(s.workers) == {n.worker_address}
s.stop()
async with Scheduler(scheduler_file=fn, dashboard_address=":0") as s:
async with Nanny(scheduler_file=fn) as n:
assert set(s.workers) == {n.worker_address}


@pytest.mark.xfail(
Expand Down
28 changes: 14 additions & 14 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,8 +824,8 @@ async def test_retire_workers_concurrently(c, s, w1, w2):
async def test_server_listens_to_other_ops(s, a, b):
async with rpc(s.address) as r:
ident = await r.identity()
assert ident["type"] == "Scheduler"
assert ident["id"].lower().startswith("scheduler")
assert ident["type"] == "Scheduler", ident["type"]
assert ident["id"].lower().startswith("scheduler"), ident["id"]


@gen_cluster(client=True)
Expand Down Expand Up @@ -928,7 +928,7 @@ def func(scheduler):
nthreads=[], config={"distributed.scheduler.blocked-handlers": ["test-handler"]}
)
async def test_scheduler_init_pulls_blocked_handlers_from_config(s):
assert s.blocked_handlers == ["test-handler"]
assert s.server.blocked_handlers == ["test-handler"]


@gen_cluster()
Expand Down Expand Up @@ -1326,7 +1326,7 @@ async def test_broadcast_nanny(s, a, b):

@gen_cluster(config={"distributed.comm.timeouts.connect": "200ms"})
async def test_broadcast_on_error(s, a, b):
a.stop()
a.server.stop()

with pytest.raises(OSError):
await s.broadcast(msg={"op": "ping"}, on_error="raise")
Expand Down Expand Up @@ -2007,7 +2007,7 @@ async def test_profile_metadata_timeout(c, s, a, b):
def raise_timeout(*args, **kwargs):
raise TimeoutError

b.handlers["profile_metadata"] = raise_timeout
b.server.handlers["profile_metadata"] = raise_timeout

futures = c.map(slowinc, range(10), delay=0.05, workers=a.address)
await wait(futures)
Expand Down Expand Up @@ -2071,7 +2071,7 @@ async def test_statistical_profiling_failure(c, s, a, b):
def raise_timeout(*args, **kwargs):
raise TimeoutError

b.handlers["profile"] = raise_timeout
b.server.handlers["profile"] = raise_timeout
await wait(futures)

profile = await s.get_profile()
Expand Down Expand Up @@ -3050,7 +3050,7 @@ async def connect(self, *args, **kwargs):
async def test_gather_failing_cnn_recover(c, s, a, b):
x = await c.scatter({"x": 1}, workers=a.address)
rpc = await FlakyConnectionPool(failing_connections=1)
with mock.patch.object(s, "rpc", rpc), dask.config.set(
with mock.patch.object(s.server, "rpc", rpc), dask.config.set(
{"distributed.comm.retry.count": 1}
), captured_handler(
logging.getLogger("distributed").handlers[0]
Expand All @@ -3068,7 +3068,7 @@ async def test_gather_failing_cnn_recover(c, s, a, b):
async def test_gather_failing_cnn_error(c, s, a, b):
x = await c.scatter({"x": 1}, workers=a.address)
rpc = await FlakyConnectionPool(failing_connections=10)
with mock.patch.object(s, "rpc", rpc):
with mock.patch.object(s.server, "rpc", rpc):
res = await s.gather(keys=["x"])
assert res["status"] == "error"
assert list(res["keys"]) == ["x"]
Expand Down Expand Up @@ -3101,7 +3101,7 @@ async def test_gather_bad_worker(c, s, a, direct):
"""
x = c.submit(inc, 1, key="x")
c.rpc = await FlakyConnectionPool(failing_connections=3)
s.rpc = await FlakyConnectionPool(failing_connections=1)
s.server.rpc = await FlakyConnectionPool(failing_connections=1)

with captured_logger("distributed.scheduler") as sched_logger:
with captured_logger("distributed.client") as client_logger:
Expand All @@ -3116,12 +3116,12 @@ async def test_gather_bad_worker(c, s, a, direct):
# 3. try direct=True again; fail
# 4. fall back to direct=False again; success
assert c.rpc.cnn_count == 2
assert s.rpc.cnn_count == 2
assert s.server.rpc.cnn_count == 2
else:
# 1. try direct=False; fail
# 2. try again direct=False; success
assert c.rpc.cnn_count == 0
assert s.rpc.cnn_count == 2
assert s.server.rpc.cnn_count == 2


@gen_cluster(client=True)
Expand Down Expand Up @@ -3152,8 +3152,8 @@ async def test_multiple_listeners(dashboard_link_template, expected_dashboard_li
async with Scheduler(
dashboard_address=":0", protocol=["inproc", "tcp"]
) as s:
async with Worker(s.listeners[0].contact_address) as a:
async with Worker(s.listeners[1].contact_address) as b:
async with Worker(s.server.listeners[0].contact_address) as a:
async with Worker(s.server.listeners[1].contact_address) as b:
assert a.address.startswith("inproc")
assert a.scheduler.address.startswith("inproc")
assert b.address.startswith("tcp")
Expand Down Expand Up @@ -4602,7 +4602,7 @@ class BrokenGatherDep(Worker):
async def gather_dep(self, worker, *args, **kwargs):
w = workers.pop(worker, None)
if w is not None and workers:
w.listener.stop()
w.server.listener.stop()
s.stream_comms[worker].abort()

return await super().gather_dep(worker, *args, **kwargs)
Expand Down
1 change: 0 additions & 1 deletion distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,6 @@ async def end_worker(w):

await asyncio.gather(*(end_worker(w) for w in workers))
await s.close() # wait until scheduler stops completely
s.stop()
check_invalid_worker_transitions(s)
check_invalid_task_states(s)
check_worker_fail_hard(s)
Expand Down
Loading

0 comments on commit 5c99782

Please sign in to comment.