-
Notifications
You must be signed in to change notification settings - Fork 189
/
optimizer_utils.py
570 lines (472 loc) · 21.5 KB
/
optimizer_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
# Copyright 2019, Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimizer utilities supporting federated averaging experiments."""
import collections
import inspect
from typing import Any, Callable, Dict, List, Optional
from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_addons.optimizers as tfao
from utils.optimizers import lars
from utils.optimizers import shampoo
from utils.optimizers import yogi
def _optimizer_canonical_name(optimizer_cls):
"""Return a short, canonical name for an optimizer for us in flags."""
return optimizer_cls.__name__.lower()
# List of optimizers currently supported.
_SUPPORTED_OPTIMIZERS_CLS = [
tf.keras.optimizers.SGD,
tf.keras.optimizers.Adagrad,
tf.keras.optimizers.Adam,
yogi.Yogi,
lars.LARS,
tfao.lamb.LAMB,
shampoo.Shampoo,
]
_SUPPORTED_OPTIMIZERS = {
_optimizer_canonical_name(cls): cls for cls in _SUPPORTED_OPTIMIZERS_CLS
}
def define_optimizer_flags(prefix: str) -> None:
"""Defines flags with `prefix` to configure an optimizer.
This method is inteded to be paired with `create_optimizer_from_flags` using
the same `prefix`, to allow Python binaries to constructed TensorFlow
optimizers parameterized by commandline flags.
This creates two new flags:
* `--<prefix>_optimizer=<optimizer name>`
* `--<prefix>_learning_rate`
In addition to a suite of flags for each optimizer:
* `--<prefix>_<optimizer name>_<constructor_argument>`
For example, given the prefix "client" this will create flags (non-exhaustive
list):
* `--client_optimizer`
* `--client_learning_rate`
* `--client_sgd_momentum`
* `--client_sgd_nesterov`
* `--client_adam_beta_1`
* `--client_adam_beta_2`
* `--client_adam_epsilon`
Then calls to `create_optimizer_from_flags('client')` will construct an
optimizer of the type named in `--client_optimizer`, parameterized by the
flags prefixed with the matching optimizer name. For example, if
`--client_optimizer=sgd`, `--client_sgd_*` flags will be used.
IMPORTANT: For flags to be correctly parsed from the commandline, this method
must be called before `absl.app.run(main)`, and is recommened to be called
next to other flag definitions at the top of a py_binary.
Args:
prefix: A string (possibly empty) indicating which optimizer is being
configured.
"""
# Create top-level, non-optimizer specific flags for picking the optimizer
# type and the learning rate.
flags.DEFINE_enum(
name='{!s}_optimizer'.format(prefix),
default=None,
enum_values=list(_SUPPORTED_OPTIMIZERS.keys()),
help='The type of optimizer to construct for `{!s}`'.format(prefix))
logging.info('Defined new flag: [%s]', '{!s}_optimizer'.format(prefix))
flags.DEFINE_float(
name='{!s}_learning_rate'.format(prefix),
default=None,
help='Base learning rate for optimizer `{!s}`'.format(prefix))
logging.info('Defined new flag: [%s]', '{!s}_learning_rate'.format(prefix))
for optimizer_name, optimizer_cls in _SUPPORTED_OPTIMIZERS.items():
# Pull out the constructor parameters except for `self`.
constructor_signature = inspect.signature(optimizer_cls.__init__)
constructor_params = list(constructor_signature.parameters.values())[1:]
def prefixed(basename, optimizer_name=optimizer_name):
if prefix:
return '{!s}_{!s}_{!s}'.format(prefix, optimizer_name, basename)
else:
return '{!s}_{!s}'.format(optimizer_name, basename)
def is_param_of_type(param, typ):
return (param.default is None and param.annotation == Optional[typ] or
isinstance(param.default, typ))
for param in constructor_params:
if param.name in ['kwargs', 'args', 'learning_rate']:
continue
if is_param_of_type(param, bool):
define_flag_fn = flags.DEFINE_bool
elif is_param_of_type(param, float):
define_flag_fn = flags.DEFINE_float
elif is_param_of_type(param, int):
define_flag_fn = flags.DEFINE_integer
elif is_param_of_type(param, str):
define_flag_fn = flags.DEFINE_string
elif is_param_of_type(param, List[str]):
define_flag_fn = flags.DEFINE_multi_string
else:
raise NotImplementedError('Cannot define flag [{!s}] '
'for parameter [{!s}] of type [{!s}] '
'(default value type [{!s}]) '
'on optimizer [{!s}]'.format(
prefixed(param.name),
param.name, param.annotation,
type(param.default), optimizer_name))
define_flag_fn(
name=prefixed(param.name),
default=param.default,
help='{!s} argument for the {!s} optimizer.'.format(
param.name, optimizer_name))
logging.info('Defined new flag: [%s]', prefixed(param.name))
def remove_unused_flags(prefix: str,
hparam_dict: Dict[str, Any]) -> collections.OrderedDict:
"""Removes unused optimizer flags with a given prefix.
This method is intended to be used with `define_optimizer_flags`, and is used
to remove elements of hparam_dict associated with unused optimizer flags.
For example, given the prefix "client", define_optimizer_flags will create
flags including:
* `--client_optimizer`
* `--client_learning_rate`
* `--client_sgd_momentum`
* `--client_sgd_nesterov`
* `--client_adam_beta_1`
* `--client_adam_beta_2`
* `--client_adam_epsilon`
and other such flags.
However, for purposes of recording hyperparameters, we would like to only keep
those that correspond to the optimizer selected in the flag
--client_optimizer. This method is intended to remove the unused flags.
For example, if `--client_optimizer=sgd` was set, then calling this method
with the prefix `client` will remove all pairs in hparam_dict except those
associated with the flags:
* `--client_optimizer`
* `--client_learning_rate`
* `--client_sgd_momentum`
* `--client_sgd_nesterov`
Args:
prefix: The prefix used to define optimizer flags, such as via
`optimizer_utils.define_optimizer_flags(prefix)`. Standard examples
include `prefix=client` and `prefix=server`.
hparam_dict: A dictionary of (string, value) pairs corresponding to
experiment hyperparameters.
Returns:
An ordered dictionary of (string, value) pairs from hparam_dict that omits
any pairs where string = "<prefix>_<optimizer>*" but <optimizer> is not the
one set via the flag --<prefix>_optimizer=...
"""
def prefixed(basename):
return '{}_{}'.format(prefix, basename) if prefix else basename
if prefixed('optimizer') not in hparam_dict.keys():
raise ValueError('The flag {!s} was not defined.'.format(
prefixed('optimizer')))
optimizer_name = hparam_dict[prefixed('optimizer')]
if not optimizer_name:
raise ValueError('The flag {!s} was not set. Unable to determine the '
'relevant optimizer.'.format(prefixed('optimizer')))
unused_optimizer_flag_prefixes = [
prefixed(k) for k in _SUPPORTED_OPTIMIZERS.keys() if k != optimizer_name
]
def _is_used_flag(flag_name):
# We filter by whether the flag contains an unused optimizer prefix.
# This retains any flag not of the form <prefix>_<optimizer>_*.
for unused_flag_prefix in unused_optimizer_flag_prefixes:
if flag_name.startswith(f'{unused_flag_prefix}_'):
return False
return True
used_flags = collections.OrderedDict()
for (flag_name, flag_value) in hparam_dict.items():
if _is_used_flag(flag_name):
used_flags[flag_name] = flag_value
return used_flags
def create_optimizer_fn_from_flags(
prefix: str) -> Callable[..., tf.keras.optimizers.Optimizer]:
"""Returns an optimizer function based on prefixed flags.
This method is inteded to be paired with `define_optimizer_flags` using the
same `prefix`, to allow Python binaries to constructed TensorFlow optimizers
parameterized by commandline flags.
This method expects at least two flags to have been defined and set:
* `--<prefix>_optimizer=<optimizer name>`
* `--<prefix>_learning_rate`
In addition to suites of flags for each optimizer:
* `--<prefix>_<optimizer name>_<constructor_argument>`
For example, if `prefix='client'` this method first reads the flags:
* `--client_optimizer`
* `--client_learning_rate`
If the optimizer flag is `'sgd'`, then a `tf.keras.optimizer.SGD` optimizer is
constructed using the values in the flags prefixed with `--client_sgd_`.
Args:
prefix: The same string prefix passed to `define_optimizer_flags`.
Returns:
A 1-arg function that accepts a learning rate and returns a
`tf.keras.optimizers.Optimizer`.
"""
def prefixed(basename):
return '{}_{}'.format(prefix, basename) if prefix else basename
optimizer_flag_name = prefixed('optimizer')
if flags.FLAGS[optimizer_flag_name] is None:
raise ValueError('Must specify flag --{!s}'.format(optimizer_flag_name))
optimizer_name = flags.FLAGS[optimizer_flag_name].value
optimizer_cls = _SUPPORTED_OPTIMIZERS.get(optimizer_name)
if optimizer_cls is None:
# To support additional optimizers, implement it as a
# `tf.keras.optimizers.Optimizer` and add to the `_SUPPORTED_OPTIMIZERS`
# dict.
logging.error(
'Unknown optimizer [%s], known optimziers are [%s]. To add '
'support for an optimizer, add the optimzier class to the '
'utils_impl._SUPPORTED_OPTIMIZERS list.', optimizer_name,
list(_SUPPORTED_OPTIMIZERS.keys()))
raise ValueError('`{!s}` is not a valid optimizer for flag --{!s}, must be '
'one of {!s}. See error log for details.'.format(
optimizer_name, optimizer_flag_name,
list(_SUPPORTED_OPTIMIZERS.keys())))
def _has_user_value(flag):
"""Check if a commandline flag has a user set value."""
return flag.present or flag.value != flag.default
# Validate that the optimizers that weren't picked don't have flag values set.
# Settings that won't be used likely means there is an expectation gap between
# the user and the system and we should notify them.
unused_flag_prefixes = [
prefixed(k) for k in _SUPPORTED_OPTIMIZERS.keys() if k != optimizer_name
]
mistakenly_set_flags = []
for flag_name in flags.FLAGS:
if not _has_user_value(flags.FLAGS[flag_name]):
# Flag was not set by the user, skip it.
continue
# Otherwise the flag has a value set by the user.
for unused_prefix in unused_flag_prefixes:
if flag_name.startswith(f'{unused_prefix}_'):
mistakenly_set_flags.append(flag_name)
break
if mistakenly_set_flags:
raise ValueError('Commandline flags for optimizers other than [{!s}] '
'(value of --{!s}) are set. These would be ignored, '
'were the flags set by mistake? Flags: {!s}'.format(
optimizer_name, optimizer_flag_name,
mistakenly_set_flags))
lr_flag_name = prefixed('learning_rate')
lr_flag = flags.FLAGS[lr_flag_name]
if _has_user_value(lr_flag):
default_lr = lr_flag.value
else:
raise ValueError(
'Learning rate for {!s} must be set by the flag --{!s} .'.format(
prefix, lr_flag_name))
flag_prefix = prefixed(optimizer_name)
prefix_len = len(flag_prefix) + 1
kwargs = {}
for flag_name in flags.FLAGS:
if not flag_name.startswith(f'{flag_prefix}_'):
continue
arg_name = flag_name[prefix_len:]
kwargs[arg_name] = flags.FLAGS[flag_name].value
if 'learning_rate' in kwargs:
kwargs.pop('learning_rate')
def optimizer_fn(learning_rate=default_lr):
return optimizer_cls(learning_rate=learning_rate, **kwargs)
return optimizer_fn
def define_lr_schedule_flags(prefix: str) -> None:
"""Defines flags with `prefix` to configure a learning rate schedule.
This method is intended to be paired with `create_optimizer_from_flags` with
the same `prefix`, to allow Python binaries to construct `tf.keras.optimizer`
objects from flags, along with an associated learning rate schedule.
This creates four new flags:
* `--<prefix>_lr_schedule`
* `--<prefix>_lr_warmup_steps`
* `--<prefix>_lr_decay_step`
* `--<prefix>_lr_decay_rate`
* `--<prefix>_lr_staircase`
Note that this should generally be preceded by `define_optimizer_flags`, and
followed by `create_lr_schedule_from_flags`. This will then create a learning
rate scheduling function governed by the flags defined herein.
Args:
prefix: A string (possibly empty) indicating which optimizer is being
configured.
"""
def prefixed(basename):
return '{}_{}'.format(prefix, basename) if prefix else basename
base_lr_flag_name = prefixed('learning_rate')
if flags.FLAGS[base_lr_flag_name] is None:
logging.warning(
'The flag %s is not set. This must be set before calling '
'`create_lr_schedule_from_flags`.', base_lr_flag_name)
flags.DEFINE_enum(
'{!s}_lr_schedule'.format(prefix),
default='constant',
enum_values=['constant', 'exp_decay', 'inv_lin_decay', 'inv_sqrt_decay'],
help='Type of learning rate decay schedule to use for `{!s}`.'.format(
prefix))
flags.DEFINE_integer(
'{!s}_lr_warmup_steps'.format(prefix),
default=None,
help='An int number of steps to warm up the `{!s}` learning rate (e.g. '
'increase linearly from 0 to the base value).'.format(prefix))
flags.DEFINE_integer(
'{!s}_lr_decay_steps'.format(prefix),
default=None,
help='An int used to compute the learning rate schedule.'
'If staircase is set to True, then the learning rate changes every '
'`{!s}_lr_decay_steps` rounds.'.format(prefix))
flags.DEFINE_float(
'{!s}_lr_decay_rate'.format(prefix),
default=None,
help='The decay rate of the {!s} learning rate schedule.'.format(prefix))
flags.DEFINE_bool(
'{!s}_lr_staircase'.format(prefix),
default=False,
help='Whether to decay the `{!s}` learning rate at discrete intervals.'
.format(prefix))
def warmup_and_decay_schedule_builder(base_value, warmup_steps, decay_fn):
"""Creates a learning rate schedule with warmup and decay.
Args:
base_value: The base value of the quantity to warm up to, then decay from,
over time.
warmup_steps: A scalar for the number of steps to linearly increase the
value (from base_value/warmup_steps to base_value) prior to decaying. No
warmup if 0 or negative.
decay_fn: A 1-arg callable producing a decayed version of the base value
when passed the current round_num (adjusted for warmup_steps if relevant).
Returns:
A 1-arg callable that produces a warmed up then decayed version of the base
value when passed the (unadjusted) current round_num.
"""
if warmup_steps is None or warmup_steps <= 0:
def warmup_and_decay_fn(round_num):
return decay_fn(round_num)
else:
def warmup_and_decay_fn(round_num):
round_num = tf.cast(round_num, tf.float32)
warmedup_value = base_value * (round_num + 1) / warmup_steps
return tf.cond(
tf.less(round_num, warmup_steps), lambda: warmedup_value,
lambda: decay_fn(round_num - warmup_steps))
return warmup_and_decay_fn
def exp_decay_schedule_builder(base_value, decay_steps, decay_rate, staircase):
"""Creates a learning rate schedule with exponential root decay.
Args:
base_value: The base value of the quantity to decay over time.
decay_steps: A positive scalar that governs how much the value decays at a
given round number.
decay_rate: A float between 0 and 1 that governs how quickly the decay
occurs.
staircase: A boolean. If set to True, the decaying occurs in discrete
intervals.
Returns:
A 1-arg callable that produces a decayed version of the base value when
passed the current round_num.
"""
if staircase:
def exp_decay_fn(round_num):
round_num = tf.cast(round_num, tf.float32)
return base_value * tf.pow(decay_rate, round_num // decay_steps)
else:
def exp_decay_fn(round_num):
round_num = tf.cast(round_num, tf.float32)
return base_value * tf.pow(decay_rate, round_num / decay_steps)
return exp_decay_fn
def inv_lin_schedule_builder(base_value, decay_steps, decay_rate, staircase):
"""Creates a learning rate schedule with inverse linear decay.
Args:
base_value: The base value of the quantity to decay over time.
decay_steps: A positive scalar that governs how much the value decays at a
given round number.
decay_rate: A positive scalar that governs how quickly the decay occurs.
staircase: A boolean. If set to True, the decaying occurs in discrete
intervals.
Returns:
A 1-arg callable that produces a decayed version of the base value when
passed the current round_num.
"""
if staircase:
def inv_lin_decay_fn(round_num):
round_num = tf.cast(round_num, tf.float32)
return base_value / (1.0 + decay_rate * (round_num // decay_steps))
else:
def inv_lin_decay_fn(round_num):
round_num = tf.cast(round_num, tf.float32)
return base_value / (1.0 + decay_rate * (round_num / decay_steps))
return inv_lin_decay_fn
def inv_sqrt_schedule_builder(base_value, decay_steps, decay_rate, staircase):
"""Creates a learning rate schedule with inverse square root decay.
Args:
base_value: The base value of the quantity to decay over time.
decay_steps: A positive scalar that governs how much the value decays at a
given round number.
decay_rate: A positive scalar that governs how quickly the decay occurs.
staircase: A boolean. If set to True, the decaying occurs in discrete
intervals.
Returns:
A 1-arg callable that produces a decayed version of the base value when
passed the current round_num.
"""
if staircase:
def inv_sqrt_decay_fn(round_num):
round_num = tf.cast(round_num, tf.float32)
return base_value / tf.sqrt(1.0 + decay_rate * (round_num // decay_steps))
else:
def inv_sqrt_decay_fn(round_num):
round_num = tf.cast(round_num, tf.float32)
return base_value / tf.sqrt(1.0 + decay_rate * (round_num / decay_steps))
return inv_sqrt_decay_fn
def create_lr_schedule_from_flags(
prefix: str) -> Callable[[tf.Tensor], tf.Tensor]:
"""Returns a callable learning rate schedule based on prefix flags.
This method is inteded to be paired with `define_lr_schedule_flags` using the
same `prefix`, to construct a callable learning rate schedule parameterized by
commandline flags.
This method expects the following flags to have been defined and set:
* `--<prefix>_learning_rate`
* `--<prefix>_lr_schedule`
* `--<prefix>_lr_warmup_steps`
If <prefix>_lr_schedule is not `constant`, then this method expects the
following flags to be defined as well:
* `--<prefix>_lr_decay_steps`
* `--<prefix>_lr_decay_rate`
* `--<prefix>_lr_staircase
Args:
prefix: The same string prefix passed to `define_optimizer_flags`.
Returns:
A callable that accepts a `round_num` and returns a learning rate.
"""
def prefixed(basename):
return '{}_{}'.format(prefix, basename) if prefix else basename
lr_flag_name = prefixed('learning_rate')
if flags.FLAGS[lr_flag_name] is None:
raise ValueError('Must specify flag --{!s}'.format(lr_flag_name))
lr_schedule_flag_name = prefixed('lr_schedule')
if flags.FLAGS[lr_schedule_flag_name] is None:
raise ValueError('Must specify flag --{!s}'.format(lr_schedule_flag_name))
lr_warmup_steps_flag_name = prefixed('lr_warmup_steps')
if flags.FLAGS[lr_warmup_steps_flag_name] is None:
raise ValueError(
'Must specify flag --{!s}'.format(lr_warmup_steps_flag_name))
base_lr = flags.FLAGS[lr_flag_name].value
lr_schedule_type = flags.FLAGS[lr_schedule_flag_name].value
lr_warmup_steps = flags.FLAGS[lr_warmup_steps_flag_name].value
if lr_schedule_type == 'constant':
return warmup_and_decay_schedule_builder(base_lr, lr_warmup_steps,
lambda _: base_lr)
lr_decay_steps = flags.FLAGS[prefixed('lr_decay_steps')].value
lr_decay_rate = flags.FLAGS[prefixed('lr_decay_rate')].value
lr_staircase = flags.FLAGS[prefixed('lr_staircase')].value
if lr_schedule_type == 'exp_decay':
return warmup_and_decay_schedule_builder(
base_lr, lr_warmup_steps,
exp_decay_schedule_builder(base_lr, lr_decay_steps, lr_decay_rate,
lr_staircase))
elif lr_schedule_type == 'inv_lin_decay':
return warmup_and_decay_schedule_builder(
base_lr, lr_warmup_steps,
inv_lin_schedule_builder(base_lr, lr_decay_steps, lr_decay_rate,
lr_staircase))
elif lr_schedule_type == 'inv_sqrt_decay':
return warmup_and_decay_schedule_builder(
base_lr, lr_warmup_steps,
inv_sqrt_schedule_builder(base_lr, lr_decay_steps, lr_decay_rate,
lr_staircase))
else:
raise ValueError(
'Unrecognized schedule type {!s}'.format(lr_schedule_type))