-
Notifications
You must be signed in to change notification settings - Fork 29
/
utils.py
352 lines (274 loc) · 10.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
import time
from girder_worker_utils.tee import Tee, tee_stderr, tee_stdout
import requests
from requests import HTTPError
# Disable urllib3 warnings about certificate validation. As they are printed in the console, the
# messages are sent to Girder, creating an infinite loop.
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
BUILTIN_CELERY_TASKS = [
'celery.accumulate'
'celery.backend_cleanup',
'celery.chain',
'celery.chord',
'celery.chord_unlock',
'celery.chunks',
'celery.group',
'celery.map',
'celery.starmap']
def is_builtin_celery_task(task):
return task in BUILTIN_CELERY_TASKS
def _maybe_model_repr(obj):
if hasattr(obj, '_repr_model_') and callable(obj._repr_model_):
return obj._repr_model_()
return obj
# Access to the correct "Inspect" instance for this worker
_inspector = None
def _worker_inspector(task):
global _inspector
if _inspector is None:
try:
# Celery >= 5
from .app import app
inspect = app.control.inspect
except Exception:
# Celecy < 5
from celery.app.control import Inspect as inspect
_inspector = inspect([task.request.hostname])
return _inspector
# Get this list of currently revoked tasks for this worker
def _revoked_tasks(task):
_revoked = _worker_inspector(task).revoked()
if _revoked is None:
return []
return _revoked.get(task.request.hostname, [])
def deserialize_job_info_spec(**kwargs):
return JobManager(**kwargs)
class JobSpecNotFound(Exception):
pass
def _job_manager(request=None, headers=None, kwargs=None):
girder_client_session_kwargs = {}
if hasattr(request, 'girder_client_session_kwargs'):
girder_client_session_kwargs = request.girder_client_session_kwargs
if hasattr(request, 'jobInfoSpec'):
jobSpec = request.jobInfoSpec
# We are being called from revoked signal
elif headers is not None and \
'jobInfoSpec' in headers:
jobSpec = headers['jobInfoSpec']
# Deprecated: This method of passing job information
# to girder_worker is deprecated. Newer versions of girder
# pass this information automatically as apart of the
# header metadata in the worker scheduler.
elif kwargs and 'jobInfo' in kwargs:
jobSpec = kwargs.pop('jobInfo', {})
else:
raise JobSpecNotFound
return deserialize_job_info_spec(
**jobSpec, girder_client_session_kwargs=girder_client_session_kwargs)
def _update_status(task, status):
task.job_manager.updateStatus(status)
def is_revoked(task):
"""
Utility function to check if a task has been revoked.
:param task: The task.
:type task: celery.app.task.Task
:return: True, if this task is in the revoked list for this worker, False
otherwise.
:rtype: bool
"""
return task.request.id in _revoked_tasks(task)
def girder_job(title=None, type='celery', public=False,
handler='celery_handler', otherFields=None):
"""Decorator that populates a girder_worker celery task with
girder's job metadata.
:param title: The title of the job in girder.
:type title: str
:param type: The type of the job in girder.
:type type: str
:param public: Public read access flag for girder.
:type public: bool
:param handler: If this job should be handled by a specific handler,
'celery_handler' by default cannot be scheduled in girder.
:param otherFields: Any additional fields to set on the job in girder.
:type otherFields: dict
"""
otherFields = otherFields or {}
def _girder_job(task_obj):
task_obj._girder_job_title = title
task_obj._girder_job_type = type
task_obj._girder_job_public = public
task_obj._girder_job_handler = handler
task_obj._girder_job_other_fields = otherFields
return task_obj
return _girder_job
class JobStatus:
INACTIVE = 0
QUEUED = 1
RUNNING = 2
SUCCESS = 3
ERROR = 4
CANCELED = 5
FETCHING_INPUT = 820
CONVERTING_INPUT = 821
CONVERTING_OUTPUT = 822
PUSHING_OUTPUT = 823
CANCELING = 824
class StateTransitionException(Exception):
pass
class TeeCustomWrite(Tee):
def __init__(self, func, *args, **kwargs):
super().__init__(*args, **kwargs)
self._write_func = func
def write(self, *args, **kwargs):
self._write_func(*args, **kwargs)
super().write(*args, **kwargs)
@tee_stdout
class TeeStdOutCustomWrite(TeeCustomWrite):
pass
@tee_stderr
class TeeStdErrCustomWrite(TeeCustomWrite):
pass
class JobManager:
"""
This class can be used to write log messages to Girder by capturing
stdout/stderr printed within the context and sending them in a
rate-limited manner to Girder. This is not threadsafe since it changes
the global values of sys.stdout/sys.stderr.
It also exposes utilities for updating other job fields such as progress
and status.
"""
def __init__(self, logPrint, url, method=None, headers=None, interval=0.5,
reference=None, girder_client_session_kwargs=None):
"""
:param on: Whether print messages should be logged to the job log.
:type on: bool
:param url: The job update URL.
:param method: The HTTP method to use when updating the job.
:param headers: Optional HTTP header dict
:param interval: Minimum time interval at which to send log updates
back to Girder over HTTP (seconds).
:type interval: int or float
:param reference: optional reference to store with the job.
"""
self.logPrint = logPrint
self.method = method or 'PUT'
self.url = url
self.headers = headers or {}
self.interval = interval
self.status = None
self.reference = reference
self._last = time.time()
self._buf = b''
self._progressTotal = None
self._progressCurrent = None
self._progressMessage = None
self._session = requests.Session()
retryAdapter = requests.adapters.HTTPAdapter(max_retries=10)
self._session.mount('http://', retryAdapter)
self._session.mount('https://', retryAdapter)
if girder_client_session_kwargs:
for attr, value in girder_client_session_kwargs.items():
setattr(self._session, attr, value)
if logPrint:
self._stdout = TeeStdOutCustomWrite(self.write)
self._stderr = TeeStdErrCustomWrite(self.write)
def cleanup(self):
self._session.close()
if self.logPrint:
self._stdout.reset()
self._stderr.reset()
def _flush(self):
"""
If there are contents in the buffer, send them up to the server. If the
buffer is empty, this is a no-op.
"""
if not self.url:
return
if len(self._buf) or self._progressTotal or self._progressMessage or \
self._progressCurrent is not None:
data = {
'progressTotal': self._progressTotal,
'progressCurrent': self._progressCurrent,
'progressMessage': self._progressMessage
}
if self._buf:
data['log'] = self._buf
req = self._session.request(
self.method.upper(), self.url, allow_redirects=True,
headers=self.headers, data=data)
req.raise_for_status()
self._buf = b''
def write(self, message, forceFlush=False):
"""
Append a message to the log for this job. If logPrint is enabled, this
will be called whenever stdout or stderr is printed to. Otherwise it
can be called manually and will still perform rate-limited flushing to
the server.
:param message: The message to append to the job log.
:type message: str
:param forceFlush: Whether to force the write of this event to the
server. Useful if you don't expect another update for some time.
:type forceFlush: bool
"""
if isinstance(message, str):
message = message.encode('utf8')
self._buf += message
if forceFlush or time.time() - self._last > self.interval:
self._flush()
self._last = time.time()
def updateStatus(self, status):
"""
Update the status field of a job.
:param status: The status to set on the job.
:type status: JobStatus
"""
if not self.url or status is None or status == self.status:
return
# Ensure that the logs are flushed before the status is changed
self._flush()
self.status = status
try:
req = self._session.request(self.method.upper(), self.url, headers=self.headers,
data={'status': status}, allow_redirects=True)
req.raise_for_status()
except HTTPError as hex:
if hex.response.status_code == 400:
json_response = hex.response.json()
if 'field' in json_response and json_response['field'] == 'status':
print(json_response['message'])
raise StateTransitionException(json_response['message'], hex)
else:
raise
else:
raise
def updateProgress(self, total=None, current=None, message=None,
forceFlush=False):
"""
Update the progress information about a job.
:param total: The total progress value, or None to leave the same.
:type total: int, float, or None
:param current: The current progress value, or None to leave the same.
:type current: int, float, or None
:param message: Progress message to set, or None to leave the same.
:type message: str or None
:param forceFlush: Whether to force the write of this event to the
server. Useful if you don't expect another update for some time.
:type forceFlush: bool
"""
if total is not None:
self._progressTotal = total
if current is not None:
self._progressCurrent = current
if message is not None:
self._progressMessage = message
if forceFlush or time.time() - self._last > self.interval:
self._flush()
self._last = time.time()
def refreshStatus(self):
"""
Refresh the status field from Girder
"""
r = self._session.get(self.url, headers=self.headers, allow_redirects=True)
self.status = r.json()['status']
return self.status