-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
utils.py
291 lines (231 loc) · 7.98 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2017 The Ray Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is adapted from
# https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/utils.py
import collections
from contextlib import closing, contextmanager
import logging
import numpy as np
import socket
import time
import torch.distributed as dist
import torch
from bigdl.dllib.utils.log4Error import *
logger = logging.getLogger(__name__)
BATCH_COUNT = "batch_count"
NUM_SAMPLES = "num_samples"
BATCH_SIZE = "*batch_size"
class TimerStat:
"""A running stat for conveniently logging the duration of a code block.
Note that this class is *not* thread-safe.
Examples:
Time a call to 'time.sleep'.
>>> import time
>>> sleep_timer = TimerStat()
>>> with sleep_timer:
... time.sleep(1)
>>> round(sleep_timer.mean)
1
"""
def __init__(self, window_size=10):
self._window_size = window_size
self._samples = []
self._units_processed = []
self._start_time = None
self._total_time = 0.0
self.count = 0
def __enter__(self):
invalidInputError(self._start_time is None, "concurrent updates not supported")
self._start_time = time.time()
def __exit__(self, type, value, tb):
invalidInputError(self._start_time is not None, "expect start time is not none")
time_delta = time.time() - self._start_time
self.push(time_delta)
self._start_time = None
def push(self, time_delta):
self._samples.append(time_delta)
if len(self._samples) > self._window_size:
self._samples.pop(0)
self.count += 1
self._total_time += time_delta
def push_units_processed(self, n):
self._units_processed.append(n)
if len(self._units_processed) > self._window_size:
self._units_processed.pop(0)
@property
def mean(self):
return np.mean(self._samples)
@property
def median(self):
return np.median(self._samples)
@property
def sum(self):
return np.sum(self._samples)
@property
def max(self):
return np.max(self._samples)
@property
def first(self):
return self._samples[0] if self._samples else None
@property
def last(self):
return self._samples[-1] if self._samples else None
@property
def size(self):
return len(self._samples)
@property
def mean_units_processed(self):
return float(np.mean(self._units_processed))
@property
def mean_throughput(self):
time_total = sum(self._samples)
if not time_total:
return 0.0
return sum(self._units_processed) / time_total
def reset(self):
self._samples = []
self._units_processed = []
self._start_time = None
self._total_time = 0.0
self.count = 0
@contextmanager
def _nullcontext(enter_result=None):
"""Used for mocking timer context."""
yield enter_result
class TimerCollection:
"""A grouping of Timers."""
def __init__(self):
self._timers = collections.defaultdict(TimerStat)
self._enabled = True
def disable(self):
self._enabled = False
def enable(self):
self._enabled = True
def reset(self):
for timer in self._timers.values():
timer.reset()
def record(self, key):
if self._enabled:
return self._timers[key]
else:
return _nullcontext()
def stats(self, mean=True, last=False):
aggregates = {}
for k, t in self._timers.items():
if t.count > 0:
if mean:
aggregates["mean_%s_s" % k] = t.mean
if last:
aggregates["last_%s_s" % k] = t.last
return aggregates
def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class AverageMeterCollection:
"""A grouping of AverageMeters."""
def __init__(self):
self._batch_count = 0
self.n = 0
self._meters = collections.defaultdict(AverageMeter)
def update(self, metrics, n=1):
self._batch_count += 1
self.n += n
for metric, value in metrics.items():
self._meters[metric].update(value, n=n)
def summary(self, sync_stats=False, dist_backend=None):
"""Returns a dict of average and most recent values for each metric."""
stats = {BATCH_COUNT: self._batch_count, NUM_SAMPLES: self.n}
for metric, meter in self._meters.items():
if sync_stats:
world_size = dist_backend.get_world_size()
avg = torch.tensor(meter.avg)
dist_backend.all_reduce(avg)
last_val = torch.tensor(meter.val)
dist_backend.all_reduce(last_val)
avg = avg.item() / world_size
last_val = last_val.item() / world_size
else:
avg = meter.avg
last_val = meter.val
stats[str(metric)] = avg
stats["last_" + str(metric)] = last_val
return stats
def check_for_failure(remote_values):
"""Checks remote values for any that returned and failed.
Args:
remote_values (list): List of object IDs representing functions
that may fail in the middle of execution. For example, running
a SGD training loop in multiple parallel actor calls.
Returns:
Bool for success in executing given remote tasks.
"""
import ray
from ray.exceptions import RayActorError
unfinished = remote_values
try:
while len(unfinished) > 0:
finished, unfinished = ray.wait(unfinished)
finished = ray.get(finished)
return True
except RayActorError as exc:
logger.exception(str(exc))
return False
def override(interface_class):
def overrider(method):
invalidInputError(method.__name__ in dir(interface_class),
"method.__name__ doesn't exist in interface_class")
return method
return overrider
def get_filesystem(filepath):
from fsspec.core import url_to_fs
fs, _ = url_to_fs(str(filepath))
return fs
def get_batchsize(input):
if isinstance(input, (list, tuple)):
return get_batchsize(input[0])
elif isinstance(input, dict):
return get_batchsize(list(input.values())[0])
else:
return input.size(0)