/
tree_util.py
287 lines (214 loc) · 8.18 KB
/
tree_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tree utilities."""
import functools
import itertools
import operator
import jax
from jax import tree_util as tu
import jax.numpy as jnp
import numpy as onp
tree_flatten = tu.tree_flatten
tree_leaves = tu.tree_leaves
tree_map = tu.tree_map
tree_reduce = tu.tree_reduce
tree_unflatten = tu.tree_unflatten
def broadcast_pytrees(*trees):
"""Broadcasts leaf pytrees to match treedef shared by the other arguments.
Args:
*trees: A `Sequence` of pytrees such that all elements that are *not* leaf
pytrees (i.e. single arrays) have the same treedef.
Returns:
The input `Sequence` of pytrees `*trees` with leaf pytrees (i.e. single
arrays) replaced by pytrees matching the treedef of non-shallow elements via
broadcasting.
Raises:
ValueError: If two or more pytrees in `*trees` that are not leaf pytrees
differ in their structure (treedef).
"""
leaves, treedef, is_leaf = [], None, []
for tree in trees:
leaves_i, treedef_i = tu.tree_flatten(tree)
is_leaf_i = tu.treedef_is_leaf(treedef_i)
if not is_leaf_i:
treedef = treedef or treedef_i
if treedef_i != treedef:
raise ValueError('Pytrees are not broadcastable.: '
f'{treedef} != {treedef_i}')
leaves.append(leaves_i)
is_leaf.append(is_leaf_i)
if treedef is not None:
max_num_leaves = max(len(leaves_i) for leaves_i in leaves)
broadcast_leaf = lambda leaf: itertools.repeat(leaf[0], max_num_leaves)
leaves = [broadcast_leaf(leaves_i) if is_leaf_i else leaves_i
for (leaves_i, is_leaf_i) in zip(leaves, is_leaf)]
return tuple(treedef.unflatten(leaves_i) for leaves_i in leaves)
# All Pytrees are leaves.
return trees
tree_add = functools.partial(tree_map, operator.add)
tree_add.__doc__ = "Tree addition."
tree_sub = functools.partial(tree_map, operator.sub)
tree_sub.__doc__ = "Tree subtraction."
tree_mul = functools.partial(tree_map, operator.mul)
tree_mul.__doc__ = "Tree multiplication."
tree_div = functools.partial(tree_map, operator.truediv)
tree_div.__doc__ = "Tree division."
def tree_scalar_mul(scalar, tree_x):
"""Compute scalar * tree_x."""
return tree_map(lambda x: scalar * x, tree_x)
def tree_add_scalar_mul(tree_x, scalar, tree_y):
"""Compute tree_x + scalar * tree_y."""
return tree_map(lambda x, y: x + scalar * y, tree_x, tree_y)
_vdot = functools.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST)
def _vdot_safe(a, b):
return _vdot(jnp.asarray(a), jnp.asarray(b))
def tree_vdot(tree_x, tree_y):
"""Compute the inner product <tree_x, tree_y>."""
vdots = tree_map(_vdot_safe, tree_x, tree_y)
return tree_reduce(operator.add, vdots)
def _vdot_real(x, y):
"""Vector dot-product guaranteed to have a real valued result despite
possibly complex input. Thus neglects the real-imaginary cross-terms.
The result is a real float.
"""
#result = _vdot(x.real, y.real)
#if jnp.iscomplexobj(x) and jnp.iscomplexobj(y):
# result += _vdot(x.imag, y.imag)
result = _vdot(x, y).real # NOTE: without jit this is faster than variant above, no difference with jit
return result
def tree_vdot_real(tree_x, tree_y):
"""Compute the real part of the inner product <tree_x, tree_y>."""
return sum(tree_leaves(tree_map(_vdot_real, tree_x, tree_y)))
def tree_dot(tree_x, tree_y):
"""Compute leaves-wise dot product between pytree of arrays.
Useful to store block diagonal linear operators: each leaf of the tree
corresponds to a block."""
return tree_map(jnp.dot, tree_x, tree_y)
def tree_sum(tree_x):
"""Compute sum(tree_x)."""
sums = tree_map(jnp.sum, tree_x)
return tree_reduce(operator.add, sums)
def tree_l2_norm(tree_x, squared=False):
"""Compute the l2 norm ||tree_x||."""
squared_tree = tree_map(lambda leaf: jnp.square(leaf.real) + jnp.square(leaf.imag), tree_x)
sqnorm = tree_sum(squared_tree)
if squared:
return sqnorm
else:
return jnp.sqrt(sqnorm)
def tree_zeros_like(tree_x):
"""Creates an all-zero tree with the same structure as tree_x."""
return tree_map(jnp.zeros_like, tree_x)
def tree_ones_like(tree_x):
"""Creates an all-ones tree with the same structure as tree_x."""
return tree_map(jnp.ones_like, tree_x)
def tree_average(trees, weights):
"""Return the linear combination of a list of trees.
Args:
trees: tree of arrays with shape (m,...)
weights: array of shape (m,)
Returns:
a single tree that is the linear combination of all trees
"""
return tree_map(lambda x: jnp.tensordot(weights, x, axes=1), trees)
def tree_gram(a):
"""Compute Gramn matrix from the pytree of batchs of vectors.
Args:
a: pytree of arrays of shape (m,...)
Returns:
arrays of shape (m,m) of all dot products
"""
vmap_left = jax.vmap(tree_vdot, in_axes=(0,None))
vmap_right = jax.vmap(vmap_left, in_axes=(None,0))
return vmap_right(a, a)
def tree_inf_norm(tree_x):
"""Computes the infinity norm of a pytree."""
leaves_vec = tree_leaves(tree_map(jnp.ravel, tree_x))
return jnp.max(jnp.abs(jnp.concatenate(leaves_vec)))
def tree_where(cond, a, b):
"""jnp.where for trees.
Mimic broadcasting semantic of jnp.where.
cond, a and b can be arrays (including scalars) broadcastable to the leaves of
the other input arguments.
Args:
cond: pytree of booleans arrays, or single array broadcastable to the shapes
of leaves of `a` and `b`.
a: pytree of arrays, or single array broadcastable to the shapes of leaves
of `cond` and `b`.
b: pytree of arrays, or single array broadcastable to the shapes of leaves
of `cond` and `a`.
Returns:
pytree of arrays, or single array
"""
cond, a, b = broadcast_pytrees(cond, a, b)
return tree_map(jnp.where, cond, a, b)
def tree_negative(tree):
"""Computes elementwise negation -x."""
return tree_scalar_mul(-1, tree)
def tree_reciproqual(tree):
"""Computes elementwise inverse 1/x."""
return tree_map(lambda x: jnp.reciprocal(x), tree)
def tree_mean(tree):
"""Mean reduction for trees."""
leaves_avg = tree_map(jnp.mean, tree)
return tree_sum(leaves_avg) / len(tree_leaves(leaves_avg))
def tree_single_dtype(tree, convert_in_jax_dtype=True):
"""The dtype for all values in a tree, provided that all leaves share the same type.
If the leaves have different type, raise a ValueError.
Args:
tree: tree to get the dtype of
convert_in_jax_type: whether to convert the types in JAX precision.
Namely, a numpy int64 type is converted in a jax.numpy int32 type
by default unless one enables double precision using
jax.config.update("jax_enable_x64", True)
Return:
dtype shared by all leaves of the tree
"""
if convert_in_jax_dtype:
dtypes = set(
jnp.asarray(p).dtype
for p in tu.tree_leaves(tree)
if isinstance(
p, (bool, int, float, complex, onp.ndarray, jnp.ndarray)
)
)
else:
dtypes = set(
onp.asarray(p).dtype
for p in tu.tree_leaves(tree)
if isinstance(
p, (bool, int, float, complex, onp.ndarray, jnp.ndarray)
)
)
if not dtypes:
return None
if len(dtypes) == 1:
dtype = dtypes.pop()
return dtype
raise ValueError("Found more than one dtype in the tree.")
def get_real_dtype(dtype):
"""Dtype corresponding of real part of a complex dtype."""
if dtype not in [f'complex{i}' for i in [4, 8, 16, 32, 64, 128]]:
return dtype
else:
return dtype.type(0).real.dtype
def tree_conj(tree):
"""Complex conjugate of a tree."""
return tree_map(jnp.conj, tree)
def tree_real(tree):
"""Real part of a tree"""
return tree_map(jnp.real, tree)
def tree_imag(tree):
"""Imaginary part of a tree"""
return tree_map(jnp.imag, tree)