/
configuration.py
130 lines (91 loc) · 3.58 KB
/
configuration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import print_function
import contextlib
import sys
import threading
class GlobalConfig(object):
"""The plain object that represents the global configuration of Chainer."""
def show(self, file=sys.stdout):
"""show(file=sys.stdout)
Prints the global config entries.
The entries are sorted in the lexicographical order of the entry name.
Args:
file: Output file-like object.
"""
keys = sorted(self.__dict__)
_print_attrs(self, keys, file)
class LocalConfig(object):
"""Thread-local configuration of Chainer.
This class implements the local configuration. When a value is set to this
object, the configuration is only updated in the current thread. When a
user tries to access an attribute and there is no local value, it
automatically retrieves a value from the global configuration.
"""
def __init__(self, global_config):
super(LocalConfig, self).__setattr__('_global', global_config)
super(LocalConfig, self).__setattr__('_local', threading.local())
def __delattr__(self, name):
delattr(self._local, name)
def __getattr__(self, name):
if hasattr(self._local, name):
return getattr(self._local, name)
return getattr(self._global, name)
def __setattr__(self, name, value):
setattr(self._local, name, value)
def show(self, file=sys.stdout):
"""show(file=sys.stdout)
Prints the config entries.
The entries are sorted in the lexicographical order of the entry names.
Args:
file: Output file-like object.
.. admonition:: Example
You can easily print the list of configurations used in
the current thread.
>>> chainer.config.show() # doctest: +SKIP
debug False
enable_backprop True
train True
type_check True
"""
keys = sorted(set(self._global.__dict__) | set(self._local.__dict__))
_print_attrs(self, keys, file)
def _print_attrs(obj, keys, file):
max_len = max(len(key) for key in keys)
for key in keys:
spacer = ' ' * (max_len - len(key))
print(u'{} {}{}'.format(key, spacer, getattr(obj, key)), file=file)
global_config = GlobalConfig()
'''Global configuration of Chainer.
It is an instance of :class:`chainer.configuration.GlobalConfig`.
See :ref:`configuration` for details.
'''
config = LocalConfig(global_config)
'''Thread-local configuration of Chainer.
It is an instance of :class:`chainer.configuration.LocalConfig`, and is
referring to :data:`~chainer.global_config` as its default configuration.
See :ref:`configuration` for details.
'''
@contextlib.contextmanager
def using_config(name, value, config=config):
"""using_config(name, value, config=chainer.config)
Context manager to temporarily change the thread-local configuration.
Args:
name (str): Name of the configuration to change.
value: Temporary value of the configuration entry.
config (~chainer.configuration.LocalConfig): Configuration object.
Chainer's thread-local configuration is used by default.
.. seealso::
:ref:`configuration`
"""
if hasattr(config._local, name):
old_value = getattr(config, name)
setattr(config, name, value)
try:
yield
finally:
setattr(config, name, old_value)
else:
setattr(config, name, value)
try:
yield
finally:
delattr(config, name)