-
Notifications
You must be signed in to change notification settings - Fork 279
/
meter_test_utils.py
328 lines (276 loc) · 11.1 KB
/
meter_test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import multiprocessing
import queue
import tempfile
import unittest
import torch
UPDATE_SIGNAL = 0
VALUE_SIGNAL = 1
SHUTDOWN_SIGNAL = 2
TIMEOUT = 100
def _get_value_or_raise_error(qout, qerr):
try:
return qout.get(timeout=TIMEOUT)
except queue.Empty:
raise qerr.get(timeout=TIMEOUT)
def _run(qin, qout, qerr, func, *args):
try:
func(qin, qout, *args)
except Exception as e:
print(e)
qerr.put(e)
def _meter_worker(qin, qout, meter, is_train, world_size, rank, filename):
backend = "gloo"
torch.distributed.init_process_group(
backend=backend,
init_method="file://{filename}".format(filename=filename),
world_size=world_size,
rank=rank,
)
# Listen for commands on queues
while True:
try:
signal, val = qin.get()
except queue.Empty:
continue
if signal == UPDATE_SIGNAL:
meter.update(val[0], val[1], is_train=is_train)
elif signal == VALUE_SIGNAL:
meter.sync_state()
qout.put(meter.value)
elif signal == SHUTDOWN_SIGNAL:
break
else:
raise NotImplementedError("Bad signal value")
return
class ClassificationMeterTest(unittest.TestCase):
def setUp(self):
self.mp = multiprocessing.get_context("spawn")
self.processes = []
def tearDown(self):
for p in self.processes:
p.terminate()
def _spawn(self, func, *args):
name = "process #%d" % len(self.processes)
qin = self.mp.Queue()
qout = self.mp.Queue()
qerr = self.mp.Queue()
qio = (qin, qout, qerr)
args = qio + (func,) + args
process = self.mp.Process(target=_run, name=name, args=args, daemon=True)
process.start()
self.processes.append(process)
return qio
def _apply_updates_and_test_meter(
self, meter, model_output, target, expected_value, **kwargs
):
"""
Runs a valid meter test. Does not reset meter before / after running
"""
if not isinstance(model_output, list):
model_output = [model_output]
if not isinstance(target, list):
target = [target]
for i in range(len(model_output)):
meter.update(model_output[i], target[i], **kwargs)
meter.sync_state()
meter_value = meter.value
for key, val in expected_value.items():
self.assertTrue(
key in meter_value, msg="{0} not in meter value!".format(key)
)
if torch.is_tensor(meter_value[key]):
self.assertTrue(
torch.all(torch.eq(meter_value[key], val)),
msg="{0} meter value mismatch!".format(key),
)
else:
self.assertAlmostEqual(
meter_value[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)
def _values_match_expected_value(self, value0, value1, expected_value):
for key, val in expected_value.items():
self.assertTrue(key in value0, msg="{0} not in meter value!".format(key))
self.assertTrue(key in value1, msg="{0} not in meter value!".format(key))
if torch.is_tensor(val):
self.assertTrue(
torch.all(torch.eq(value0[key], val)),
"{0} meter value mismatch!".format(key),
)
self.assertTrue(
torch.all(torch.eq(value1[key], val)),
"{0} meter value mismatch!".format(key),
)
else:
self.assertAlmostEqual(
value0[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)
self.assertAlmostEqual(
value1[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)
def meter_update_and_reset_test(
self, meter, model_outputs, targets, expected_value, **kwargs
):
"""
This test verifies that a single update on the meter is successful,
resets the meter, then applies the update again.
"""
# If a single output is provided, wrap in list
if not isinstance(model_outputs, list):
model_outputs = [model_outputs]
targets = [targets]
for i in range(len(model_outputs)):
meter.validate(model_outputs[i].size(), targets[i].size())
self._apply_updates_and_test_meter(
meter, model_outputs, targets, expected_value, **kwargs
)
meter.reset()
# Verify reset works by reusing single update test
self._apply_updates_and_test_meter(
meter, model_outputs, targets, expected_value, **kwargs
)
def meter_invalid_meter_input_test(self, meter, model_output, target):
# Invalid model
with self.assertRaises(AssertionError):
meter.validate(model_output.shape, target.shape)
def meter_invalid_update_test(self, meter, model_output, target, **kwargs):
"""
Runs a valid meter test. Does not reset meter before / after running
"""
if not isinstance(model_output, list):
model_output = [model_output]
if not isinstance(target, list):
target = [target]
with self.assertRaises(AssertionError):
for i in range(len(model_output)):
meter.update(model_output[i], target[i], **kwargs)
def meter_get_set_classy_state_test(
self, meters, model_outputs, targets, expected_value, **kwargs
):
"""
Tests get and set classy state methods of meter.
"""
assert len(meters) == 2, "Incorrect number of meters passed to test"
assert (
len(model_outputs) == 2
), "Incorrect number of model_outputs passed to test"
assert len(targets) == 2, "Incorrect number of targets passed to test"
meter0 = meters[0]
meter1 = meters[1]
meter0.update(model_outputs[0], targets[0], **kwargs)
meter1.update(model_outputs[1], targets[1], **kwargs)
meter0.sync_state()
value0 = meter0.value
meter1.sync_state()
value1 = meter1.value
for key, val in value0.items():
if torch.is_tensor(value1[key]):
self.assertFalse(
torch.all(torch.eq(value1[key], val)),
msg="{0} meter values should not be same!".format(key),
)
else:
self.assertNotEqual(
value1[key],
val,
msg="{0} meter values should not be same!".format(key),
)
meter0.set_classy_state(meter1.get_classy_state())
value0 = meter0.value
for key, val in value0.items():
if torch.is_tensor(value1[key]):
self.assertTrue(
torch.all(torch.eq(value1[key], val)),
msg="{0} meter value mismatch after state transfer!".format(key),
)
self.assertTrue(
torch.all(torch.eq(value1[key], expected_value[key])),
msg="{0} meter value mismatch from ground truth!".format(key),
)
else:
self.assertAlmostEqual(
value1[key],
val,
places=4,
msg="{0} meter value mismatch after state transfer!".format(key),
)
self.assertAlmostEqual(
value1[key],
expected_value[key],
places=4,
msg="{0} meter value mismatch from ground truth!".format(key),
)
def _spawn_all_meter_workers(self, world_size, meters, is_train):
filename = tempfile.NamedTemporaryFile(delete=True).name
qins = []
qerrs = []
qouts = []
for i in range(world_size):
qin, qout, qerr = self._spawn(
_meter_worker, meters[i], is_train, world_size, i, filename
)
qins.append(qin)
qouts.append(qout)
qerrs.append(qerr)
return qins, qouts, qerrs
def meter_distributed_test(
self, meters, model_outputs, targets, expected_values, is_train=False
):
"""
Sets up two processes each with a given meter on that process.
Verifies that sync code path works.
"""
world_size = len(meters)
assert world_size == 2, "This test only works for world_size of 2"
assert len(model_outputs) == 4, (
"Test assumes 4 model outputs, "
"0, 2 passed to meter0 and 1, 3 passed to meter1"
)
assert (
len(targets) == 4
), "Test assumes 4 targets, 0, 2 passed to meter0 and 1, 3 passed to meter1"
assert len(expected_values) == 2, (
"Test assumes 2 expected values, "
"first is result of applying updates 0,1 to the meter, "
"second is result of applying all 4 updates to meter"
)
qins, qouts, qerrs = self._spawn_all_meter_workers(
world_size, meters, is_train=is_train
)
# First update each meter, then get value from each meter
qins[0].put_nowait((UPDATE_SIGNAL, (model_outputs[0], targets[0])))
qins[1].put_nowait((UPDATE_SIGNAL, (model_outputs[1], targets[1])))
qins[0].put_nowait((VALUE_SIGNAL, None))
qins[1].put_nowait((VALUE_SIGNAL, None))
value0 = _get_value_or_raise_error(qouts[0], qerrs[0])
value1 = _get_value_or_raise_error(qouts[1], qerrs[1])
self._values_match_expected_value(value0, value1, expected_values[0])
# Verify that calling value again does not break things
qins[0].put_nowait((VALUE_SIGNAL, None))
qins[1].put_nowait((VALUE_SIGNAL, None))
value0 = _get_value_or_raise_error(qouts[0], qerrs[0])
value1 = _get_value_or_raise_error(qouts[1], qerrs[1])
self._values_match_expected_value(value0, value1, expected_values[0])
# Second, update each meter, then get value from each meter
qins[0].put_nowait((UPDATE_SIGNAL, (model_outputs[2], targets[2])))
qins[1].put_nowait((UPDATE_SIGNAL, (model_outputs[3], targets[3])))
qins[0].put_nowait((VALUE_SIGNAL, None))
qins[1].put_nowait((VALUE_SIGNAL, None))
value0 = _get_value_or_raise_error(qouts[0], qerrs[0])
value1 = _get_value_or_raise_error(qouts[1], qerrs[1])
self._values_match_expected_value(value0, value1, expected_values[1])
qins[0].put_nowait((SHUTDOWN_SIGNAL, None))
qins[1].put_nowait((SHUTDOWN_SIGNAL, None))