-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
configuration.py
170 lines (119 loc) · 4.74 KB
/
configuration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import sys
import threading
import typing as tp # NOQA
from chainer import types # NOQA
if types.TYPE_CHECKING:
import numpy # NOQA
from chainer.graph_optimizations import static_graph # NOQA
class GlobalConfig(object):
debug = None # type: bool
cudnn_deterministic = None # type: bool
warn_nondeterministic = None # type: bool
enable_backprop = None # type: bool
keep_graph_on_report = None # type: bool
train = None # type: bool
type_check = None # type: bool
use_cudnn = None # type: str
use_cudnn_tensor_core = None # type: str
autotune = None # type: bool
schedule_func = None # type: tp.Optional[static_graph.StaticScheduleFunction] # NOQA
use_ideep = None # type: str
lazy_grad_sum = None # type: bool
cudnn_fast_batch_normalization = None # type: bool
dtype = None # type: numpy.dtype
in_recomputing = None # type: bool
"""The plain object that represents the global configuration of Chainer."""
def show(self, file=sys.stdout):
"""show(file=sys.stdout)
Prints the global config entries.
The entries are sorted in the lexicographical order of the entry name.
Args:
file: Output file-like object.
"""
keys = sorted(self.__dict__)
_print_attrs(self, keys, file)
class LocalConfig(object):
"""Thread-local configuration of Chainer.
This class implements the local configuration. When a value is set to this
object, the configuration is only updated in the current thread. When a
user tries to access an attribute and there is no local value, it
automatically retrieves a value from the global configuration.
"""
def __init__(self, global_config):
super(LocalConfig, self).__setattr__('_global', global_config)
super(LocalConfig, self).__setattr__('_local', threading.local())
def __delattr__(self, name):
delattr(self._local, name)
def __getattr__(self, name):
dic = self._local.__dict__
if name in dic:
return dic[name]
return getattr(self._global, name)
def __setattr__(self, name, value):
setattr(self._local, name, value)
def show(self, file=sys.stdout):
"""show(file=sys.stdout)
Prints the config entries.
The entries are sorted in the lexicographical order of the entry names.
Args:
file: Output file-like object.
.. admonition:: Example
You can easily print the list of configurations used in
the current thread.
>>> chainer.config.show() # doctest: +SKIP
debug False
enable_backprop True
train True
type_check True
"""
keys = sorted(set(self._global.__dict__) | set(self._local.__dict__))
_print_attrs(self, keys, file)
def _print_attrs(obj, keys, file):
max_len = max(len(key) for key in keys)
for key in keys:
spacer = ' ' * (max_len - len(key))
file.write(u'{} {}{}\n'.format(key, spacer, getattr(obj, key)))
global_config = GlobalConfig()
'''Global configuration of Chainer.
It is an instance of :class:`chainer.configuration.GlobalConfig`.
See :ref:`configuration` for details.
'''
config = LocalConfig(global_config)
'''Thread-local configuration of Chainer.
It is an instance of :class:`chainer.configuration.LocalConfig`, and is
referring to :data:`~chainer.global_config` as its default configuration.
See :ref:`configuration` for details.
'''
class _ConfigContext(object):
is_local = False
old_value = None
def __init__(self, config, name, value):
self.config = config
self.name = name
self.value = value
def __enter__(self):
name = self.name
value = self.value
config = self.config
is_local = hasattr(config._local, name)
if is_local:
self.old_value = getattr(config, name)
self.is_local = is_local
setattr(config, name, value)
def __exit__(self, typ, value, traceback):
if self.is_local:
setattr(self.config, self.name, self.old_value)
else:
delattr(self.config, self.name)
def using_config(name, value, config=config):
"""using_config(name, value, config=chainer.config)
Context manager to temporarily change the thread-local configuration.
Args:
name (str): Name of the configuration to change.
value: Temporary value of the configuration entry.
config (~chainer.configuration.LocalConfig): Configuration object.
Chainer's thread-local configuration is used by default.
.. seealso::
:ref:`configuration`
"""
return _ConfigContext(config, name, value)