-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree_util.py
270 lines (206 loc) · 6.42 KB
/
tree_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
r"""Extended utilities for tree-like data structures"""
__all__ = [
'Namespace',
'Static',
'Auto',
'tree_repr',
]
import jax
import jax._src.tree_util as jtu
import numpy as np
from jax import Array
from textwrap import indent
from typing import *
from warnings import warn
PyTree = TypeVar('PyTree', bound=Any)
PyTreeDef = TypeVar('PyTreeDef')
def is_array(x: Any) -> bool:
return isinstance(x, np.ndarray) or isinstance(x, Array)
def is_static(x: Any) -> bool:
return isinstance(x, Static)
def is_auto(x: Any) -> bool:
return isinstance(x, Auto)
class PyTreeMeta(type):
r"""PyTree meta-class."""
def __new__(cls, *args, **kwargs) -> type:
cls = super().__new__(cls, *args, **kwargs)
if hasattr(cls, 'tree_flatten_with_keys'):
jtu.register_pytree_with_keys_class(cls)
else:
jtu.register_pytree_node_class(cls)
return cls
class Namespace(metaclass=PyTreeMeta):
r"""PyTree class for name-value mappings.
Arguments:
kwargs: A name-value mapping.
Example:
>>> ns = Namespace(a=1, b='2'); ns
Namespace(
a = 1,
b = '2'
)
>>> ns.c = [3, False]; ns
Namespace(
a = 1,
b = '2',
c = [3, False]
)
>>> jax.tree_util.tree_leaves(ns)
[1, '2', 3, False]
"""
def __init__(self, **kwargs):
self.__dict__.update(**kwargs)
def __repr__(self) -> str:
return tree_repr(self)
def tree_repr(self, **kwargs) -> str:
lines = (
f'{name} = {tree_repr(getattr(self, name), **kwargs)}'
for name in sorted(self.__dict__.keys())
)
lines = ',\n'.join(lines)
if lines:
lines = '\n' + indent(lines, ' ') + '\n'
return f'{self.__class__.__name__}({lines})'
def tree_flatten(self):
if self.__dict__:
names, values = zip(*sorted(self.__dict__.items()))
else:
names, values = (), ()
return values, names
def tree_flatten_with_keys(self):
values, names = self.tree_flatten()
keys = map(jtu.GetAttrKey, names)
return list(zip(keys, values)), names
@classmethod
def tree_unflatten(cls, names, values):
self = object.__new__(cls)
self.__dict__ = dict(zip(names, values))
return self
class Static(metaclass=PyTreeMeta):
r"""Wraps an hashable value as a leafless PyTree.
Arguments:
value: An hashable value to wrap.
Example:
>>> x = Static((0, 'one', None))
>>> x.value
(0, 'one', None)
>>> jax.tree_util.tree_leaves(x)
[]
>>> jax.tree_util.tree_structure(x)
PyTreeDef(CustomNode(Static[(0, 'one', None)], []))
"""
def __init__(self, value: Hashable):
if not isinstance(value, Hashable):
warn(f"'{type(value).__name__}' object is not hashable.")
self.value = value
def __hash__(self) -> int:
return hash((type(self), self.value))
def __repr__(self) -> str:
return self.tree_repr()
def tree_repr(self, **kwargs) -> str:
return f'{self.__class__.__name__}({tree_repr(self.value, **kwargs)})'
def tree_flatten(self):
return (), self.value
@classmethod
def tree_unflatten(cls, value, _):
self = object.__new__(cls)
self.value = value
return self
class Auto(Namespace):
r"""Subclass of :class:`Namespace` that automatically detects non-array leaves
and considers them as static.
Important:
:py:`object()` leaves are never considered static.
Arguments:
kwargs: A name-value mapping.
Example:
>>> auto = Auto(a=1, b=jnp.array(2.0)); auto
Auto(
a = 1,
b = float32[]
)
>>> auto.c = ['3', jnp.arange(4)]; auto
Auto(
a = 1,
b = float32[],
c = ['3', int32[4]]
)
>>> jax.tree_util.tree_leaves(auto) # only arrays
[Array(2., dtype=float32, weak_type=True), Array([0, 1, 2, 3], dtype=int32)]
"""
def tree_flatten(self):
values, names = super().tree_flatten()
values = jtu.tree_map(
f=lambda x: x if type(x) is object or is_array(x) or is_auto(x) else Static(x),
tree=values,
is_leaf=is_auto,
)
return values, names
@classmethod
def tree_unflatten(cls, names, values):
values = jtu.tree_map(
f=lambda x: x.value if is_static(x) else x,
tree=values,
is_leaf=lambda x: is_auto(x) or is_static(x),
)
return super().tree_unflatten(names, values)
def tree_repr(
x: PyTree,
/,
linewidth: int = 88,
typeonly: bool = True,
**kwargs,
) -> str:
r"""Creates a pretty representation of a tree.
Arguments:
x: The tree to represent.
linewidth: The maximum line width before elements of tuples, lists and dicts
are represented on separate lines.
typeonly: Whether to represent the type of arrays instead of their elements.
Returns:
The representation string.
Example:
>>> tree = [1, 'two', (True, False), list(range(5)), {'6': jnp.arange(7)}]
>>> print(tree_repr(tree))
[
1,
'two',
(True, False),
[0, 1, 2, 3, 4, 5],
{'6': int32[7]}
]
"""
kwargs.update(
linewidth=linewidth,
typeonly=typeonly,
)
if hasattr(x, 'tree_repr'):
return x.tree_repr(**kwargs)
elif isinstance(x, tuple):
bra, ket = '(', ')'
lines = [tree_repr(y, **kwargs) for y in x]
elif isinstance(x, list):
bra, ket = '[', ']'
lines = [tree_repr(y, **kwargs) for y in x]
elif isinstance(x, dict):
bra, ket = '{', '}'
lines = [
f'{tree_repr(key)}: {tree_repr(value)}'
for key, value in x.items()
]
elif is_array(x):
if typeonly:
return f'{x.dtype}{list(x.shape)}'
else:
return repr(x)
else:
return repr(x).strip(' \n')
if any('\n' in line for line in lines):
lines = ',\n'.join(lines)
elif sum(len(line) + 2 for line in lines) > linewidth:
lines = ',\n'.join(lines)
else:
lines = ', '.join(lines)
if '\n' in lines:
lines = '\n' + indent(lines, ' ') + '\n'
return f'{bra}{lines}{ket}'