-
Notifications
You must be signed in to change notification settings - Fork 45
/
selectors.py
1398 lines (1151 loc) · 50.2 KB
/
selectors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A toolkit for pytree manipulation."""
from __future__ import annotations
import collections
import contextlib
import dataclasses
import functools
import typing
from typing import Any, Callable, Collection, Generic, Iterable, Literal, Mapping, Sequence
import equinox as eqx
import jax
import numpy as np
from penzai.core import partitioning
from penzai.core import struct
from penzai.core import tree_util as penzai_tree_util
KeyPath = tuple[Any, ...]
SelectedSubtree = typing.TypeVar("SelectedSubtree")
OtherSubtree = typing.TypeVar("OtherSubtree")
T = typing.TypeVar("T")
@struct.pytree_dataclass
class SelectionHole(struct.Struct):
"""A hole in a structure, taking the place of a selected subtree.
When building a selection, the nodes that are selected are moved out of the
original tree for easier processing. They are replaced by a ``SelectionHole``,
which points to the node that was here originally.
A ``SelectionHole`` is a PyTree with no children. This ensures that the
selected elements are actually "removed" from the tree from JAX's point of
view.
Users should not need to create a ``SelectionHole`` directly, and should
instead use the `select(...)` function and other selector traversals. However,
you may see a ``SelectionHole`` if inspecting the contents of a `Selection`
object.
Note that the `Selection` machinery assumes that the selected PyTree nodes do
not require their children to be a specific type, so that we can insert
``SelectionHole`` in arbitrary places in the tree.
If a node makes strong assumptions about the types of its children, it may
not be possible to select those children, since rebuilding that node with
a ``SelectionHole`` may fail.
See
https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
for more information on how to implement your PyTrees to avoid this problem.
Attributes:
path: Keypath to this hole. Used to index back into the selected components.
"""
path: KeyPath = dataclasses.field(metadata={"pytree_node": False})
@struct.pytree_dataclass
class SelectionQuote(struct.Struct):
"""Marks a particular subtree as relating to an inner selection.
``SelectionQuote`` is primarily used to handle the situation where a
`Selection` contains another `Selection`. In this situation, the inner
`Selection` holes must be kept distinct from the outer `Selection` holes,
so that only the outer `Selection`'s holes are modified by operations on the
selection. When this occurs, the inner `Selection`'s `SelectionHole` instances
are wrapped in a ``SelectionQuote``, so that they aren't processed until the
outer `Selection` is resolved.
Note that quoting is only applied to the `remainder` tree, since this is the
tree where we expect to find a `SelectionHole`. If there are `SelectionHole`
instances inside the selected subtrees themselves, these are not quoted, since
we never look at those subtrees when rebuilding the tree from a selection.
One situation where selections-of-selections may appear is when using
`penzai.treescope` to visualize a `Selection`. To
support even higher levels of nesting, or trees where the user has inserted
their own `SelectionHole` or ``SelectionQuote`` for some reason, we also
quote ``SelectionQuote``; it's not clear that there are many uses for this
but `penzai` supports it regardless.
Users should not need to create a ``SelectionQuote`` directly, and should
instead use the `select(...)` function and other selector traversals. However,
you may see ``SelectionQuote`` if inspecting the contents of a `Selection`
object.
.. note:: Aside for programming languages geeks:
``SelectionQuote`` can be seen as a Peano-arithmetic "successor" function
for De Bruijn levels, a particular convention for encoding variable binding
in the lambda calculus, with `SelectionHole` representing zero. (De Bruijn
levels are the conceptual opposite of De Bruijn indices, which use lower
indices for the innermost variables instead of for the outermost variables.)
See https://randall-holmes.github.io/Watson/babydocs/node26.html for a bit
of related discussion.
This is related to the interpretation of `Selection` as
a lens, for which the `remainder` tree conceptually represents the
partially-applied setter function.
"""
quoted: SelectionHole | SelectionQuote = dataclasses.field(
metadata={"pytree_node": False}
)
@dataclasses.dataclass(frozen=True)
class _InProgressSelectionBoundary:
"""Helper object for building a selection.
Not intended to be exposed to user code! If you're seeing this and you don't
expect to, please file an issue.
This class is only used temporarily, while building a selection. It denotes
the boundary between the selected part and the unselected part, allowing
us to pull out the selected nodes using an ordinary ``tree_map`` call. All
public selection-creation functions assume without checking that their
input does NOT already contain any instances of this class; if they do,
the resulting selection may be incorrect.
Attributes:
selected: The subtree we are selecting.
"""
selected: Any
def _is_hole_or_quote(subtree: Any) -> bool:
"""Checks if we need special handling for this subtree."""
return isinstance(subtree, (SelectionHole, SelectionQuote))
@struct.pytree_dataclass
class Selection(Generic[SelectedSubtree], struct.Struct):
"""A selected subset of nodes within a larger PyTree.
Penzai selectors (such as ``.at(...)``) return ``Selection`` objects, which
indicate a specific subset of nodes within a larger PyTree, allowing those
nodes to be pulled out and modified in a functional way.
Selected nodes are required to be non-overlapping: no selected node can be
the ancestor of any other selected node in the same selection.
For convenience, a ``Selection`` is also a PyTree, and its leaves are the same
as the leaves in the original PyTree, but they are likely to be in a different
order.
.. note:: Aside for functional programming geeks:
A ``Selection`` is conceptually related to an "optic", specifically a
"lens". If you're familiar with optics, you can
think of a ``Selection`` as a partially-applied lens: it allows either
retrieving the selected values, or setting the selected values in the
structure. (If you're not familiar with optics, you can ignore this.)
Attributes:
selected_by_path: A mapping whose values are the selected parts from the
original structure, and whose keys are the keypaths for those parts (as
registered with JAX's PyTree registry). This is an `OrderedDict` to
prevent JAX from trying to sort the keys, which may be arbitrarily
hashable objects without an ordering.
remainder: The rest of the structure. The locations where the selected
components were are marked with `SelectionHole` nodes. If the remainder
also includes a ``Selection`` itself, the remainder may also include
`SelectionQuote` nodes.
"""
selected_by_path: collections.OrderedDict[KeyPath, SelectedSubtree]
remainder: Any
def count(self) -> int:
"""Returns the number of elements in the selection."""
return len(self.selected_by_path)
def __len__(self) -> int:
"""Returns the number of elements in the selection."""
return len(self.selected_by_path)
def is_empty(self) -> bool:
"""Returns True if the selection is empty."""
return not self.selected_by_path
def deselect(self) -> Any:
"""Rebuilds the tree, forgetting which nodes were selected.
Returns:
A copy of `remainder` with the holes filled by the values in
`selected_by_path`. If called on an ordinary selection, this rebuilds
the original tree.
"""
def rebuild(subtree) -> Any:
"""Processes nodes in the remainder to rebuild the tree."""
if isinstance(subtree, SelectionHole):
# Pull out the value for each hole.
return self.selected_by_path[subtree.path]
elif isinstance(subtree, SelectionQuote):
# Unquote one level of quoting.
return subtree.quoted
else:
# An ordinary PyTree leaf in the remainder; leave it as is.
return subtree
with _wrap_selection_errors(self):
return jax.tree_util.tree_map(
rebuild, self.remainder, is_leaf=_is_hole_or_quote
)
def get(self) -> SelectedSubtree:
"""Returns the result of a singleton selection.
Returns:
The selected subtree from this selection.
Raises:
ValueError: If this selection does not have exactly one selected subtree.
"""
if len(self.selected_by_path) != 1:
raise ValueError(
"Selection.get() can only be called on selections with one selected"
f" element, but there were {len(self.selected_by_path)}. Consider"
" using .selected_by_path instead."
)
(value,) = self.selected_by_path.values()
return value
def get_keypaths(self) -> tuple[KeyPath, ...]:
"""Returns the collection of selected key paths."""
return tuple(self.selected_by_path.keys())
# pytype: disable=invalid-annotation
@typing.overload
def apply(
self,
fn: Callable[[SelectedSubtree], Any],
*,
keep_selected: Literal[False] = False,
with_keypath: Literal[False] = False,
) -> Any:
...
@typing.overload
def apply(
self,
fn: Callable[[SelectedSubtree], OtherSubtree],
*,
keep_selected: Literal[True],
with_keypath: Literal[False] = False,
) -> Selection[OtherSubtree]:
...
@typing.overload
def apply(
self,
fn: Callable[[KeyPath, SelectedSubtree], Any],
*,
with_keypath: Literal[True],
keep_selected: Literal[False] = False,
) -> Any:
...
@typing.overload
def apply(
self,
fn: Callable[[KeyPath, SelectedSubtree], OtherSubtree],
*,
with_keypath: Literal[True],
keep_selected: Literal[True],
) -> Selection[OtherSubtree]:
...
# pytype: enable=invalid-annotation
def apply(
self,
fn: Callable[..., Any],
*,
with_keypath: bool = False,
keep_selected: bool = False,
) -> Any:
"""Replaces each selected node with the result of applying this function.
Args:
fn: Function to apply to each selected node. This function should take a
PyTree (if with_keypath=False) or a KeyPath and a PyTree (if
with_keypath=True) and return a replacement PyTree.
with_keypath: Whether to pass the keypath as the first argument to the
callable.
keep_selected: Whether to keep the nodes selected. If True, returns the
modified selection; if False, rebuilds the tree after replacing.
Returns:
Either a modified `Selection` (if keep_selected=True) or a rebuilt version
of the original tree, each with the replacements applied.
"""
if with_keypath:
new_values = collections.OrderedDict(
[(k, fn(k, v)) for k, v in self.selected_by_path.items()]
)
else:
new_values = collections.OrderedDict(
[(k, fn(v)) for k, v in self.selected_by_path.items()]
)
new_selection = Selection(
selected_by_path=new_values, remainder=self.remainder
)
if keep_selected:
return new_selection
else:
return new_selection.deselect()
def set(self, replacement: Any) -> Any:
"""Replaces the selected subtree(s) with a fixed replacement.
Args:
replacement: The pytree to replace with.
Returns:
A modified version of the original tree, with this replacement in place
of any selected subtrees.
"""
return self.apply(lambda _: replacement)
def get_by_path(self) -> collections.OrderedDict[KeyPath, SelectedSubtree]:
"""Retrieves the selected subtree(s) based on their path(s).
Returns:
A dictionary of selected nodes, indexed by their path
"""
return self.selected_by_path
def set_by_path(
self, replacer: Mapping[KeyPath, Any] | Callable[[KeyPath], Any]
) -> Any:
"""Replaces the selected subtree(s) based on their path(s).
If you need both the value and the key, see
``.apply(fn, with_keypath=True)``.
Args:
replacer: A mapping from key paths to replacements, or a function that
builds such a mapping. Passing ``self.selection_by_path`` will return
the original tree unchanged.
Returns:
A modified version of the original tree, with replacements taken from the
replacer.
"""
if callable(replacer):
replacer = {k: replacer(k) for k in self.selected_by_path.keys()}
return self.apply(lambda k, v: replacer[k], with_keypath=True)
def select_and_set_by_path(
self, replacements_by_path: dict[KeyPath, Any]
) -> Any:
"""Selects subtrees and replaces them based on relative keypaths.
Convenience method that combines `at_keypaths` and `set_by_path`.
Args:
replacements_by_path: A mapping from key paths to replacements. Key paths
are relative to the current selected nodes.
Returns:
A modified version of the original tree, with replacements taken from the
replacer.
"""
def go(node):
return (
select(node)
.at_keypaths(replacements_by_path.keys())
.set_by_path(replacements_by_path)
)
return self.apply(go)
def get_sequence(self) -> tuple[SelectedSubtree, ...]:
"""Gets the selected subtree(s) in order.
Convenience wrapper for ``.selected_by_path.values()``.
Returns:
A tuple containing the selected subtrees.
"""
return tuple(self.selected_by_path.values())
def set_sequence(self, replacements: Iterable[Any]) -> Any:
"""Replaces the selected subtree(s) in order.
Args:
replacements: An iterable of PyTrees to insert at the selected locations,
in order.
Returns:
A modified version of the original tree, with replacements taken from the
iterable.
"""
replacer = {}
for keypath, replacement in zip(self.selected_by_path.keys(), replacements):
replacer[keypath] = replacement
return self.set_by_path(replacer)
def flatten_selected_selections(
self: "Selection[Selection[SelectedSubtree]]",
) -> "Selection[SelectedSubtree]":
"""Flattens a selection whose selected values are all selections.
This function takes a selection for which all of the selected values are
already selections, and merges them into a single selection that selects
all of the values from each individual selection.
You can use this to build more complex selections by chaining your own
logic. For instance, if you have written a function ``f`` that selects part
of a tree, you can run
::
selection.apply(f, keep_selected=True).flatten_selected_selections()
to "broadcast" that logic across all of the already-selected subtrees in
the original selection.
See also `refine`, which allows you to express similar transformations more
directly.
Returns:
A flattened selection object.
"""
# Strategy: Replace the values of each inner selection by a boundary,
# deselect everything, then re-select at the boundary.
def process_subselection(subselection: Selection) -> Any:
if not isinstance(subselection, Selection):
raise ValueError(
"flatten_selected_selections can only be called on Selections for"
" which all values in `.selected_by_path` are also Selections. Got"
f" {subselection}"
)
return subselection.apply(_InProgressSelectionBoundary)
with _wrap_selection_errors(self):
return _build_selection_from_boundary(self.apply(process_subselection))
def refine(
self, selector_fn: "Callable[[Any], Selection[OtherSubtree]]"
) -> "Selection[OtherSubtree]":
"""Refines a selection by selecting within each selected subtree.
Although selectors can already be refined by making additional calls,
chained calls generally treat all selected subtrees the same way. In
contrast, this method allows each selected node to be processed
independently. Additionally, similar to `apply`, the additional logic is
free to modify the subtree as it goes.
Args:
selector_fn: A function that takes a selected subtree from this selection
and returns a new selection object, usually a selection of some nodes in
the input subtree.
Returns:
A new selection that selects every node selected by ``selector_fn``, but
in the context of the original tree rather than the individual selected
subtrees.
"""
return self.apply(
selector_fn, keep_selected=True
).flatten_selected_selections()
def at(
self,
accessor_fn: Callable[[SelectedSubtree], Any | tuple[Any, ...]],
) -> "Selection":
"""Selects a specific child of each selected node.
``Selection.at(...)`` allows you to modify a tree with an almost-imperative
style while maintaining a functional interface, similar to the
``Array.at[...]`` syntax for ordinary NDArrays. It takes a callable that
picks out a subtree of the tree, and returns a new selection that selects
the part that was picked out.
For instance, if you have an object
::
obj = Foo(bar=[1, 2, {"baz": 5}])
you could select the 5 using
::
pz.select(obj).at(lambda x: x.bar[2]["baz"])
``Selection.at`` is implemented using `equinox.tree_at`.
Args:
accessor_fn: A function which takes each element of the current selection
and returns a node or tuple of nodes within that selection. This
function must be structural; it must depend only on the PyTree structure
of its input and not on the actual values of the leaves. See
`equinox.tree_at` for the full set of requirements.
Returns:
A modified selection that selects the specific child of each node in the
original selection.
"""
# Use eqx.tree_at on each selected node to identify the subtree to select.
add_boundary = functools.partial(
eqx.tree_at,
accessor_fn,
replace_fn=_InProgressSelectionBoundary,
)
with_boundary = self.apply(add_boundary)
with _wrap_selection_errors(self):
return _build_selection_from_boundary(with_boundary)
def at_pytree_leaves(self) -> "Selection":
"""Selects all PyTree leaves of each selected subtree.
This selects all of the leaves of the PyTree according to `jax.tree_util`,
giving the most-specific selection expressible with a `Selection` object.
(Note that, if any objects in the tree are not registered as JAX PyTree
nodes, they will be selected in their entirety even if they contain children
when printed out by treescope.)
Returns:
A new selection that selects every leaf of each selected subtree.
"""
add_boundary = functools.partial(
jax.tree_util.tree_map, _InProgressSelectionBoundary
)
return _build_selection_from_boundary(self.apply(add_boundary))
def at_children(self) -> "Selection":
"""Selects all direct children of each selected subtree.
This can be used to implement recursive tree traversals in a generic way,
using something like::
def traverse(subtree):
# ... process the subtree before recursive call ...
subtree = select(subtree).at_children().apply(traverse)
# ... process the subtree after the recursive call ...
return subtree
new_value = traverse(value)
Returns:
A new selection that selects every direct child of a selected subtree.
If any leaves were previously selected, those leaves will no longer be
selected (since they have no children).
"""
def process(subtree):
maybe_children = penzai_tree_util.tree_flatten_exactly_one_level(subtree)
if maybe_children:
keyed_children, treedef = maybe_children
return treedef.unflatten(
[_InProgressSelectionBoundary(child) for _, child in keyed_children]
)
else:
return subtree
with _wrap_selection_errors(self):
return _build_selection_from_boundary(self.apply(process))
@typing.overload
def where(
self: "Selection[SelectedSubtree]",
filter_fn: Callable[[SelectedSubtree], bool],
*,
with_keypath: Literal[False] = False,
) -> "Selection[SelectedSubtree]":
...
@typing.overload
def where(
self: "Selection[SelectedSubtree]",
filter_fn: Callable[[KeyPath, SelectedSubtree], bool],
*,
with_keypath: Literal[True],
) -> "Selection[SelectedSubtree]":
...
def where(
self: "Selection[SelectedSubtree]",
filter_fn: Callable[..., bool],
*,
with_keypath: bool = False,
) -> "Selection[SelectedSubtree]":
"""Filters to only a subset of selected nodes based on a condition.
Args:
filter_fn: Function to determine whether to keep a node in the selection.
This function should take a PyTree (if ``with_keypath=False``) or a
KeyPath and a PyTree (if ``with_keypath=True``).
with_keypath: Whether to pass the keypath as the first argument to the
callable.
Returns:
A new selection that includes only the selected parts where ``filter_fn``
evaluates to true.
"""
with _wrap_selection_errors(self):
keep = _InProgressSelectionBoundary
if with_keypath:
new_with_boundary = self.apply(
lambda k, v: keep(v) if filter_fn(k, v) else v, with_keypath=True
)
else:
new_with_boundary = self.apply(lambda v: keep(v) if filter_fn(v) else v)
return _build_selection_from_boundary(new_with_boundary)
@typing.overload
def at_subtrees_where(
self,
filter_fn: Callable[[SelectedSubtree], bool],
*,
with_keypath: Literal[False] = False,
absolute_keypath: bool = False,
innermost: bool = False,
) -> "Selection":
...
@typing.overload
def at_subtrees_where(
self,
filter_fn: Callable[[KeyPath, SelectedSubtree], bool],
*,
with_keypath: Literal[True],
absolute_keypath: bool = False,
innermost: bool = False,
) -> "Selection":
...
def at_subtrees_where(
self,
filter_fn: Callable[..., bool],
*,
with_keypath: bool = False,
absolute_keypath: bool = False,
innermost: bool = False,
) -> "Selection":
"""Selects subtrees of selected nodes where a function evaluates to True.
Note that a selection cannot contain a node that is the descendant of
another selected node. If ``innermost=False``, we return the outermost
node, whereas if ``innermost=True`` we return the innermost.
If you want to apply a modification to *all* matches of a function, even
if they are nested, you can use a pattern like ::
selection = select(value)
while not selection.empty():
selection = selection.at_subtrees_where(foo).apply(
bar, keep_selected=True)
new_value = selection.deselect()
More complex modifications can also be made using a manual traversal, e.g.::
def traverse(subtree):
# ... process the subtree before recursive call ...
subtree = select(subtree).at_children().apply(traverse)
# ... process the subtree after the recursive call ...
return subtree
new_value = traverse(value)
Args:
filter_fn: A function determining which subtrees to select. Should be
deterministic, and may be called more than once. This function should
take a PyTree (if ``with_keypath=False``) or a KeyPath and a PyTree (if
``with_keypath=True``).
with_keypath: Whether to pass a keypath as the first argument to the
callable.
absolute_keypath: Whether to pass the keypath relative to the root of the
original tree (if True) or the keypath relative to the currently
selected node (if False). Ignored if ``with_keypath`` is False.
innermost: Whether to select the innermost subtree(s) for which the filter
function is true, instead of the first subtrees encountered.
Returns:
A new selection that selects the desired subtrees.
"""
def safe_filter_fn(*args):
result = filter_fn(*args)
if not isinstance(result, bool):
raise TypeError(
"Filter function for at_subtrees_where must return a bool, not"
f" {type(result)}"
)
return result
with _wrap_selection_errors(self):
if with_keypath or innermost:
if with_keypath:
wrapped_filter_fn = safe_filter_fn
if not with_keypath:
wrapped_filter_fn = lambda _, s: safe_filter_fn(s)
def process_subtree(keypath, leaf_or_subtree) -> tuple[bool, Any]:
"""Recursively walks subtrees one level at a time.
Args:
keypath: Keypath to this subtree
leaf_or_subtree: The subtree (or a leaf)
Returns:
(found_any, processed_pytree)
"""
if not innermost:
found_here = wrapped_filter_fn(keypath, leaf_or_subtree)
if found_here:
# Select it.
return True, _InProgressSelectionBoundary(leaf_or_subtree)
maybe_children = penzai_tree_util.tree_flatten_exactly_one_level(
leaf_or_subtree
)
if maybe_children:
# Recursive step.
keyed_children, treedef = maybe_children
new_children = []
any_descendant_selected = False
for key, child in keyed_children:
found_in_child, new_child = process_subtree(
keypath + (key,), child
)
new_children.append(new_child)
any_descendant_selected = (
any_descendant_selected or found_in_child
)
if not any_descendant_selected and wrapped_filter_fn(
keypath, leaf_or_subtree
):
# No descendant was selected, so select this tree.
return True, _InProgressSelectionBoundary(leaf_or_subtree)
elif any_descendant_selected:
# A descendant was selected, use the descendant.
return True, treedef.unflatten(new_children)
else:
# Didn't find anything.
return False, leaf_or_subtree
else:
# This is a leaf.
return False, leaf_or_subtree
def process_selected(keypath, selected_subtree):
"""Processes one of the selected subtrees."""
if absolute_keypath:
# Start with this selection's keypath
_, result = process_subtree(keypath, selected_subtree)
return result
else:
# Start with the empty keypath (relative to this selection)
_, result = process_subtree((), selected_subtree)
return result
return _build_selection_from_boundary(
self.apply(process_selected, with_keypath=True)
)
else:
def process_leaf_or_filtered(leaf_or_subtree):
"""Processes leaves or values where filter_fn is True."""
if safe_filter_fn(leaf_or_subtree):
return _InProgressSelectionBoundary(leaf_or_subtree)
else:
return leaf_or_subtree
def process_selected(selected_subtree):
"""Processes one of the selected subtrees."""
return jax.tree_util.tree_map(
process_leaf_or_filtered, selected_subtree, is_leaf=safe_filter_fn
)
return _build_selection_from_boundary(self.apply(process_selected))
def at_instances_of(
self,
cls: type[OtherSubtree] | tuple[type[OtherSubtree], ...],
innermost: bool = False,
) -> "Selection[OtherSubtree]":
"""Selects subtrees that are an instance of the given type.
Convenience wrapper for::
.at_subtrees_where(lambda subtree: isinstance(subtree, cls))
Args:
cls: The class (or tuple of classes) to retrieve instances of.
innermost: Whether to return the innermost instances of the class (instead
of the outermost).
Returns:
A refined selection that selects instances of this class within the
original selection. If instances of this class are nested, only selects
the outermost (if ``innermost=False``) or the innermost (if
``innermost=True``), but never both.
"""
return self.at_subtrees_where(
lambda subtree: isinstance(subtree, cls), innermost=innermost
)
def at_equal_to(self, template: OtherSubtree) -> Selection[OtherSubtree]: # pytype: disable=invalid-annotation
"""Selects subtrees that are equal to a particular object.
Mostly a convenience wrapper for ::
.at_subtrees_where(lambda subtree: template == subtree)
but also skips `jax.Array`, `np.ndarray`, and
`penzai.core.named_axes.NamedArray`, since they override ``==`` to return
arrays.
Args:
template: The object to select occurrences of.
Returns:
A refined selection that selects instances of this class that compare
equal to this object (with other on the left).
"""
# Lazy import to avoid circular dependencies
import penzai.core.named_axes # pylint: disable=g-import-not-at-top
bypass_equal_types = (
jax.Array,
np.ndarray,
penzai.core.named_axes.NamedArrayBase,
)
if isinstance(template, bypass_equal_types):
raise ValueError(
"Cannot use at_equal_to to check for equality of an array, since"
" arrays override the == operator. Consider using `at_subtrees_where`"
" instead."
)
def _check_equal(subtree):
if isinstance(subtree, bypass_equal_types):
return False
else:
return bool(subtree == template)
return self.at_subtrees_where(_check_equal)
def partition(self, at_leaves: bool = False) -> tuple[Any, Any]:
"""Partitions the tree into ``(selected_tree, remainder_tree)`` parts.
This function can be used to separate out the selected components of a tree
into their own separate tree, so that JAX functions and other JAX libraries
can process them like ordinary PyTrees. It splits its input into two
disjoint trees ``(selected_tree, remainder_tree)``, where ``selected_tree``
only contains the leaves that were selected, and `remainder_tree` only
contains the remainder. The parts that were removed are identified using a
sentinel `pz.NotInThisPartition` object, which has no PyTree children.
The main use case for ``partition`` is to identify subsets of models that
should be treated in different ways by JAX API functions. For instance, if
you want to take a gradient with respect to a specific subset of parameters,
you can select those parameters, call ``partition`` to separate them from
the rest, then call `jax.grad` and use `argnums` to identify the partition
of interest. Similarly, if you want to donate only a subset of the state to
`jax.jit`, you can partition it and then use JAX's ``donate_argnums``
argument to `jax.jit` to identify the parts you want to donate. Inside the
function, you can then use `pz.combine` to rebuild the original tree.
It is possible to repeatedly call ``partition`` to split a tree into
more than two parts. In particular, you can select the ``remainder_tree``,
target some additional nodes, and call ``.partition()`` again, repeating
this process as needed. All of the partitioned trees can then be re-combined
using a single call to `pz.combine`.
Note that `NotInThisPartition` is a PyTree node with no children, which
means that partitioned trees are safe to pass through JAX transformations,
and the set of leaves in the two partitioned trees together are the same as
the set of leaves in the original selected tree.
This function is inspired by Equinox's `equinox.partition`, but is designed
to work with Penzai's selector system. Unlike `equinox.partition`, missing
nodes are identified with the `pz.NotInThisPartition` sentinel, and can
replace arbitrary PyTree subtrees instead of just leaves. (Partitioning is
also somewhat less important in Penzai than in Equinox because all PyTree
leaves are arraylike by convention; partitioning is only necessary when
different parts of the tree need special treatment e.g. for ``argnums`` or
``donate_argnums`` parameters.)
Args:
at_leaves: Whether to do the partitioning at the leaf level, so that the
returned trees have exactly the same structure. (Note that `pz.combine`
is OK with entire subtrees missing, so this is not necessary, but can
make the partitions easier to manipulate manually if desired.) If False,
the entire selected subtrees will be replaced by `NotInThisPartition` in
the remainder tree.
Returns:
A tuple ``(selected_tree, remainder_tree)``, where both trees have the
same structure (if ``at_leaves=True``) or the same prefix (if
``at_leaves=False``) except that `NotInThisPartition` is used to replace
parts that are in the other partition.
"""
if at_leaves:
selected_tree = (
self.invert()
.at_pytree_leaves()
.set(partitioning.NotInThisPartition())
)
remainder_tree = self.at_pytree_leaves().set(
partitioning.NotInThisPartition()
)
else:
selected_tree = self.invert().set(partitioning.NotInThisPartition())
remainder_tree = self.set(partitioning.NotInThisPartition())
return selected_tree, remainder_tree
def at_keypaths(self, keypaths: Collection[KeyPath]) -> "Selection":
"""Selects nodes by their keypaths (relative to the current selection).
Args:
keypaths: A collection of keypaths.
Returns:
A new selection where any node whose keypath is in the given selection
is selected. Note that if any path in ``keypaths`` is a prefix of another,
only the shorter prefix will be used, since selected nodes cannot be
nested.
"""
return self.at_subtrees_where(
lambda keypath, _: keypath in keypaths,
with_keypath=True,
absolute_keypath=False,
)
def invert(self) -> "Selection":
"""Inverts a selection, selecting subtrees with no selected children.
``selection.invert()`` selects the largest set of subtrees such that those
subtrees do NOT contain any selected children in the original selection.
In other words, it selects the common ancestors of all unselected nodes,
without selecting any selected nodes.
Returns:
An inverted selection.
"""
# Strategy: Select any node whose keypath is NOT a prefix of
# any of the original selection's keypaths. But we have to be careful to
# not include the children of the original selection by accident (since
# their paths aren't prefixes of the original selection anymore).
original_keypath_set = set(self.selected_by_path.keys())
all_prefixes = set()
for keypath in self.selected_by_path.keys():
for i in range(len(keypath) + 1):
all_prefixes.add(keypath[:i])
def equal_or_not_a_prefix(keypath, _):
return keypath in original_keypath_set or keypath not in all_prefixes
return (
select(self.deselect())
.at_subtrees_where(
equal_or_not_a_prefix,
with_keypath=True,
)
.where(
lambda keypath, _: keypath not in all_prefixes,
with_keypath=True,
)
)
def at_childless(self) -> "Selection":
"""Selects all PyTree nodes with no children, including PyTree leaves.
This is different than `at_pytree_leaves` in that it additionally selects
pytree nodes that are childless, e.g. empty lists, ``None``, and structures
without any PyTree children. Those nodes are not considered leaves according
to JAX, but it may still be useful to select them, e.g. for visualization