-
Notifications
You must be signed in to change notification settings - Fork 0
/
_tree.py
324 lines (257 loc) · 11 KB
/
_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""Tools for manipulation of PyTrees.
:copyright: Copyright 2023-2024 by Matt L Laporte.
:license: Apache 2.0, see LICENSE for details.
"""
from collections.abc import Callable, Sequence
from functools import partial
import logging
from typing import Any, Optional, Tuple, TypeVar, TypeVarTuple, get_type_hints
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
from jaxtyping import Array, ArrayLike, PRNGKeyArray, PyTree, PyTreeDef, Shaped
logger = logging.getLogger(__name__)
T = TypeVar("T")
def filter_spec_leaves(
tree: PyTree[Any, "T"], leaf_func: Callable
) -> PyTree[bool, "T"]:
"""Returns a filter specification for tree leaves matching `leaf_func`."""
filter_spec = jax.tree_util.tree_map(lambda _: False, tree)
filter_spec = eqx.tree_at(
leaf_func,
filter_spec,
replace_fn=lambda x: True,
)
return filter_spec
def tree_index(tree: PyTree[Any, "T"], index: int) -> PyTree[Any, "T"]:
"""Returns the same PyTree, indexing out all of its array leaves."""
models_arrays, models_other = eqx.partition(tree, eqx.is_array)
return eqx.combine(
jax.tree_map(lambda x: x[index], models_arrays),
models_other,
)
def get_ensemble(
func: Callable[..., PyTree[Any, "S"]],
*args: Any,
n_ensemble: int,
key: PRNGKeyArray,
**kwargs: Any,
) -> PyTree[Any, "S"]:
"""Vmap a function over a set of random keys.
Arguments:
func: A function that returns a PyTree, and whose final argument is a key.
n_ensemble: The number of keys to split; i.e. the size of the batch
dimensions in the array leaves of the returned PyTree.
*args: The positional arguments to `func`.
key: The key to split to perform the vmap.
**kwargs: The keyword arguments to `func`.
"""
keys = jr.split(key, n_ensemble)
func_ = partial(func, *args, **kwargs)
return eqx.filter_vmap(func_)(keys)
@jax.named_scope("fbx.tree_take")
def tree_take(
tree: PyTree[Any, "T"],
indices: ArrayLike,
axis: int = 0,
**kwargs,
) -> PyTree[Any, "T"]:
"""Indexes elements out of each array leaf of a PyTree.
Any non-array leaves are returned unchanged.
Arguments:
tree: Any PyTree whose array leaves are equivalently indexable,
according to the other arguments to this function. For example,
`axis=0` could be used when the first dimension of every array leaf
is a batch dimension, and `indices` specifies a subset of examples
from the batch.
indices: The indices of the values to take from each array leaf.
axis: The axis of the array leaves over which to take their values.
Defaults to 0.
Returns:
A PyTree with the same structure as `tree`, where array leaves from `tree` have been replaced by indexed-out elements.
"""
arrays, other = eqx.partition(tree, eqx.is_array)
values = jax.tree_map(
lambda xs: jnp.take(xs, indices, axis=axis, **kwargs),
arrays,
)
return eqx.combine(values, other)
@jax.named_scope("fbx.tree_set")
def tree_set(
tree: PyTree[Any | Shaped[Array, "batch *?dims"], "T"],
items: PyTree[Any | Shaped[Array, "*?dims"], "T"],
idx: int,
) -> PyTree[Any | Shaped[Array, "batch *?dims"], "T"]:
"""Perform an out-of-place update of each array leaf of a PyTree.
Non-array leaves are simply replaced by their matching leaves in `items`.
For example, if `tree` is a PyTree of states over time, whose first dimension
is the time step, and `items` is a PyTree of states for a single time step,
this function can be used to insert the latter into the former at a given time index.
Arguments:
tree: Any PyTree whose array leaves share a first dimension of the same
length, for example a batch dimension.
items: Any PyTree with the same structure as `tree`, and whose array
leaves have the same shape as the corresponding leaves in `tree`,
but lacking the first dimension.
idx: The index along the first dimension of the array leaves of `tree`
into which to insert the array leaves of `items`.
Returns:
A PyTree with the same structure as `tree`, where the array leaves of `items` have been inserted as the `idx`-th elements of the corresponding array leaves of `tree`.
"""
arrays = eqx.filter(tree, eqx.is_array)
vals_update, other_update = eqx.partition(
items, jax.tree_map(lambda x: x is not None, arrays)
)
arrays_update = jax.tree_map(lambda xs, x: xs.at[idx].set(x), arrays, vals_update)
return eqx.combine(arrays_update, other_update)
def random_split_like_tree(
key: PRNGKeyArray,
tree: PyTree[Any, "T"],
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> PyTree[PRNGKeyArray | None, "T"]:
"""Returns a split of random keys, as leaves of a target PyTree structure.
Derived from [this](https://github.com/google/jax/discussions/9508#discussioncomment-2144076) comment
on a discussion in the JAX GitHub repository.
Arguments:
key: The random key from which to split the tree of random keys.
tree: Any PyTree.
is_leaf: An optional function that decides whether each node in `tree`
should be treated as a leaf, or traversed as a subtree.
"""
treedef = jax.tree_structure(tree, is_leaf=is_leaf)
return _random_split_like_treedef(key, treedef)
def _random_split_like_treedef(
key: PRNGKeyArray,
treedef: PyTreeDef,
):
keys = jr.split(key, treedef.num_leaves)
return jax.tree_unflatten(treedef, keys)
def tree_stack(
trees: Sequence[PyTree[Any, "T"]],
axis: int = 0,
) -> PyTree[Any, "T"]:
"""Returns a PyTree whose array leaves stack those of the PyTrees in `trees`.
!!! Example
```python
a = [jnp.array([1, 2]), jnp.array([3, 4])]
b = [jnp.array([5, 6]), jnp.array([7, 8])]
tree_stack([a, b], axis=0)
# [jnp.array([[1, 2], [5, 6]]), jnp.array([[3, 4], [7, 8]])]
```
Derived from [this](https://gist.github.com/willwhitney/dd89cac6a5b771ccff18b06b33372c75?permalink_comment_id=4634557#gistcomment-4634557)
GitHub gist.
Arguments:
trees: A sequence of PyTrees with the same structure, and whose array
leaves have the same shape.
axis: The axis along which to stack the array leaves.
"""
return jax.tree_util.tree_map(lambda *v: jnp.stack(v, axis=axis), *trees)
def tree_sum_squares(tree: PyTree[Array]) -> ArrayLike:
"""Sum the sums of squares of the leaves of a PyTree."""
return jax.tree_util.tree_reduce(
lambda x, y: x + y, jax.tree_map(lambda x: jnp.sum(x**2), tree)
)
def tree_sum_n_features(tree: PyTree[Array]) -> int:
"""Returns the sum the sizes of the last dimensions of all leaves."""
return jax.tree_util.tree_reduce(
lambda x, y: x + y, jax.tree_map(lambda x: x.shape[-1], tree)
)
def _tree_map(
f: Callable[..., Any],
tree: PyTree[Any, "T"],
*rest,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> PyTree[Any, "T"]:
"""Custom version of `jax.tree_util.tree_map`.
The only difference is that by default, it will infer `is_leaf` from
the annotation of the first argument to `f`. This is useful when mapping
user-defined functions over PyTrees of user-defined objects, when it is
acceptable to have slightly worse performance of `tree_map` in exchange
for not needing to import the objects' class to be able to manually
define `is_leaf` with an `isinstance` check.
Unfortunately this doesn't work with string annotations, which I've had to
use in places due to issues with circular imports, so it is not very
useful at the moment...
"""
if is_leaf is None:
is_leaf = lambda x: isinstance(x, next(iter(f.__annotations__.values())))
return jax.tree_map(f, tree, *rest, is_leaf=is_leaf)
def _is_module(obj: Any):
return isinstance(obj, eqx.Module)
S = TypeVar("S")
def tree_map_module(
f: Callable[[Any], S],
tree: PyTree[Any, "T"],
*rest,
) -> PyTree[S, "T"]:
"""Custom `tree_map` that treats `eqx.Module`s as leaves.
This is a convenience for performing analyses involving mapping
repeatedly over PyTrees of `eqx.Module`, where it would be repetitive
to write `is_leaf=lambda x: isinstance(x, eqx.Module)` every time.
"""
return jax.tree_map(f, tree, *rest, is_leaf=_is_module)
def tree_map_unzip(
f: Callable[..., Tuple[Any, ...]],
tree: PyTree[Any, "T"],
*rest,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Tuple[PyTree[Any, "T"], ...]:
"""Maps a tuple-valued function over a PyTree. Returns a tuple of PyTrees.
For example, for a function `f(x) -> (y, z)`, we can do `ys, zs =
tree_map_unzip(f, xs)` where `ys`, `zs` are PyTrees, whereas with a normal
`tree_map` we'd get a single PyTree of tuples `(y, z)`.
"""
results = jax.tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
return tree_unzip(results)
def tree_unzip(
tree: PyTree[Tuple[Any, ...], "T"],
) -> Tuple[PyTree[Any, "T"], ...]:
"""Unzips a PyTree of tuples into a tuple of PyTrees."""
tree_flat, treedef = jax.tree_flatten(tree, is_leaf=lambda x: isinstance(x, tuple))
tree_flat_unzipped = zip(*tree_flat)
return tuple(jax.tree_unflatten(treedef, x) for x in tree_flat_unzipped)
def tree_call(
tree: PyTree[Any, "T"],
*args: Any,
exclude: Callable = lambda _: False,
is_leaf: Optional[Callable] = None,
**kwargs: Any,
) -> PyTree[Any, "T"]:
"""Returns a tree of the return values of a PyTree's callable leaves.
!!! Note ""
Every callable leaf is passed the same `*args, **kwargs`.
Non-callable leaves, callable leaves that satisfy `exclude`, are passed through
as-is.
Arguments:
tree: Any PyTree.
*args: Positional arguments to pass to each callable leaf.
exclude: A function that returns `True` for any callable leaf that
should not be called.
**kwargs: Keyword arguments to pass to each callable leaf.
"""
callables, other_values = eqx.partition(
tree,
lambda x: isinstance(x, Callable) and not exclude(x),
is_leaf=is_leaf,
)
callables_values = jax.tree_map(
lambda x: x(*args, **kwargs),
callables,
is_leaf=is_leaf,
)
return eqx.combine(callables_values, other_values)
def tree_array_bytes(tree: PyTree) -> int:
"""Returns the total bytes of memory over all array leaves of a PyTree."""
arrays = eqx.filter(tree, eqx.is_array)
array_bytes = jax.tree_map(lambda x: x.nbytes, arrays)
return jtu.tree_reduce(
lambda x, y: x + y,
array_bytes,
)
def tree_struct_bytes(tree: PyTree[jax.ShapeDtypeStruct]) -> int:
"""Returns the total bytes of memory implied by a PyTree of `ShapeDtypeStruct`s."""
structs = eqx.filter(tree, lambda x: isinstance(x, jax.ShapeDtypeStruct))
struct_bytes = jax.tree_map(lambda x: x.size * x.dtype.itemsize, structs)
return jtu.tree_reduce(lambda x, y: x + y, struct_bytes)