-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree_util.py
441 lines (335 loc) · 11.8 KB
/
tree_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
r"""Extended utilities for tree-like data structures"""
__all__ = [
'Namespace',
'Partial',
'Static',
'tree_mask',
'tree_unmask',
'tree_partition',
'tree_combine',
'tree_repr',
]
import jax._src.tree_util as jtu
import numpy as np
from jax import Array
from textwrap import indent
from typing import Any, Callable, Dict, Hashable, Tuple, TypeVar, Union
from warnings import warn
PyTree = TypeVar('PyTree', bound=Any)
PyTreeDef = TypeVar('PyTreeDef')
def is_array(x: Any) -> bool:
return isinstance(x, np.ndarray) or isinstance(x, Array)
class PyTreeMeta(type):
r"""PyTree meta-class."""
def __new__(cls, *args, **kwargs) -> type:
cls = super().__new__(cls, *args, **kwargs)
if hasattr(cls, 'tree_flatten_with_keys'):
jtu.register_pytree_with_keys_class(cls)
else:
jtu.register_pytree_node_class(cls)
return cls
class Namespace(metaclass=PyTreeMeta):
r"""PyTree class for name-value mappings.
Arguments:
kwargs: A name-value mapping.
Example:
>>> tree = Namespace(a=1, b='2'); tree
Namespace(
a = 1,
b = '2'
)
>>> tree.c = [3, False]; tree
Namespace(
a = 1,
b = '2',
c = [3, False]
)
>>> jax.tree_util.tree_leaves(tree)
[1, '2', 3, False]
"""
def __init__(self, **kwargs):
self.__dict__.update(**kwargs)
def __repr__(self) -> str:
return tree_repr(self)
def tree_repr(self, **kwargs) -> str:
lines = (
f'{name} = {tree_repr(getattr(self, name), **kwargs)}'
for name in sorted(self.__dict__.keys())
)
lines = ',\n'.join(lines)
if lines:
lines = '\n' + indent(lines, ' ') + '\n'
return f'{self.__class__.__name__}({lines})'
def tree_flatten(self):
if self.__dict__:
names, values = zip(*sorted(self.__dict__.items()))
else:
names, values = (), ()
return values, names
def tree_flatten_with_keys(self):
values, names = self.tree_flatten()
keys = map(jtu.GetAttrKey, names)
return list(zip(keys, values)), names
@classmethod
def tree_unflatten(cls, names, values):
self = object.__new__(cls)
self.__dict__ = dict(zip(names, values))
return self
class Partial(Namespace):
r"""A version of :class:`functools.partial` that is a PyTree.
Arguments:
func: A function.
args: Positional arguments for future calls.
kwds: Keyword arguments for future calls.
Examples:
>>> increment = Partial(jax.numpy.add, 1)
>>> increment(2)
Array(3, dtype=int32, weak_type=True)
>>> println = Partial(print, sep='\n')
>>> println('Hello', 'World!')
Hello
World!
"""
def __init__(self, func: Callable, *args: Any, **kwds: Any):
self.func = func
self.args = args
self.kwds = kwds
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.func(*self.args, *args, **self.kwds, **kwds)
class Static(metaclass=PyTreeMeta):
r"""Wraps an hashable value as a leafless PyTree.
Arguments:
value: An hashable value to wrap.
Example:
>>> tree = Static((0, 'one', None))
>>> tree.value
(0, 'one', None)
>>> jax.tree_util.tree_leaves(tree)
[]
>>> jax.tree_util.tree_structure(tree)
PyTreeDef(CustomNode(Static[(0, 'one', None)], []))
"""
def __init__(self, value: Hashable):
if not isinstance(value, Hashable):
warn(f"considering a non-hashable object ('{type(value).__name__}') static could lead to frequent JIT recompilations.") # fmt: off
if callable(value):
if '<lambda>' in value.__qualname__ or '<locals>' in value.__qualname__:
warn(f"considering a local function ('{value.__qualname__}') static could lead to frequent JIT recompilations.") # fmt: off
self.value = value
def __eq__(self, other: Any) -> bool:
return type(self) is type(other) and self.value == other.value
def __hash__(self) -> int:
return hash((type(self), self.value))
def __repr__(self) -> str:
return self.tree_repr()
def tree_repr(self, **kwargs) -> str:
return f'{self.__class__.__name__}({tree_repr(self.value, **kwargs)})'
def tree_flatten(self):
return (), self.value
def tree_flatten_with_keys(self):
return self.tree_flatten()
@classmethod
def tree_unflatten(cls, value, _):
self = object.__new__(cls)
self.value = value
return self
def tree_mask(
tree: PyTree,
is_static: Callable[[Any], bool] = None,
) -> PyTree:
r"""Masks the static leaves of a tree.
The structure of the tree remains unchanged, but leaves that are considered static
are wrapped into a :class:`Static` instance, which hides them from
:func:`jax.tree_util.tree_leaves` and :func:`jax.tree_util.tree_map`.
See also:
:func:`tree_unmask`
Arguments:
tree: The tree to mask.
is_static: A predicate for what to consider static. If :py:`None`,
all non-array leaves are considered static.
Returns:
The masked tree.
Example:
>>> tree = [1, jax.numpy.arange(2), 'three']
>>> jax.tree_util.tree_leaves(tree)
[1, Array([0, 1], dtype=int32), 'three']
>>> tree = tree_mask(tree); tree
[Static(1), Array([0, 1], dtype=int32), Static('three')]
>>> jax.tree_util.tree_leaves(tree)
[Array([0, 1], dtype=int32)]
"""
if is_static is None:
is_static = lambda x: not is_array(x)
return jtu.tree_map(
f=lambda x: Static(x) if is_static(x) else x,
tree=tree,
)
def tree_unmask(tree: PyTree) -> PyTree:
r"""Unmasks the static leaves of a masked tree.
See also:
:func:`tree_mask`
Arguments:
tree: The masked tree to unmask.
Returns:
The unmasked tree.
Example:
>>> tree = [Static(1), jax.numpy.arange(2), Static('three')]
>>> tree_unmask(tree)
[1, Array([0, 1], dtype=int32), 'three']
"""
is_static = lambda x: type(x) is Static
return jtu.tree_map(
f=lambda x: x.value if is_static(x) else x,
tree=tree,
is_leaf=is_static,
)
def tree_partition(
tree: PyTree,
*filters: Union[type, Callable[[Any], bool]],
is_leaf: Callable[[Any], bool] = None,
) -> Tuple[PyTreeDef, Dict[str, Any]]:
r"""Flattens a tree and partitions the leaves.
The leaves are partitioned into a set of path-leaf mappings. Each mapping contains
the leaves of the subset of nodes satisfying the corresponding filtering constraint.
The last mapping is dedicated to leaves that do not satisfy any constraint.
See also:
:func:`tree_combine`
Arguments:
tree: The tree to flatten.
filters: A set of filtering constraints. Types are transformed into
:py:`isinstance` constraints.
is_leaf: A predicate for what to consider as a leaf.
Returns:
The tree definition and leaf partitions.
Example:
>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False])
>>> treedef, leaves = tree_partition(tree)
>>> leaves
{'.a': 1, '.b': Array([0, 1], dtype=int32), '.c[0]': 'three', '.c[1]': False}
>>> treedef, arrays, others = tree_partition(tree, jax.Array)
>>> arrays
{'.b': Array([0, 1], dtype=int32)}
>>> others
{'.a': 1, '.c[0]': 'three', '.c[1]': False}
"""
treedef = jtu.tree_structure(tree, is_leaf)
leaves = [{} for _ in filters] + [{}]
def factory(filtr):
if isinstance(filtr, type):
return lambda x: isinstance(x, filtr)
else:
return filtr
predicates = list(map(factory, filters))
if is_leaf is None:
is_node = lambda x: any(p(x) for p in predicates)
else:
is_node = lambda x: any(p(x) for p in predicates) or is_leaf(x)
for path, node in jtu.tree_leaves_with_path(tree, is_node):
for i, p in enumerate(predicates):
if p(node):
break
else:
i = -1
for subpath, leaf in jtu.tree_leaves_with_path(node, is_leaf):
leaves[i][jtu.keystr(path + subpath)] = leaf
return treedef, *leaves
def tree_combine(
treedef: PyTreeDef,
*leaves: Dict[str, Any],
) -> PyTree:
r"""Reconstructs a tree from the tree definition and leaf partitions.
See also:
:func:`tree_partition`
Arguments:
treedef: The tree definition.
leaves: The set of leaf partitions.
Returns:
The reconstructed tree.
Example:
>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False])
>>> treedef, arrays, others = tree_partition(tree, jax.Array)
>>> others = {key: str(leaf).upper() for key, leaf in others.items()}
>>> tree_combine(treedef, arrays, others)
Namespace(
a = '1',
b = int32[2],
c = ['THREE', 'FALSE']
)
"""
missing = []
leaves = {key: leaf for partition in leaves for key, leaf in partition.items()}
def f(path, leaf):
key = jtu.keystr(path)
if key in leaves:
leaf = leaves.pop(key)
else:
missing.append(key)
return leaf
tree = jtu.tree_unflatten(treedef, [object()] * treedef.num_leaves)
tree = jtu.tree_map_with_path(f, tree)
if missing:
keys = ', '.join(f'"{key}"' for key in missing)
raise KeyError(f"Missing key(s) in leaves: {keys}.")
if leaves:
keys = ', '.join(f'"{key}"' for key in leaves)
raise KeyError(f"Unexpected key(s) in leaves: {keys}.")
return tree
def tree_repr(
tree: PyTree,
linewidth: int = 88,
typeonly: bool = True,
**kwargs,
) -> str:
r"""Creates a pretty representation of a tree.
Arguments:
tree: The tree to represent.
linewidth: The maximum line width before elements of tuples, lists and dicts
are represented on separate lines.
typeonly: Whether to represent the type of arrays instead of their elements.
Returns:
The representation string.
Example:
>>> tree = [1, 'two', (True, False), list(range(5)), {'6': jnp.arange(7)}]
>>> print(tree_repr(tree))
[
1,
'two',
(True, False),
[0, 1, 2, 3, 4, 5],
{'6': int32[7]}
]
"""
kwargs.update(
linewidth=linewidth,
typeonly=typeonly,
)
if hasattr(tree, 'tree_repr'):
return tree.tree_repr(**kwargs)
elif isinstance(tree, tuple):
if hasattr(tree, '_fields'):
bra, ket = f'{type(tree).__name__}(', ')'
lines = [
f'{field}={tree_repr(value, **kwargs)}' for field, value in tree._asdict().items()
]
else:
bra, ket = '(', ')'
lines = [tree_repr(x, **kwargs) for x in tree]
elif isinstance(tree, list):
bra, ket = '[', ']'
lines = [tree_repr(x, **kwargs) for x in tree]
elif isinstance(tree, dict):
bra, ket = '{', '}'
lines = [f'{repr(key)}: {tree_repr(value, **kwargs)}' for key, value in tree.items()]
elif is_array(tree) and typeonly:
return f'{tree.dtype}{list(tree.shape)}'
else:
return repr(tree).strip(' \n')
if any('\n' in line for line in lines):
lines = ',\n'.join(lines)
elif sum(len(line) + 2 for line in lines) > linewidth:
lines = ',\n'.join(lines)
else:
lines = ', '.join(lines)
if '\n' in lines:
lines = '\n' + indent(lines, ' ') + '\n'
return f'{bra}{lines}{ket}'