/
config.py
2152 lines (1713 loc) · 79.1 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2020 The Gin-Config Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the Gin configuration framework.
Programs frequently have a number of "hyperparameters" that require variation
across different executions of the program. When the number of such parameters
grows even moderately large, or use of some parameter is deeply embedded in the
code, top-level flags become very cumbersome. This module provides an
alternative mechanism for setting such hyperparameters, by allowing injection of
parameter values for any function marked as "configurable".
For detailed documentation, please see the user guide:
https://github.com/google/gin-config/tree/master/docs/index.md
# Making functions and classes configurable
Functions and classes can be marked configurable using the `@configurable`
decorator, which associates a "configurable name" with the function or class (by
default, just the function or class name). Optionally, parameters can be
allowlisted or denylisted to mark only a subset of the function's parameters as
configurable. Once parameters have been bound (see below) to this function, any
subsequent calls will have those parameters automatically supplied by Gin.
If an argument supplied to a function by its caller (either as a positional
argument or as a keyword argument) corresponds to a parameter configured by Gin,
the caller's value will take precedence.
# A short example
Python code:
@gin.configurable
def mix_cocktail(ingredients):
...
@gin.configurable
def serve_random_cocktail(available_cocktails):
...
@gin.configurable
def drink(cocktail):
...
Gin configuration:
martini/mix_cocktail.ingredients = ['gin', 'vermouth', 'twist of lemon']
gin_and_tonic/mix_cocktail.ingredients = ['gin', 'tonic water']
serve_random_cocktail.available_cocktails = {
'martini': @martini/mix_cocktail,
'gin_and_tonic': @gin_and_tonic/mix_cocktail,
}
drink.cocktail = @serve_random_cocktail()
In the above example, there are three configurable functions: `mix_cocktail`
(with a parameter `ingredients`), `serve_random_cocktail` (with parameter
`available_cocktails`), and `drink` (with parameter `cocktail`).
When `serve_random_cocktail` is called, it will receive a dictionary
containing two scoped *references* to the `mix_cocktail` function (each scope
providing unique parameters, meaning calling the different references will
presumably produce different outputs).
On the other hand, when the `drink` function is called, it will receive the
*output* of calling `serve_random_cocktail` as the value of its `cocktail`
parameter, due to the trailing `()` in `@serve_random_cocktail()`.
"""
import collections
import contextlib
import copy
import enum
import functools
import inspect
import logging
import os
import pprint
import sys
import threading
import traceback
from typing import Optional, Sequence, Union
from gin import config_parser
from gin import selector_map
from gin import utils
class _ScopeManager(threading.local):
"""Manages currently active config scopes.
This ensures thread safety of config scope management by subclassing
`threading.local`. Scopes are tracked as a stack, where elements in the
stack are lists of the currently active scope names.
"""
def _maybe_init(self):
if not hasattr(self, '_active_scopes'):
self._active_scopes = [[]]
@property
def active_scopes(self):
self._maybe_init()
return self._active_scopes[:]
@property
def current_scope(self):
self._maybe_init()
return self._active_scopes[-1][:] # Slice to get copy.
def enter_scope(self, scope):
"""Enters the given scope, updating the list of active scopes.
Args:
scope: A list of active scope names, ordered from outermost to innermost.
"""
self._maybe_init()
self._active_scopes.append(scope)
def exit_scope(self):
"""Exits the most recently entered scope."""
self._maybe_init()
self._active_scopes.pop()
# Maintains the registry of configurable functions and classes.
_REGISTRY = selector_map.SelectorMap()
# Maps tuples of `(scope, selector)` to associated parameter values. This
# specifies the current global "configuration" set through `bind_parameter` or
# `parse_config`, but doesn't include any functions' default argument values.
_CONFIG = {}
# Keeps a set of module names that were dynamically imported via config files.
_IMPORTED_MODULES = set()
# Maps `(scope, selector)` tuples to all configurable parameter values used
# during program execution (including default argument values).
_OPERATIVE_CONFIG = {}
_OPERATIVE_CONFIG_LOCK = threading.Lock()
# Keeps track of currently active config scopes.
_SCOPE_MANAGER = _ScopeManager()
# Keeps track of hooks to run when the Gin config is finalized.
_FINALIZE_HOOKS = []
# Keeps track of whether the config is locked.
_CONFIG_IS_LOCKED = False
# Keeps track of whether "interactive mode" is enabled, in which case redefining
# a configurable is not an error.
_INTERACTIVE_MODE = False
# Keeps track of constants created via gin.constant, to both prevent duplicate
# definitions and to avoid writing them to the operative config.
_CONSTANTS = selector_map.SelectorMap()
# Keeps track of singletons created via the singleton configurable.
_SINGLETONS = {}
# Keeps track of file readers. These are functions that behave like Python's
# `open` function (can be used a context manager) and will be used to load
# config files. Each element of this list should be a tuple of `(function,
# exception_type)`, where `exception_type` is the type of exception thrown by
# `function` when a file can't be opened/read successfully.
_FILE_READERS = [(open, os.path.isfile)]
# Maintains a cache of argspecs for functions.
_ARG_SPEC_CACHE = {}
# List of location prefixes. Similar to PATH var in unix to be used to search
# for files with those prefixes.
_LOCATION_PREFIXES = ['']
# Value to represent required parameters.
REQUIRED = object()
def _find_class_construction_fn(cls):
"""Find the first __init__ or __new__ method in the given class's MRO."""
for base in type.mro(cls): # pytype: disable=wrong-arg-types
if '__init__' in base.__dict__:
return base.__init__
if '__new__' in base.__dict__:
return base.__new__
def _ensure_wrappability(fn):
"""Make sure `fn` can be wrapped cleanly by functools.wraps."""
# Handle "builtin_function_or_method", "wrapped_descriptor", and
# "method-wrapper" types.
unwrappable_types = (type(sum), type(object.__init__), type(object.__call__))
if isinstance(fn, unwrappable_types):
# pylint: disable=unnecessary-lambda
wrappable_fn = lambda *args, **kwargs: fn(*args, **kwargs)
wrappable_fn.__name__ = fn.__name__
wrappable_fn.__doc__ = fn.__doc__
wrappable_fn.__module__ = '' # These types have no __module__, sigh.
wrappable_fn.__wrapped__ = fn
return wrappable_fn
# Otherwise we're good to go...
return fn
def _decorate_fn_or_cls(decorator, fn_or_cls, subclass=False):
"""Decorate a function or class with the given decorator.
When `fn_or_cls` is a function, applies `decorator` to the function and
returns the (decorated) result.
When `fn_or_cls` is a class and the `subclass` parameter is `False`, this will
replace `fn_or_cls.__init__` with the result of applying `decorator` to it.
When `fn_or_cls` is a class and `subclass` is `True`, this will subclass the
class, but with `__init__` defined to be the result of applying `decorator` to
`fn_or_cls.__init__`. The decorated class has metadata (docstring, name, and
module information) copied over from `fn_or_cls`. The goal is to provide a
decorated class the behaves as much like the original as possible, without
modifying it (for example, inspection operations using `isinstance` or
`issubclass` should behave the same way as on the original class).
Args:
decorator: The decorator to use.
fn_or_cls: The function or class to decorate.
subclass: Whether to decorate classes by subclassing. This argument is
ignored if `fn_or_cls` is not a class.
Returns:
The decorated function or class.
"""
if not inspect.isclass(fn_or_cls): # pytype: disable=wrong-arg-types
return decorator(_ensure_wrappability(fn_or_cls))
construction_fn = _find_class_construction_fn(fn_or_cls)
if subclass:
class DecoratedClass(fn_or_cls):
__doc__ = fn_or_cls.__doc__
__module__ = fn_or_cls.__module__
DecoratedClass.__name__ = fn_or_cls.__name__
DecoratedClass.__qualname__ = fn_or_cls.__qualname__
cls = DecoratedClass
else:
cls = fn_or_cls
decorated_fn = decorator(_ensure_wrappability(construction_fn))
if construction_fn.__name__ == '__new__':
decorated_fn = staticmethod(decorated_fn)
setattr(cls, construction_fn.__name__, decorated_fn)
return cls
class Configurable(
collections.namedtuple(
'Configurable',
['fn_or_cls', 'name', 'module', 'allowlist', 'denylist', 'selector'])):
pass
def _raise_unknown_reference_error(ref, additional_msg=''):
err_str = "No configurable matching reference '@{}{}'.{}"
maybe_parens = '()' if ref.evaluate else ''
raise ValueError(err_str.format(ref.selector, maybe_parens, additional_msg))
class ConfigurableReference:
"""Represents a reference to a configurable function or class."""
def __init__(self, scoped_selector, evaluate):
self._scoped_selector = scoped_selector
self._evaluate = evaluate
scoped_selector_parts = self._scoped_selector.split('/')
self._scopes = scoped_selector_parts[:-1]
self._selector = scoped_selector_parts[-1]
self._configurable = _REGISTRY.get_match(self._selector)
if not self._configurable:
_raise_unknown_reference_error(self)
def reference_decorator(fn):
if self._scopes:
@functools.wraps(fn)
def scoping_wrapper(*args, **kwargs):
with config_scope(self._scopes):
return fn(*args, **kwargs)
return scoping_wrapper
return fn
self._scoped_configurable_fn = _decorate_fn_or_cls(
reference_decorator, self.configurable.fn_or_cls, True)
@property
def configurable(self):
return self._configurable
@property
def scoped_configurable_fn(self):
return self._scoped_configurable_fn
@property
def scopes(self):
return self._scopes
@property
def selector(self):
return self._selector
@property
def scoped_selector(self):
return self._scoped_selector
@property
def config_key(self):
return ('/'.join(self._scopes), self._configurable.selector)
@property
def evaluate(self):
return self._evaluate
def __eq__(self, other):
if isinstance(other, self.__class__):
# pylint: disable=protected-access
return (self._configurable == other._configurable and
self._evaluate == other._evaluate)
# pylint: enable=protected-access
return False
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(repr(self))
def __repr__(self):
# Check if this reference is a macro or constant, i.e. @.../macro() or
# @.../constant(). Only macros and constants correspond to the %... syntax.
configurable_fn = self._configurable.fn_or_cls
if configurable_fn in (macro, _retrieve_constant) and self._evaluate:
return '%' + '/'.join(self._scopes)
maybe_parens = '()' if self._evaluate else ''
return '@{}{}'.format(self._scoped_selector, maybe_parens)
def __deepcopy__(self, memo):
"""Dishonestly implements the __deepcopy__ special method.
When called, this returns either the `ConfigurableReference` instance itself
(when `self._evaluate` is `False`) or the result of calling the underlying
configurable. Configurable references may be deeply nested inside other
Python data structures, and by providing this implementation,
`copy.deepcopy` can be used on the containing Python structure to return a
copy replacing any `ConfigurableReference` marked for evaluation with its
corresponding configurable's output.
Args:
memo: The memoization dict (unused).
Returns:
When `self._evaluate` is `False`, returns the underlying configurable
(maybe wrapped to be called in the proper scope). When `self._evaluate` is
`True`, returns the output of calling the underlying configurable.
"""
if self._evaluate:
return self._scoped_configurable_fn()
return self._scoped_configurable_fn
class _UnknownConfigurableReference:
"""Represents a reference to an unknown configurable.
This class acts as a substitute for `ConfigurableReference` when the selector
doesn't match any known configurable.
"""
def __init__(self, selector, evaluate):
self._selector = selector.split('/')[-1]
self._evaluate = evaluate
@property
def selector(self):
return self._selector
@property
def evaluate(self):
return self._evaluate
def __deepcopy__(self, memo):
"""Dishonestly implements the __deepcopy__ special method.
See `ConfigurableReference` above. If this method is called, it means there
was an attempt to use this unknown configurable reference, so we throw an
error here.
Args:
memo: The memoization dict (unused).
Raises:
ValueError: To report that there is no matching configurable.
"""
addl_msg = '\n\n To catch this earlier, ensure gin.finalize() is called.'
_raise_unknown_reference_error(self, addl_msg)
def _validate_skip_unknown(skip_unknown):
if not isinstance(skip_unknown, (bool, list, tuple, set)):
err_str = 'Invalid value for `skip_unknown`: {}'
raise ValueError(err_str.format(skip_unknown))
def _should_skip(selector, skip_unknown):
"""Checks whether `selector` should be skipped (if unknown)."""
_validate_skip_unknown(skip_unknown)
if _REGISTRY.matching_selectors(selector):
return False # Never skip known configurables.
if isinstance(skip_unknown, (list, tuple, set)):
return selector in skip_unknown
return skip_unknown # Must be a bool by validation check.
class ParserDelegate(config_parser.ParserDelegate):
"""Delegate to handle creation of configurable references and macros."""
def __init__(self, skip_unknown=False):
self._skip_unknown = skip_unknown
def configurable_reference(self, scoped_selector, evaluate):
unscoped_selector = scoped_selector.rsplit('/', 1)[-1]
if _should_skip(unscoped_selector, self._skip_unknown):
return _UnknownConfigurableReference(scoped_selector, evaluate)
return ConfigurableReference(scoped_selector, evaluate)
def macro(self, name):
matching_selectors = _CONSTANTS.matching_selectors(name)
if matching_selectors:
if len(matching_selectors) == 1:
name = matching_selectors[0]
return ConfigurableReference(name + '/gin.constant', True)
err_str = "Ambiguous constant selector '{}', matches {}."
raise ValueError(err_str.format(name, matching_selectors))
return ConfigurableReference(name + '/gin.macro', True)
class ParsedBindingKey(
collections.namedtuple(
'ParsedBindingKey',
['scope', 'given_selector', 'complete_selector', 'arg_name'])):
"""Represents a parsed and validated binding key.
A "binding key" identifies a specific parameter (`arg_name`), of a specific
configurable (`complete_selector`), in a specific scope (`scope`), to which a
value may be bound in the global configuration. The `given_selector` field
retains information about how the original configurable selector was
specified, which can be helpful for error messages (but is ignored for the
purposes of equality and hashing).
"""
def __new__(cls, binding_key):
"""Parses and validates the given binding key.
This function will parse `binding_key` (if necessary), and ensure that the
specified parameter can be bound for the given configurable selector (i.e.,
that the parameter isn't denylisted or not allowlisted if an allowlist was
provided).
Args:
binding_key: A spec identifying a parameter of a configurable (maybe in
some scope). This should either be a string of the form
'maybe/some/scope/maybe.modules.configurable_name.parameter_name'; or a
list or tuple of `(scope, selector, arg_name)`; or another instance of
`ParsedBindingKey`.
Returns:
A new instance of `ParsedBindingKey`.
Raises:
ValueError: If no function can be found matching the configurable name
specified by `binding_key`, or if the specified parameter name is
denylisted or not in the function's allowlist (if present).
"""
if isinstance(binding_key, ParsedBindingKey):
return super(ParsedBindingKey, cls).__new__(cls, *binding_key) # pytype: disable=missing-parameter
if isinstance(binding_key, (list, tuple)):
scope, selector, arg_name = binding_key
elif isinstance(binding_key, str):
scope, selector, arg_name = config_parser.parse_binding_key(binding_key)
else:
err_str = 'Invalid type for binding_key: {}.'
raise ValueError(err_str.format(type(binding_key)))
configurable_ = _REGISTRY.get_match(selector)
if not configurable_:
raise ValueError("No configurable matching '{}'.".format(selector))
if not _might_have_parameter(configurable_.fn_or_cls, arg_name):
err_str = "Configurable '{}' doesn't have a parameter named '{}'."
raise ValueError(err_str.format(selector, arg_name))
if configurable_.allowlist and arg_name not in configurable_.allowlist:
err_str = "Configurable '{}' doesn't include kwarg '{}' in its allowlist."
raise ValueError(err_str.format(selector, arg_name))
if configurable_.denylist and arg_name in configurable_.denylist:
err_str = "Configurable '{}' has denylisted kwarg '{}'."
raise ValueError(err_str.format(selector, arg_name))
return super(ParsedBindingKey, cls).__new__(
cls,
scope=scope,
given_selector=selector,
complete_selector=configurable_.selector,
arg_name=arg_name)
@property
def config_key(self):
return self.scope, self.complete_selector
@property
def scope_selector_arg(self):
return self.scope, self.complete_selector, self.arg_name
def __equal__(self, other):
# Equality ignores the `given_selector` field, since two binding keys should
# be equal whenever they identify the same parameter.
return self.scope_selector_arg == other.scope_selector_arg
def __hash__(self):
return hash(self.scope_selector_arg)
def _format_value(value):
"""Returns `value` in a format parseable by `parse_value`, or `None`.
Simply put, This function ensures that when it returns a string value, the
following will hold:
parse_value(_format_value(value)) == value
Args:
value: The value to format.
Returns:
A string representation of `value` when `value` is literally representable,
or `None`.
"""
literal = repr(value)
try:
if parse_value(literal) == value:
return literal
except SyntaxError:
pass
return None
def _is_literally_representable(value):
"""Returns `True` if `value` can be (parseably) represented as a string.
Args:
value: The value to check.
Returns:
`True` when `value` can be represented as a string parseable by
`parse_literal`, `False` otherwise.
"""
return _format_value(value) is not None
def clear_config(clear_constants=False):
"""Clears the global configuration.
This clears any parameter values set by `bind_parameter` or `parse_config`, as
well as the set of dynamically imported modules. It does not remove any
configurable functions or classes from the registry of configurables.
Args:
clear_constants: Whether to clear constants created by `constant`. Defaults
to False.
"""
_set_config_is_locked(False)
_CONFIG.clear()
_SINGLETONS.clear()
if clear_constants:
_CONSTANTS.clear()
else:
saved_constants = _CONSTANTS.copy()
_CONSTANTS.clear() # Clear then redefine constants (re-adding bindings).
for name, value in saved_constants.items():
constant(name, value)
_IMPORTED_MODULES.clear()
_OPERATIVE_CONFIG.clear()
def bind_parameter(binding_key, value):
"""Binds the parameter value specified by `binding_key` to `value`.
The `binding_key` argument should either be a string of the form
`maybe/scope/optional.module.names.configurable_name.parameter_name`, or a
list or tuple of `(scope, selector, parameter_name)`, where `selector`
corresponds to `optional.module.names.configurable_name`. Once this function
has been called, subsequent calls (in the specified scope) to the specified
configurable function will have `value` supplied to their `parameter_name`
parameter.
Example:
@configurable('fully_connected_network')
def network_fn(num_layers=5, units_per_layer=1024):
...
def main(_):
config.bind_parameter('fully_connected_network.num_layers', 3)
network_fn() # Called with num_layers == 3, not the default of 5.
Args:
binding_key: The parameter whose value should be set. This can either be a
string, or a tuple of the form `(scope, selector, parameter)`.
value: The desired value.
Raises:
RuntimeError: If the config is locked.
ValueError: If no function can be found matching the configurable name
specified by `binding_key`, or if the specified parameter name is
denylisted or not in the function's allowlist (if present).
"""
if config_is_locked():
raise RuntimeError('Attempted to modify locked Gin config.')
pbk = ParsedBindingKey(binding_key)
fn_dict = _CONFIG.setdefault(pbk.config_key, {})
fn_dict[pbk.arg_name] = value
def query_parameter(binding_key):
"""Returns the currently bound value to the specified `binding_key`.
The `binding_key` argument should look like
'maybe/some/scope/maybe.modules.configurable_name.parameter_name'. Note that
this will not include default parameters.
Args:
binding_key: The parameter whose value should be queried.
Returns:
The value bound to the configurable/parameter combination given in
`binding_key`.
Raises:
ValueError: If no function can be found matching the configurable name
specified by `biding_key`, or if the specified parameter name is
denylisted or not in the function's allowlist (if present) or if there is
no value bound for the queried parameter or configurable.
"""
if config_parser.MODULE_RE.match(binding_key):
matching_selectors = _CONSTANTS.matching_selectors(binding_key)
if len(matching_selectors) == 1:
return _CONSTANTS[matching_selectors[0]]
elif len(matching_selectors) > 1:
err_str = "Ambiguous constant selector '{}', matches {}."
raise ValueError(err_str.format(binding_key, matching_selectors))
pbk = ParsedBindingKey(binding_key)
if pbk.config_key not in _CONFIG:
err_str = "Configurable '{}' has no bound parameters."
raise ValueError(err_str.format(pbk.given_selector))
if pbk.arg_name not in _CONFIG[pbk.config_key]:
err_str = "Configurable '{}' has no value bound for parameter '{}'."
raise ValueError(err_str.format(pbk.given_selector, pbk.arg_name))
return _CONFIG[pbk.config_key][pbk.arg_name]
def _might_have_parameter(fn_or_cls, arg_name):
"""Returns True if `arg_name` might be a valid parameter for `fn_or_cls`.
Specifically, this means that `fn_or_cls` either has a parameter named
`arg_name`, or has a `**kwargs` parameter.
Args:
fn_or_cls: The function or class to check.
arg_name: The name fo the parameter.
Returns:
Whether `arg_name` might be a valid argument of `fn`.
"""
if inspect.isclass(fn_or_cls): # pytype: disable=wrong-arg-types
fn = _find_class_construction_fn(fn_or_cls)
else:
fn = fn_or_cls
while hasattr(fn, '__wrapped__'):
fn = fn.__wrapped__
arg_spec = _get_cached_arg_spec(fn)
if arg_spec.varkw:
return True
return arg_name in arg_spec.args or arg_name in arg_spec.kwonlyargs
def _validate_parameters(fn_or_cls, arg_name_list, err_prefix):
for arg_name in arg_name_list or []:
if not _might_have_parameter(fn_or_cls, arg_name):
err_str = "Argument '{}' in {} not a parameter of '{}'."
raise ValueError(err_str.format(arg_name, err_prefix, fn_or_cls.__name__))
def _get_cached_arg_spec(fn):
"""Gets cached argspec for `fn`."""
arg_spec = _ARG_SPEC_CACHE.get(fn)
if arg_spec is None:
try:
arg_spec = inspect.getfullargspec(fn)
except TypeError:
# `fn` might be a callable object.
arg_spec = inspect.getfullargspec(fn.__call__)
_ARG_SPEC_CACHE[fn] = arg_spec
return arg_spec
def _get_supplied_positional_parameter_names(fn, args):
"""Returns the names of the supplied arguments to the given function."""
arg_spec = _get_cached_arg_spec(fn)
# May be shorter than len(args) if args contains vararg (*args) arguments.
return arg_spec.args[:len(args)]
def _get_all_positional_parameter_names(fn):
"""Returns the names of all positional arguments to the given function."""
arg_spec = _get_cached_arg_spec(fn)
args = arg_spec.args
if arg_spec.defaults:
args = args[:-len(arg_spec.defaults)]
return args
def _get_kwarg_defaults(fn):
"""Returns a dict mapping kwargs to default values for the given function."""
arg_spec = _get_cached_arg_spec(fn)
if arg_spec.defaults:
default_kwarg_names = arg_spec.args[-len(arg_spec.defaults):]
arg_vals = dict(zip(default_kwarg_names, arg_spec.defaults))
else:
arg_vals = {}
if arg_spec.kwonlydefaults:
arg_vals.update(arg_spec.kwonlydefaults)
return arg_vals
def _get_validated_required_kwargs(fn, fn_descriptor, allowlist, denylist):
"""Gets required argument names, and validates against allow/denylist."""
kwarg_defaults = _get_kwarg_defaults(fn)
required_kwargs = []
for kwarg, default in kwarg_defaults.items():
if default is REQUIRED:
if denylist and kwarg in denylist:
err_str = "Argument '{}' of {} marked REQUIRED but denylisted."
raise ValueError(err_str.format(kwarg, fn_descriptor))
if allowlist and kwarg not in allowlist:
err_str = "Argument '{}' of {} marked REQUIRED but not allowlisted."
raise ValueError(err_str.format(kwarg, fn_descriptor))
required_kwargs.append(kwarg)
return required_kwargs
def _get_default_configurable_parameter_values(fn, allowlist, denylist):
"""Retrieve all default values for configurable parameters of a function.
Any parameters included in the supplied denylist, or not included in the
supplied allowlist, are excluded.
Args:
fn: The function whose parameter values should be retrieved.
allowlist: The allowlist (or `None`) associated with the function.
denylist: The denylist (or `None`) associated with the function.
Returns:
A dictionary mapping configurable parameter names to their default values.
"""
arg_vals = _get_kwarg_defaults(fn)
# Now, eliminate keywords that are denylisted, or aren't allowlisted (if
# there's an allowlist), or aren't representable as a literal value.
for k in list(arg_vals):
allowlist_fail = allowlist and k not in allowlist
denylist_fail = denylist and k in denylist
representable = _is_literally_representable(arg_vals[k])
if allowlist_fail or denylist_fail or not representable:
del arg_vals[k]
return arg_vals
def _order_by_signature(fn, arg_names):
"""Orders given `arg_names` based on their order in the signature of `fn`."""
arg_spec = _get_cached_arg_spec(fn)
all_args = list(arg_spec.args)
if arg_spec.kwonlyargs:
all_args.extend(arg_spec.kwonlyargs)
ordered = [arg for arg in all_args if arg in arg_names]
# Handle any leftovers corresponding to varkwargs in the order we got them.
ordered.extend([arg for arg in arg_names if arg not in ordered])
return ordered
def current_scope():
return _SCOPE_MANAGER.current_scope
def current_scope_str():
return '/'.join(current_scope())
@contextlib.contextmanager
def config_scope(name_or_scope):
"""Opens a new configuration scope.
Provides a context manager that opens a new explicit configuration
scope. Explicit configuration scopes restrict parameter bindings to only
certain sections of code that run within the scope. Scopes can be nested to
arbitrary depth; any configurable functions called within a scope inherit
parameters defined by higher level scopes.
For example, suppose a function named `preprocess_images` is called in two
places in a codebase: Once when loading data for a training task, and once
when loading data for an evaluation task:
def load_training_data():
...
with gin.config_scope('train'):
images = preprocess_images(images)
...
def load_eval_data():
...
with gin.config_scope('eval'):
images = preprocess_images(images)
...
By using a `config_scope` to wrap each invocation of `preprocess_images` as
above, it is possible to use Gin to supply specific parameters to each. Here
is a possible configuration for the above example:
preprocess_images.crop_size = [64, 64]
preprocess_images.normalize_image = True
train/preprocess_images.crop_location = 'random'
train/preprocess_images.random_flip_lr = True
eval/preprocess_images.crop_location = 'center'
The `crop_size` and `normalize_image` parameters above will be shared by both
the `train` and `eval` invocations; only `train` will receive
`random_flip_lr`, and the two invocations receive different values for
`crop_location`.
Passing `None` or `''` to `config_scope` will temporarily clear all currently
active scopes (within the `with` block; they will be restored afterwards).
Args:
name_or_scope: A name for the config scope, or an existing scope (e.g.,
captured from `with gin.config_scope(...) as scope`), or `None` to clear
currently active scopes.
Raises:
ValueError: If `name_or_scope` is not a list, string, or None.
Yields:
The resulting config scope (a list of all active scope names, ordered from
outermost to innermost).
"""
try:
valid_value = True
if isinstance(name_or_scope, list):
new_scope = name_or_scope
elif name_or_scope and isinstance(name_or_scope, str):
new_scope = current_scope() # Returns a copy.
new_scope.extend(name_or_scope.split('/'))
else:
valid_value = name_or_scope in (None, '')
new_scope = []
# Append new_scope first. It will be popped in the finally block if an
# exception is raised below.
_SCOPE_MANAGER.enter_scope(new_scope)
scopes_are_valid = map(config_parser.MODULE_RE.match, new_scope)
if not valid_value or not all(scopes_are_valid):
err_str = 'Invalid value for `name_or_scope`: {}.'
raise ValueError(err_str.format(name_or_scope))
yield new_scope
finally:
_SCOPE_MANAGER.exit_scope()
def _make_gin_wrapper(fn, fn_or_cls, name, selector, allowlist, denylist):
"""Creates the final Gin wrapper for the given function.
Args:
fn: The function that will be wrapped.
fn_or_cls: The original function or class being made configurable. This will
differ from `fn` when making a class configurable, in which case `fn` will
be the constructor/new function, while `fn_or_cls` will be the class.
name: The name given to the configurable.
selector: The full selector of the configurable (name including any module
components).
allowlist: An allowlist of configurable parameters.
denylist: A denylist of non-configurable parameters.
Returns:
The Gin wrapper around `fn`.
"""
# At this point we have access to the final function to be wrapped, so we
# can cache a few things here.
fn_descriptor = "'{}' ('{}')".format(name, fn_or_cls)
signature_required_kwargs = _get_validated_required_kwargs(
fn, fn_descriptor, allowlist, denylist)
initial_configurable_defaults = _get_default_configurable_parameter_values(
fn, allowlist, denylist)
@functools.wraps(fn)
def gin_wrapper(*args, **kwargs):
"""Supplies fn with parameter values from the configuration."""
scope_components = current_scope()
new_kwargs = {}
for i in range(len(scope_components) + 1):
partial_scope_str = '/'.join(scope_components[:i])
new_kwargs.update(_CONFIG.get((partial_scope_str, selector), {}))
gin_bound_args = list(new_kwargs.keys())
scope_str = partial_scope_str
arg_names = _get_supplied_positional_parameter_names(fn, args)
for arg in args[len(arg_names):]:
if arg is REQUIRED:
raise ValueError(
'gin.REQUIRED is not allowed for unnamed (vararg) parameters. If '
'the function being called is wrapped by a non-Gin decorator, '
'try explicitly providing argument names for positional '
'parameters.')
required_arg_names = []
required_arg_indexes = []
for i, arg in enumerate(args[:len(arg_names)]):
if arg is REQUIRED:
required_arg_names.append(arg_names[i])
required_arg_indexes.append(i)
caller_required_kwargs = []
for kwarg, value in kwargs.items():
if value is REQUIRED:
caller_required_kwargs.append(kwarg)
# If the caller passed arguments as positional arguments that correspond to
# a keyword arg in new_kwargs, remove the keyword argument from new_kwargs
# to let the caller win and avoid throwing an error. Unless it is an arg
# marked as REQUIRED.
for arg_name in arg_names:
if arg_name not in required_arg_names:
new_kwargs.pop(arg_name, None)
# Get default values for configurable parameters.
operative_parameter_values = initial_configurable_defaults.copy()
# Update with the values supplied via configuration.
operative_parameter_values.update(new_kwargs)
# Remove any values from the operative config that are overridden by the
# caller. These can't be configured, so they won't be logged. We skip values
# that are marked as REQUIRED.
for k in arg_names:
if k not in required_arg_names:
operative_parameter_values.pop(k, None)
for k in kwargs:
if k not in caller_required_kwargs:
operative_parameter_values.pop(k, None)
# An update is performed in case another caller of this same configurable
# object has supplied a different set of arguments. By doing an update, a
# Gin-supplied or default value will be present if it was used (not
# overridden by the caller) at least once.
with _OPERATIVE_CONFIG_LOCK:
op_cfg = _OPERATIVE_CONFIG.setdefault((scope_str, selector), {})