-
Notifications
You must be signed in to change notification settings - Fork 88
/
utils.py
266 lines (212 loc) · 7.87 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""
Copyright (C) 2013 Stefan Pfenninger.
Licensed under the Apache 2.0 License (see LICENSE file).
utils.py
~~~~~~~~
Various utility functions, particularly the AttrDict class (a subclass
of regular dict) used for managing model configuration.
"""
from __future__ import print_function
from __future__ import division
from contextlib import contextmanager
from cStringIO import StringIO
from functools import partial
import yaml
class AttrDict(dict):
"""A subclass of ``dict`` with key access by attributes::
d = AttrDict({'a': 1, 'b': 2})
d.a == 1 # True
Includes a range of additional methods to read and write to YAML,
and to deal with nested keys.
"""
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __init__(self, source_dict=None):
super(AttrDict, self).__init__()
if isinstance(source_dict, dict):
self.init_from_dict(source_dict)
def init_from_dict(self, d):
"""Initialize a new AttrDict from the given dict. Handles any
nested dicts by turning them into AttrDicts too::
d = AttrDict({'a': 1, 'b': {'x': 1, 'y': 2}})
d.b.x == 1 # True
"""
for k, v in d.iteritems():
if isinstance(k, int):
k = str(k) # Keys must be strings, not ints
if isinstance(v, dict):
self[k] = AttrDict(v)
else:
self.set_key(k, v)
@classmethod
def from_yaml(cls, f):
"""Returns an AttrDict initialized from the given path or
file object ``f``, which must point to a YAML file.
"""
if isinstance(f, str):
with open(f, 'r') as src:
return cls(yaml.load(src))
else:
return cls(yaml.load(f))
@classmethod
def from_yaml_string(cls, string):
"""Returns an AttrDict initialized from the given string, which
must be valid YAML.
"""
return cls(yaml.load(string))
def set_key(self, key, value):
"""Set the given ``key`` to the given ``value``. Handles nested
keys, e.g.::
d = AttrDict()
d.set_key('foo.bar', 1)
d.foo.bar == 1 # True
"""
if '.' in key:
key, remainder = key.split('.', 1)
try:
self[key].set_key(remainder, value)
except KeyError:
self[key] = AttrDict()
self[key].set_key(remainder, value)
except AttributeError:
if self[key] is None: # If the value is None, we replace it
self[key] = AttrDict()
self[key].set_key(remainder, value)
# Else there is probably something there, and we don't just
# want to overwrite so stop and warn the user
else:
raise UserWarning('Cannot set nested key on non-dict key.')
else:
self[key] = value
def get_key(self, key, default=None):
"""Looks up the given ``key``. Like set_key(), deals with nested
keys.
If default is given and not None (it may be, for example, False),
returns default if encounters a KeyError during lookup
"""
if '.' in key:
# Nested key of form "foo.bar"
key, remainder = key.split('.', 1)
if default is not None:
try:
value = self[key].get_key(remainder)
except KeyError:
return default
else:
value = self[key].get_key(remainder)
else:
# Single, non-nested key of form "foo"
if default is not None:
return self.get(key, default)
else:
return self[key]
return value
def as_dict(self):
"""Return the AttrDict as a pure dict (with nested dicts if
necessary).
"""
d = {}
for k, v in self.iteritems():
if isinstance(v, AttrDict):
d[k] = v.as_dict()
else:
d[k] = v
return d
def to_yaml(self, path):
"""Saves the AttrDict to the given path as YAML file"""
with open(path, 'w') as f:
yaml.dump(self.as_dict(), f)
def keys_nested(self, subkeys_as='list'):
"""Returns all keys in the AttrDict, including the keys of
nested subdicts (which may be either regular dicts or AttrDicts).
If ``subkeys_as='list'`` (default), then a (sorted) list of
all keys is returned, in the form ``['a', 'b.b1']``.
If ``subkeys_as='dict'``, a list containing keys and dicts of
subkeys is returned, in the form ``['a', {'b': [b1]}]``. The list
is sorted (subdicts first, then string keys).
"""
keys = []
for k, v in self.iteritems():
if isinstance(v, AttrDict) or isinstance(v, dict):
if subkeys_as == 'list':
keys.extend([k + '.' + kk for kk in v.keys_nested()])
elif subkeys_as == 'dict':
keys.append({k: v.keys_nested(subkeys_as=subkeys_as)})
else:
keys.append(k)
return sorted(keys)
def union(self, other):
"""
Merges the AttrDict in-place with the passed ``other`` dict or
AttrDict. Keys in ``other`` take precedence, and nested keys
are properly handled.
"""
for k in other.keys_nested():
self.set_key(k, other.get_key(k))
@contextmanager
def capture_output():
"""Capture stdout and stderr output of a wrapped function::
with capture_output() as out:
# do things that create stdout or stderr output
Returns a list with the captured strings: ``[stderr, stdout]``
"""
import sys
old_out, old_err = sys.stdout, sys.stderr
try:
out = [StringIO(), StringIO()]
sys.stdout, sys.stderr = out
yield out
finally:
sys.stdout, sys.stderr = old_out, old_err
out[0] = out[0].getvalue()
out[1] = out[1].getvalue()
def memoize(f):
""" Memoization decorator for a function taking one or more
arguments.
"""
class MemoDict(dict):
def __getitem__(self, *key):
return dict.__getitem__(self, key)
def __missing__(self, key):
ret = self[key] = f(*key)
return ret
return MemoDict().__getitem__
class memoize_instancemethod(object):
"""Cache the return value of a method
Source: http://code.activestate.com/recipes/577452/
This class is meant to be used as a decorator of methods. The return
value from a given method invocation will be cached on the instance
whose method was invoked. All arguments passed to a method decorated
with memoize must be hashable.
If a memoized method is invoked directly on its class the result
will not be cached. Instead the method will be invoked like a
static method.
"""
def __init__(self, func):
self.func = func
def __get__(self, obj, objtype=None):
if obj is None:
return self.func
return partial(self, obj)
def __call__(self, *args, **kw):
obj = args[0]
try:
cache = obj.__cache
except AttributeError:
cache = obj.__cache = {}
key = (self.func, args[1:], frozenset(kw.items()))
try:
res = cache[key]
except KeyError:
res = cache[key] = self.func(*args, **kw)
return res
def replace(string, placeholder, replacement):
"""Replace all occurences of ``{{placeholder}}`` or
``{{ placeholder }}`` in ``string`` with ``replacement``.
"""
placeholders = ['{{ ' + placeholder + ' }}',
'{{' + placeholder + '}}']
for p in placeholders:
string = string.replace(p, replacement)
return string