/
protocols.py
470 lines (393 loc) · 16.5 KB
/
protocols.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
# Copyright 2022 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tuning protocols."""
import abc
import contextlib
import datetime
import time
import traceback
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from pyglove.core import geno
from pyglove.core import logging
from pyglove.core import object_utils
from pyglove.core import symbolic
from pyglove.core import typing as pg_typing
class _DataEntity(symbolic.Object):
"""Base class for object that is used as data entity."""
# Allow assignment on symbolic attributes.
allow_symbolic_assignment = True
def __hash__(self):
"""Hash code."""
return hash(repr(self))
@symbolic.members([
('step', pg_typing.Int(), 'At which step the result is reported.'),
('elapse_secs', pg_typing.Float(), 'Elapse in seconds since trial start.'),
('reward', pg_typing.Float().noneable(),
('Reward of reported tunable target. Can be None if multi-objective '
'optimization is used.')),
('metrics', pg_typing.Dict([
(pg_typing.StrKey(), pg_typing.Float(), 'Metric item.')
]).noneable(), 'Metric in key/value pairs (optional).'),
('checkpoint_path', pg_typing.Str().noneable(),
'Path to the checkpoint of this specific measurement.')
])
class Measurement(_DataEntity):
"""Measurement of a trial at certain step."""
@symbolic.members([
('id', pg_typing.Int(), 'Identifier of the trial.'),
('description', pg_typing.Str().noneable(), 'Description of the trial.'),
('dna', pg_typing.Object(geno.DNA), 'Proposed DNA for the trial.'),
('status',
pg_typing.Enum('PENDING', [
'UNKNOWN',
'REQUESTED',
'PENDING',
'COMPLETED',
'DELETED',
'STOPPING',
]), 'Trial status.'),
('final_measurement', pg_typing.Object(Measurement).noneable(),
'Reported final results.'),
('infeasible', pg_typing.Bool(False), 'Whether trial is infeasible.'),
('measurements', pg_typing.List(pg_typing.Object(Measurement), default=[]),
'All reported measurements.'),
('metadata',
pg_typing.Dict(
[(pg_typing.StrKey(), pg_typing.Any(),
'Serializable key value pairs as metadata.')]),
'Trial metadata.'),
('related_links',
pg_typing.Dict([
(pg_typing.StrKey(), pg_typing.Str(), 'Related link.')]),
'Related links'),
# TODO(daiyip): consider change time from timestamp to datetime.
# Need to introduce a mechanism in symbolic to cherry pick serialization
# for individual fields.
('created_time', pg_typing.Int(), 'Created time in Unix timestamp.'),
('completed_time', pg_typing.Int().noneable(),
'Completed time in Unix timestamp.'),
])
class Trial(_DataEntity):
"""Metadata of a trial."""
def get_reward_for_feedback(
self, metric_names: Optional[Sequence[str]] = None
) -> Union[None, float, Tuple[float]]:
"""Get reward for feedback."""
if self.status != 'COMPLETED' or self.infeasible:
return None
assert self.final_measurement is not None
measurement = self.final_measurement
if metric_names is None:
return measurement.reward
assert metric_names, metric_names
metric_values = []
for metric_name in metric_names:
if metric_name == 'reward':
v = measurement.reward
else:
v = measurement.metrics.get(metric_name, None)
if v is None:
raise ValueError(
f'Metric {metric_name!r} does not exist in final '
f'measurement {measurement!r} in trial {self.id}.')
metric_values.append(v)
return tuple(metric_values) if len(metric_values) > 1 else metric_values[0]
class Result(object_utils.Formattable):
"""Interface for tuning result."""
@property
@abc.abstractmethod
def metadata(self) -> Dict[str, Any]:
"""Returns the metadata of current sampling."""
@property
@abc.abstractmethod
def is_active(self) -> bool:
"""Returns whether the tuner task is active."""
@property
@abc.abstractmethod
def last_updated(self) -> datetime.datetime:
"""Last updated time."""
@property
@abc.abstractmethod
def trials(self) -> List[Trial]:
"""Retrieve all trials."""
@property
@abc.abstractmethod
def best_trial(self) -> Optional[Trial]:
"""Get best trial so far."""
class Feedback(metaclass=abc.ABCMeta):
"""Interface for the feedback object for a trial.
Feedback object is an agent to communicate to the search algorithm and other
workers based on current trial, which includes:
* Information about current example:
* :attr:`id`: The ID of current example, started from 1.
* :attr:`dna`: The DNA for current example.
* Methods to communicate with the search algorithm:
* :meth:`add_measurement`: Add a measurement for current example.
Multiple measurements can be added as progressive evaluation of the
example, which can be used by the early stopping policy to suggest
whether current evaluation can be stopped early.
* :meth:`done`: Mark evaluation on current example as done, use the
reward from the latest measurement to feedback to the algorithm, and
move to the next example.
* :meth:`__call__`: A shortcut method that calls :meth:`add_measurement`
and :meth:`done` in sequence.
* :meth:`skip`: Mark evaluation on current example as done, and move to
the next example without providing feedback to the algorithm.
* :meth:`should_stop_early`: Tell if progressive evaluation on current
example can be stopped early.
* :meth:`end_loop`: Mark the loop as done. All workers will get out of
the loop after they finish evaluating their current examples.
* Methods to publish information associated with current trial:
* :meth:`set_metadata`: Set persistent metadata by key.
* :meth:`get_metadata`: Get persistent metadata by key.
* :meth:`add_link`: Add a related link by key.
"""
def __init__(self, metrics_to_optimize: Sequence[str]):
super().__init__()
self._metrics_to_optimize = metrics_to_optimize
self._sample_time = time.time()
@property
@abc.abstractmethod
def id(self) -> int:
"""Gets the ID of current trial."""
@property
@abc.abstractmethod
def dna(self) -> geno.DNA:
"""Gets the DNA of the example used in current trial."""
@property
@abc.abstractmethod
def checkpoint_to_warm_start_from(self) -> Optional[str]:
"""Gets checkpoint path to warm start from."""
def add_measurement(
self,
reward: Union[None, float, Sequence[float]] = None,
metrics: Optional[Dict[str, float]] = None,
step: int = 0,
checkpoint_path: Optional[str] = None,
elapse_secs: Optional[float] = None) -> None:
"""Add a measurement for current trial.
This method can be called multiple times on the same trial, e.g::
for model, feedback in pg.sample(...):
accuracy = train_and_evaluate(model, step=10)
feedback.add_measurement(accuracy, step=10)
accuracy = train_and_evaluate(model, step=15)
feedback.add_measurement(accuracy, step=25)
feedback.done()
Args:
reward: An optional float value as the reward for single-objective
optimization, or a sequence of float values for multiple objectives
optimization. In multiple-objective scenario, the float sequence will
be paired up with the `metrics_to_optimize` argument of `pg.sample`,
thus their length must be equal.
Another way for providing reward for multiple-objective reward is
through the `metrics` argument, which is a dict using metric name as key
and its measure as value (the key should match with an element of the
`metrics_to_optimize` argument). When multi-objective reward is provided
from both the `reward` argument (via a sequence of float) and the
`metrics` argument, their value should agree with each other.
metrics: An optional dictionary of string to float as metrics. It can
be used to provide metrics for multi-objective optimization, and/or
carry additional metrics for study analysis.
step: An optional integer as the step (e.g. step for model training),
at which the measurement applies. When a trial is completed, the
measurement at the largest step will be chosen as the final measurement
to feed back to the controller.
checkpoint_path: An optional string as the checkpoint path produced
from the evaluation (e.g. training a model), which can be used in
transfer learning.
elapse_secs: Time spent on evaluating current example so far. If None,
it will be automatically computed by the backend.
"""
metrics_to_optimize = self._metrics_to_optimize
metrics = metrics or {}
if isinstance(reward, (list, tuple)):
rewards = reward
if len(rewards) != len(metrics_to_optimize):
raise ValueError(
f'The number of items in the reward ({rewards!r}) computed by the '
f'controller does not match with the number of metrics to '
f'optimize ({metrics_to_optimize!r}).')
for k, v in zip(metrics_to_optimize, rewards):
if k in metrics and metrics[k] != v:
raise ValueError(
f'The value for metric {k} is provided from both the \'reward\' '
f'and the \'metrics\' arguments with different values: '
f'{[v, metrics[k]]!r}.')
metrics[k] = v
reward = metrics.pop('reward', None)
elif reward is not None:
reward = float(reward)
for metric_name in metrics_to_optimize:
if metric_name == 'reward':
if reward is None:
raise ValueError(
'\'reward\' must be provided as it is a goal to optimize.')
elif metric_name in metrics:
metrics[metric_name] = float(metrics[metric_name])
else:
raise ValueError(
f'Metric {metric_name!r} must be provided '
f'as it is a goal to optimize.')
if len(metrics_to_optimize) == 1 and metrics_to_optimize[0] != 'reward':
if reward is None:
reward = metrics[metrics_to_optimize[0]]
else:
raise ValueError(
f'\'reward\' {reward!r} is provided while it is '
f'not a goal to optimize.')
if elapse_secs is None:
elapse_secs = time.time() - self._sample_time
self._add_measurement(
reward, metrics, step, checkpoint_path, elapse_secs)
def _add_measurement(
self,
reward: Optional[float],
metrics: Dict[str, float],
step: int,
checkpoint_path: Optional[str],
elapse_secs: float) -> None:
"""Child class should implement."""
raise NotImplementedError()
@abc.abstractmethod
def get_trial(self) -> Trial:
"""Gets current Trial.
Returns:
An up-to-date `Trial` object. A distributed tuning backend should make
sure the return value is up-to-date not only locally, but among different
workers.
"""
@abc.abstractmethod
def set_metadata(self, key: str, value: Any, per_trial: bool = True) -> None:
"""Sets metadata for current trial or current sampling.
Metadata can be used in two use cases:
* Worker processes that co-work on the same trial can use meta-data to
communicate with each other.
* Worker use metadata as a persistent store to save information for
current trial, which can be retrieved via `poll_result` method later.
Args:
key: A string as key to metadata.
value: A value that can be serialized by `pg.to_json_str`.
per_trial: If True, the key is set per current trial. Otherwise, it
is set per current sampling loop.
"""
@abc.abstractmethod
def get_metadata(self, key: str, per_trial: bool = True) -> Optional[Any]:
"""Gets metadata for current trial or current sampling.
Args:
key: A string as key to metadata.
per_trial: If True, the key is retrieved per curent trial. Otherwise, it
is retrieved per current sampling.
Returns:
A value that can be deserialized by `pg.from_json_str`.
"""
@abc.abstractmethod
def add_link(self, name: str, url: str) -> None:
"""Adds a related link to current trial.
Added links can be retrieved from the `Trial.related_links` property via
`pg.poll_result`.
Args:
name: Name for the related link.
url: URL for this link.
"""
@abc.abstractmethod
def done(self,
metadata: Optional[Dict[str, Any]] = None,
related_links: Optional[Dict[str, str]] = None) -> None:
"""Marks current trial as done.
Args:
metadata: Additional metadata to add to current trial.
related_links: Additional links to add to current trial.
"""
@abc.abstractmethod
def skip(self, reason: Optional[str] = None) -> None:
"""Move to next example without providing the feedback to the algorithm."""
@abc.abstractmethod
def should_stop_early(self) -> bool:
"""Whether progressive evaluation can be stopped early on current trial."""
@abc.abstractmethod
def end_loop(self) -> None:
"""Ends current sapling loop."""
def __call__(
self,
reward: Union[None, float, Sequence[float]] = None,
metrics: Optional[Dict[str, float]] = None,
checkpoint_path: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
related_links: Optional[Dict[str, str]] = None,
step: int = 0) -> None:
"""Adds a measurement and marks the trial as done."""
self.add_measurement(
reward, metrics, step=step, checkpoint_path=checkpoint_path)
self.done(metadata=metadata, related_links=related_links)
def skip_on_exceptions(
self, exceptions: Sequence[
Union[Type[Exception], Tuple[Exception, str]]]):
"""Returns a context manager to skip trial on user-specified exceptions.
Usages::
with feedback.skip_on_exceptions((ValueError, KeyError)):
...
with feedback.skip_on_exceptions(((ValueError, 'bad value for .*'),
(ValueError, '.* invalid range'),
TypeError)):
...
Args:
exceptions: A sequence of (exception type, or exception type plus regular
expression for error message).
Returns:
A context manager for skipping trials on user-specified exceptions.
"""
def skip_on_exception(unused_error):
error_stack = traceback.format_exc()
logging.warning('Skipping trial on unhandled exception: %s', error_stack)
self.skip(error_stack)
return object_utils.catch_errors(exceptions, skip_on_exception)
@contextlib.contextmanager
def ignore_race_condition(self):
"""Context manager for ignoring RaceConditionError within the scope.
Race condition may happen when multiple workers are working on the same
trial (e.g. paired train/eval processes). Assuming there are two co-workers
(X and Y), common race conditions are:
1) Both X and Y call `feedback.done` or `feedback.skip` to the same trial.
2) X calls `feedback.done`/`feedback.skip`, then B calls
`feedback.add_measurement`.
Users can use this context manager to simplify the code for handling
multiple co-workers. (See the `group` argument of `pg.sample`)
Usages::
feedback = ...
def thread_fun():
with feedback.ignore_race_condition():
feedback.add_measurement(0.1)
# Multiple workers working on the same trial might trigger this code
# from different processes.
feedback.done()
x = threading.Thread(target=thread_fun)
x.start()
y = threading.Thread(target=thread_fun)
y.start()
Yields:
None.
"""
try:
yield
except RaceConditionError:
pass
class RaceConditionError(RuntimeError):
"""Race condition error.
This error will be raisen when the operations made to `Feedback` indicates
a race condition. There are possible scenarios that may lead to such race
conditions, which happen among multiple co-workers (taking X and Y for
example) on the same trial:
* X calls `feedback.done`/`feedback.skip`, then B calls
`feedback.add_measurement`.
"""