Skip to content

Commit

Permalink
Non-trivial changes
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 11, 2022
1 parent 401e1b4 commit f953262
Show file tree
Hide file tree
Showing 9 changed files with 554 additions and 332 deletions.
2 changes: 1 addition & 1 deletion distributed/diagnostics/plugin.py
Expand Up @@ -334,7 +334,7 @@ def __init__(self, filepath):

async def setup(self, worker):
response = await worker.upload_file(
comm=None, filename=self.filename, data=self.data, load=True
filename=self.filename, data=self.data, load=True
)
assert len(self.data) == response["nbytes"]

Expand Down
7 changes: 4 additions & 3 deletions distributed/node.py
Expand Up @@ -77,15 +77,16 @@ def stop_services(self):
def service_ports(self):
return {k: v.port for k, v in self.services.items()}

def _setup_logging(self, logger):
def _setup_logging(self, *loggers):
self._deque_handler = DequeHandler(
n=dask.config.get("distributed.admin.log-length")
)
self._deque_handler.setFormatter(
logging.Formatter(dask.config.get("distributed.admin.log-format"))
)
logger.addHandler(self._deque_handler)
weakref.finalize(self, logger.removeHandler, self._deque_handler)
for logger in loggers:
logger.addHandler(self._deque_handler)
weakref.finalize(self, logger.removeHandler, self._deque_handler)

def get_logs(self, start=0, n=None, timestamps=False):
"""
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/shuffle_extension.py
Expand Up @@ -230,7 +230,7 @@ def __init__(self, worker: Worker) -> None:
# Initialize
self.worker: Worker = worker
self.shuffles: dict[ShuffleId, Shuffle] = {}
self.executor = ThreadPoolExecutor(worker.nthreads)
self.executor = ThreadPoolExecutor(worker.state.nthreads)

# Handlers
##########
Expand Down
3 changes: 2 additions & 1 deletion distributed/tests/test_client.py
Expand Up @@ -1612,6 +1612,7 @@ def g():
os.remove("myfile.zip")


@pytest.mark.slow
@gen_cluster(client=True)
async def test_upload_file_egg(c, s, a, b):
pytest.importorskip("setuptools")
Expand Down Expand Up @@ -6810,7 +6811,7 @@ async def test_workers_collection_restriction(c, s, a, b):
assert a.data and not b.data


@gen_cluster(client=True, nthreads=[("127.0.0.1", 0)])
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_get_client_functions_spawn_clusters(c, s, a):
# see gh4565

Expand Down
38 changes: 21 additions & 17 deletions distributed/tests/test_worker.py
Expand Up @@ -1557,7 +1557,9 @@ async def f(ev):
task for task in asyncio.all_tasks() if "execute(f1)" in task.get_name()
)
start = time()
with captured_logger("distributed.worker", level=logging.ERROR) as logger:
with captured_logger(
"distributed.worker_state_machine", level=logging.ERROR
) as logger:
await a.close(timeout=1)
assert "Failed to cancel asyncio task" in logger.getvalue()
assert time() - start < 5
Expand Down Expand Up @@ -2030,7 +2032,7 @@ async def test_gather_dep_from_remote_workers_if_all_local_workers_are_busy(
assert_story(a.story("receive-dep"), [("receive-dep", rw.address, {"f"})])


@gen_cluster(client=True, nthreads=[("127.0.0.1", 0)])
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_worker_client_uses_default_no_close(c, s, a):
"""
If a default client is available in the process, the worker will pick this
Expand All @@ -2057,7 +2059,7 @@ def get_worker_client_id():
assert c is c_def


@gen_cluster(nthreads=[("127.0.0.1", 0)])
@gen_cluster(nthreads=[("127.0.0.1", 1)])
async def test_worker_client_closes_if_created_on_worker_one_worker(s, a):
async with Client(s.address, set_as_default=False, asynchronous=True) as c:
with pytest.raises(ValueError):
Expand Down Expand Up @@ -2542,7 +2544,7 @@ def raise_exc(*args):
await asyncio.sleep(0.01)


@gen_cluster(client=True, nthreads=[("127.0.0.1", x) for x in range(4)])
@gen_cluster(client=True, nthreads=[("", x) for x in (1, 2, 3, 4)])
async def test_hold_on_to_replicas(c, s, *workers):
f1 = c.submit(inc, 1, workers=[workers[0].address], key="f1")
f2 = c.submit(inc, 2, workers=[workers[1].address], key="f2")
Expand Down Expand Up @@ -3283,10 +3285,24 @@ async def test_Worker__to_dict(c, s, a):
"type",
"id",
"scheduler",
"nthreads",
"address",
"status",
"thread_id",
"logs",
"config",
"incoming_transfer_log",
"outgoing_transfer_log",
# attributes of WorkerMemoryManager
"data",
"max_spill",
"memory_limit",
"memory_monitor_interval",
"memory_pause_fraction",
"memory_spill_fraction",
"memory_target_fraction",
# attributes of WorkerState
"nthreads",
"running",
"ready",
"constrained",
"long_running",
Expand All @@ -3298,20 +3314,8 @@ async def test_Worker__to_dict(c, s, a):
"stimulus_log",
"transition_counter",
"tasks",
"logs",
"config",
"incoming_transfer_log",
"outgoing_transfer_log",
"data_needed",
"data_needed_per_worker",
# attributes of WorkerMemoryManager
"data",
"max_spill",
"memory_limit",
"memory_monitor_interval",
"memory_pause_fraction",
"memory_spill_fraction",
"memory_target_fraction",
}
assert d["tasks"]["x"]["key"] == "x"
assert d["data"] == ["x"]
Expand Down
21 changes: 12 additions & 9 deletions distributed/utils_test.py
Expand Up @@ -70,7 +70,8 @@
reset_logger_locks,
sync,
)
from distributed.worker import WORKER_ANY_RUNNING, InvalidTransition, Worker
from distributed.worker import WORKER_ANY_RUNNING, Worker
from distributed.worker_state_machine import InvalidTransition

try:
import ssl
Expand Down Expand Up @@ -1271,8 +1272,10 @@ def validate_state(*servers: Scheduler | Worker | Nanny) -> None:
Excludes workers wrapped by Nannies and workers manually started by the test.
"""
for s in servers:
if s.validate and hasattr(s, "validate_state"):
s.validate_state() # type: ignore
if isinstance(s, Scheduler) and s.validate:
s.validate_state()
elif isinstance(s, Worker) and s.state.validate:
s.validate_state()


def raises(func, exc=Exception):
Expand Down Expand Up @@ -2322,13 +2325,13 @@ def freeze_data_fetching(w: Worker, *, jump_start: bool = False):
If True, trigger ensure_communicating on exit; this simulates e.g. an unrelated
worker moving out of in_flight_workers.
"""
old_out_connections = w.total_out_connections
old_comm_threshold = w.comm_threshold_bytes
w.total_out_connections = 0
w.comm_threshold_bytes = 0
old_out_connections = w.state.total_out_connections
old_comm_threshold = w.state.comm_threshold_bytes
w.state.total_out_connections = 0
w.state.comm_threshold_bytes = 0
yield
w.total_out_connections = old_out_connections
w.comm_threshold_bytes = old_comm_threshold
w.state.total_out_connections = old_out_connections
w.state.comm_threshold_bytes = old_comm_threshold
if jump_start:
w.status = Status.paused
w.status = Status.running
Expand Down

0 comments on commit f953262

Please sign in to comment.