forked from dask/distributed
/
test_utils_comm.py
263 lines (204 loc) · 7.82 KB
/
test_utils_comm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
from __future__ import annotations
import asyncio
import random
from unittest import mock
import pytest
from dask.optimization import SubgraphCallable
from distributed import wait
from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.core import ConnectionPool, Status
from distributed.utils_comm import (
WrappedKey,
gather_from_workers,
pack_data,
retry,
subs_multiple,
unpack_remotedata,
)
from distributed.utils_test import BarrierGetData, BrokenComm, gen_cluster, inc
def test_pack_data():
data = {"x": 1}
assert pack_data(("x", "y"), data) == (1, "y")
assert pack_data({"a": "x", "b": "y"}, data) == {"a": 1, "b": "y"}
assert pack_data({"a": ["x"], "b": "y"}, data) == {"a": [1], "b": "y"}
def test_subs_multiple():
data = {"x": 1, "y": 2}
assert subs_multiple((sum, [0, "x", "y", "z"]), data) == (sum, [0, 1, 2, "z"])
assert subs_multiple((sum, [0, ["x", "y", "z"]]), data) == (sum, [0, [1, 2, "z"]])
dsk = {"a": (sum, ["x", "y"])}
assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])}
# Tuple key
data = {"x": 1, ("y", 0): 2}
dsk = {"a": (sum, ["x", ("y", 0)])}
assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])}
@gen_cluster(client=True, nthreads=[("", 1)] * 10)
async def test_gather_from_workers_missing_replicas(c, s, *workers):
"""When a key is replicated on multiple workers, but the who_has is slightly
obsolete, gather_from_workers, retries fetching from all known holders of a replica
until it finds the key
"""
a = random.choice(workers)
x = await c.scatter({"x": 1}, workers=a.address)
assert len(s.workers) == 10
assert len(s.tasks["x"].who_has) == 1
rpc = await ConnectionPool()
data, missing, failed, bad_workers = await gather_from_workers(
{"x": [w.address for w in workers]}, rpc=rpc
)
assert data == {"x": 1}
assert missing == []
assert failed == []
assert bad_workers == []
@gen_cluster(client=True)
async def test_gather_from_workers_permissive(c, s, a, b):
"""gather_from_workers fetches multiple keys, of which some are missing.
Test that the available data is returned with a note for missing data.
"""
rpc = await ConnectionPool()
x = await c.scatter({"x": 1}, workers=a.address)
data, missing, failed, bad_workers = await gather_from_workers(
{"x": [a.address], "y": [b.address]}, rpc=rpc
)
assert data == {"x": 1}
assert missing == ["y"]
assert failed == []
assert bad_workers == []
class BrokenConnectionPool(ConnectionPool):
async def connect(self, address, *args, **kwargs):
return BrokenComm()
@gen_cluster(client=True)
async def test_gather_from_workers_permissive_flaky(c, s, a, b):
"""gather_from_workers fails to connect to a worker"""
x = await c.scatter({"x": 1}, workers=a.address)
rpc = await BrokenConnectionPool()
data, missing, failed, bad_workers = await gather_from_workers(
{"x": [a.address]}, rpc=rpc
)
assert data == {}
assert missing == ["x"]
assert failed == []
assert bad_workers == [a.address]
@gen_cluster(
client=True,
nthreads=[],
config={"distributed.worker.memory.pause": False},
)
async def test_gather_from_workers_busy(c, s):
"""gather_from_workers receives a 'busy' response from a worker"""
async with BarrierGetData(s.address, barrier_count=2) as w:
x = await c.scatter({"x": 1}, workers=[w.address])
await wait(x)
# Throttle to 1 simultaneous connection
w.status = Status.paused
rpc1 = await ConnectionPool()
rpc2 = await ConnectionPool()
out1, out2 = await asyncio.gather(
gather_from_workers({"x": [w.address]}, rpc=rpc1),
gather_from_workers({"x": [w.address]}, rpc=rpc2),
)
assert w.barrier_count == -1 # w.get_data() has been hit 3 times
assert out1 == out2 == ({"x": 1}, [], [], [])
@pytest.mark.parametrize("when", ["pickle", "unpickle"])
@gen_cluster(client=True)
async def test_gather_from_workers_serialization_error(c, s, a, b, when):
"""A task fails to (de)serialize. Tasks from other workers are fetched
successfully.
"""
class BadReduce:
def __reduce__(self):
if when == "pickle":
1 / 0
else:
return lambda: 1 / 0, ()
rpc = await ConnectionPool()
x = c.submit(BadReduce, key="x", workers=[a.address])
y = c.submit(inc, 1, key="y", workers=[a.address])
z = c.submit(inc, 2, key="z", workers=[b.address])
await wait([x, y, z])
data, missing, failed, bad_workers = await gather_from_workers(
{"x": [a.address], "y": [a.address], "z": [b.address]}, rpc=rpc
)
assert data == {"z": 3}
assert missing == []
# x and y were serialized together with a single call to pickle; can't tell which
# raised
assert failed == ["x", "y"]
assert bad_workers == []
def test_retry_no_exception(cleanup):
n_calls = 0
retval = object()
async def coro():
nonlocal n_calls
n_calls += 1
return retval
async def f():
return await retry(coro, count=0, delay_min=-1, delay_max=-1)
assert asyncio_run(f(), loop_factory=get_loop_factory()) is retval
assert n_calls == 1
def test_retry0_raises_immediately(cleanup):
# test that using max_reties=0 raises after 1 call
n_calls = 0
async def coro():
nonlocal n_calls
n_calls += 1
raise RuntimeError(f"RT_ERROR {n_calls}")
async def f():
return await retry(coro, count=0, delay_min=-1, delay_max=-1)
with pytest.raises(RuntimeError, match="RT_ERROR 1"):
asyncio_run(f(), loop_factory=get_loop_factory())
assert n_calls == 1
def test_retry_does_retry_and_sleep(cleanup):
# test the retry and sleep pattern of `retry`
n_calls = 0
class MyEx(Exception):
pass
async def coro():
nonlocal n_calls
n_calls += 1
raise MyEx(f"RT_ERROR {n_calls}")
sleep_calls = []
async def my_sleep(amount):
sleep_calls.append(amount)
return
async def f():
return await retry(
coro,
retry_on_exceptions=(MyEx,),
count=5,
delay_min=1.0,
delay_max=6.0,
jitter_fraction=0.0,
)
with mock.patch("asyncio.sleep", my_sleep):
with pytest.raises(MyEx, match="RT_ERROR 6"):
asyncio_run(f(), loop_factory=get_loop_factory())
assert n_calls == 6
assert sleep_calls == [0.0, 1.0, 3.0, 6.0, 6.0]
def test_unpack_remotedata():
def assert_eq(keys1: set[WrappedKey], keys2: set[WrappedKey]) -> None:
if len(keys1) != len(keys2):
assert False
if not keys1:
assert True
if not all(isinstance(k, WrappedKey) for k in keys1 & keys2):
assert False
assert sorted([k.key for k in keys1]) == sorted([k.key for k in keys2])
assert unpack_remotedata(1) == (1, set())
assert unpack_remotedata(()) == ((), set())
res, keys = unpack_remotedata(WrappedKey("mykey"))
assert res == "mykey"
assert_eq(keys, {WrappedKey("mykey")})
# Check unpack of SC that contains a wrapped key
sc = SubgraphCallable({"key": (WrappedKey("data"),)}, outkey="key", inkeys=["arg1"])
dsk = (sc, "arg1")
res, keys = unpack_remotedata(dsk)
assert res[0] != sc # Notice, the first item (the SC) has been changed
assert res[1:] == ("arg1", "data")
assert_eq(keys, {WrappedKey("data")})
# Check unpack of SC when it takes a wrapped key as argument
sc = SubgraphCallable({"key": ("arg1",)}, outkey="key", inkeys=[WrappedKey("arg1")])
dsk = (sc, "arg1")
res, keys = unpack_remotedata(dsk)
assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed
assert_eq(keys, set())