forked from dask/dask
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_callbacks.py
116 lines (84 loc) · 2.56 KB
/
test_callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from dask.local import get_sync
from dask.context import _globals
from dask.threaded import get as get_threaded
from dask.callbacks import Callback
from dask.utils_test import add
def test_start_callback():
flag = [False]
class MyCallback(Callback):
def _start(self, dsk):
flag[0] = True
with MyCallback():
get_sync({'x': 1}, 'x')
assert flag[0] is True
def test_start_state_callback():
flag = [False]
class MyCallback(Callback):
def _start_state(self, dsk, state):
flag[0] = True
assert dsk['x'] == 1
assert len(state['cache']) == 1
with MyCallback():
get_sync({'x': 1}, 'x')
assert flag[0] is True
def test_finish_always_called():
flag = [False]
class MyCallback(Callback):
def _finish(self, dsk, state, errored):
flag[0] = True
assert errored
dsk = {'x': (lambda: 1 / 0,)}
# `raise_on_exception=True`
try:
with MyCallback():
get_sync(dsk, 'x')
except Exception as e:
assert isinstance(e, ZeroDivisionError)
assert flag[0]
# `raise_on_exception=False`
flag[0] = False
try:
with MyCallback():
get_threaded(dsk, 'x')
except Exception as e:
assert isinstance(e, ZeroDivisionError)
assert flag[0]
# KeyboardInterrupt
def raise_keyboard():
raise KeyboardInterrupt()
dsk = {'x': (raise_keyboard,)}
flag[0] = False
try:
with MyCallback():
get_sync(dsk, 'x')
except BaseException as e:
assert isinstance(e, KeyboardInterrupt)
assert flag[0]
def test_nested_schedulers():
class MyCallback(Callback):
def _start(self, dsk):
self.dsk = dsk
def _pretask(self, key, dsk, state):
assert key in self.dsk
inner_callback = MyCallback()
inner_dsk = {'x': (add, 1, 2),
'y': (add, 'x', 3)}
def nested_call(x):
assert not _globals['callbacks']
with inner_callback:
return get_threaded(inner_dsk, 'y') + x
outer_callback = MyCallback()
outer_dsk = {'a': (nested_call, 1),
'b': (add, 'a', 2)}
with outer_callback:
get_threaded(outer_dsk, 'b')
assert not _globals['callbacks']
assert outer_callback.dsk == outer_dsk
assert inner_callback.dsk == inner_dsk
assert not _globals['callbacks']
def test_add_remove_mutates_not_replaces():
g = _globals.copy()
assert not g['callbacks']
with Callback():
pass
assert not g['callbacks']