-
Notifications
You must be signed in to change notification settings - Fork 45
/
grouping.py
323 lines (250 loc) · 11.4 KB
/
grouping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grouping layers, for chaining sequential computations."""
import dataclasses
import typing
from typing import Any, Callable, Sequence
import jax
from penzai.core import formatting_util
from penzai.core import layer as layer_base
from penzai.core import selectors
from penzai.core import shapecheck
from penzai.core import struct
LayerLike: typing.TypeAlias = layer_base.LayerLike
@struct.pytree_dataclass
class Sequential(layer_base.Layer):
"""A group of layers to call sequentially.
``Sequential`` is one of the most common layer types to use in a penzai.nn
model, since many networks can be written as the composition of a number of
layers. However, you may prefer to use `CheckedSequential` if you can define
in advance the structure of inputs and outputs your layer will accept.
A common pattern in penzai is:
* subclass ``Sequential`` with a different layer name,
* inherit ``__init__`` and ``__call__`` from ``Sequential``,
* define a classmethod (often called ``from_config``) that constructs an
instance of the subclass with its contents.
This allows the configuration and initialization logic for parts of a network
(such as a self-attention layer) to be grouped in a single place, without
affecting the later ability to interactively modify the resulting network.
Subclasses of ``Sequential`` are NOT allowed to override ``__call__``. If a
user has a subclass of ``Sequential``, they should be able to assume it just
calls each child in order. (If you need finer control, consider having a
``Sequential`` as a child attribute instead, or just duplicate the relevant
logic for your own class.)
Attributes:
sublayers: A sequence of layers to call in order. These are usually pz.Layer
instances, but are allowed to be other types of callable PyTree as well.
"""
sublayers: list[layer_base.LayerLike]
@typing.final
def __call__(self, value: Any) -> Any:
"""Runs each of the sublayers in sequence.
Args:
value: The input to the layer.
Returns:
The output of the final sublayer.
"""
for i, layer in enumerate(self.sublayers):
with jax.named_scope(f"{i}"):
value = layer(value)
return value
def treescope_color(self) -> str | tuple[str, str]:
if type(self) is Sequential: # pylint: disable=unidiomatic-typecheck
return "#cdcdcd", "color-mix(in oklab, #cdcdcd 25%, white)"
else:
type_string = type(self).__module__ + "." + type(self).__qualname__
accent = formatting_util.color_from_string(type_string)
return accent, f"color-mix(in oklab, {accent} 25%, white)"
@struct.pytree_dataclass
@typing.final
class NamedGroup(layer_base.Layer):
"""A layer that names an activation or a sequence of layers.
This layer does not do anything interesting on its own, but exists primarily
to facilitate manipulation and inspection of a complex network:
* The name will show up in ``treescope`` when inspecting the network
interactively, giving context for the wrapped layers.
* ``NamedGroup`` layers can be selected with ``pz.select`` based on their
name, using something like ::
(...).at_instances_of(NamedGroup).where(lambda n: n.name == NAME)
* When traced in JAX, ``NamedGroup`` layers add their name to the name scope,
which will be visible in the TensorBoard profiler and in JAXPRs.
You can also omit the sublayers, in which case this serves as a lightweight
way to assign a name to an activation (mostly useful in combination with
`pz.select`).
Suggestion for when to use ``NamedGroup`` vs subclass `Sequential`: If you
have a function that builds a particular collection of sub-layers in a
reusable way, consider subclassing `Sequential` and having that function be a
constructor classmethod. If you just need to group some sublayers together,
but want to name them for later reference, just used ``NamedGroup``.
You shouldn't subclass ``NamedGroup``; either subclass `Sequential` or define
your own layer.
Attributes:
name: The name for the layer.
sublayers: A sequence of layers to call in order. These are usually pz.Layer
instances, but are allowed to be other types of callable PyTree as well.
"""
name: str = dataclasses.field(metadata={"pytree_node": False})
sublayers: Sequence[LayerLike]
def __call__(self, value: Any) -> Any:
"""Runs each of the sublayers in sequence.
Args:
value: The input to the layer.
Returns:
The output of the final sublayer.
"""
for i, layer in enumerate(self.sublayers):
with jax.named_scope(f"{i}"):
value = layer(value)
return value
def treescope_color(self) -> str | tuple[str, str]:
accent = formatting_util.color_from_string(self.name)
return accent, f"color-mix(in oklab, {accent} 25%, white)"
@struct.pytree_dataclass
class CheckedSequential(layer_base.Layer):
"""A group of layers to call sequentially, with known input/output types.
``CheckedSequential`` is a "typed" variant of `Sequential`, which is annotated
with input and output structures. The input and output structures will
share state variables, which can be used to make assertions about the
relationship between the shape of the inputs and the shape ouf the outputs.
Attributes:
input_like: An input structure, represented as a PyTree of
`pz.chk.ArraySpec` nodes. This defines the type of input this layer
expects to receive. Passing anything else will raise an error.
sublayers: A sequence of layers to call in order. These are usually pz.Layer
instances, but are allowed to be other types of callable PyTree as well.
output_like: An output structure, represented as a PyTree of
`pz.chk.ArraySpec` nodes. This defines the type of input this layer will
produce. Returining anything else will raise an error.
"""
input_like: shapecheck.StructureAnnotation = dataclasses.field(
metadata={"pytree_node": False, "treescope_always_collapse": True}
)
output_like: shapecheck.StructureAnnotation = dataclasses.field(
metadata={"pytree_node": False, "treescope_always_collapse": True}
)
sublayers: list[layer_base.LayerLike]
@layer_base.checked_layer_call
@typing.final
def __call__(self, value: Any) -> Any:
"""Runs each of the sublayers in sequence.
Args:
value: The input to the layer.
Returns:
The output of the final sublayer.
"""
for i, layer in enumerate(self.sublayers):
with jax.named_scope(f"{i}"):
value = layer(value)
return value
def treescope_color(self) -> str | tuple[str, str]:
if type(self) is CheckedSequential: # pylint: disable=unidiomatic-typecheck
return "#cdcdcd"
else:
type_string = type(self).__module__ + "." + type(self).__qualname__
accent = formatting_util.color_from_string(type_string)
return accent, "#cdcdcd"
@typing.final
def input_structure(self) -> Any:
return self.input_like
@typing.final
def output_structure(self) -> Any:
return self.output_like
@struct.pytree_dataclass
class Identity(layer_base.Layer):
"""A layer that returns its input unchanged, without any side effects."""
@typing.final
def __call__(self, value: Any) -> Any:
"""Returns the input unchanged."""
return value
def treescope_color(self) -> str | tuple[str, str]:
return "#cdcdcd", "color-mix(in oklab, #cdcdcd 25%, white)"
@struct.pytree_dataclass
@typing.final
class CheckStructure(layer_base.Layer):
"""A layer that checks the structure of the value passing through it.
Attributes:
expected: An expected structure, represented as a PyTree of
`pz.chk.ArraySpec` nodes. This defines the type of input this layer
expects to receive. Passing anything else will raise an error.
"""
expected: Any
def __call__(self, value: Any) -> Any:
"""Checks the structure of the value, then returns it."""
shapecheck.check_structure(value, self.expected)
return value
def treescope_color(self) -> str:
return "#cdcdcd"
def is_sequential_or_named(tree: Any) -> bool:
"""Checks if a tree is a subclass of `Sequential` or a `NamedGroup`."""
return isinstance(tree, Sequential | CheckedSequential | NamedGroup)
def is_anonymous_sequential(tree: Any) -> bool:
"""Checks if the type of a node is exactly `Sequential`, not a named subclass."""
return type(tree) is Sequential # pylint: disable=unidiomatic-typecheck
def inline_groups(
tree: Any,
parent_filter: Callable[[Any], bool],
child_filter: Callable[[Any], bool],
) -> Any:
"""Inlines sequential nodes into their parents if possible.
This function finds nodes that match ``child_filter`` within nodes that match
``parent_filter``, and splices the sublayers of the child node into the
parent, removing the child. This can be used to flatten a nested structure of
`Sequential` or `NamedGroup` objects into a new structer with a smaller depth.
The logic applies recursively: if a node matches both the parent and the
child filter, it may be inlined into its parent after its sublayers are
inlined into it.
For the common case where you wish to inline "anonymous" groups (instances
of type `Sequential` but not a more specific subclass of `Sequential`), you
can use the convenience wrapper `inline_anonymous_sequentials`.
Args:
tree: The tree to process.
parent_filter: A function that returns True on the nodes that we want to
inline sublayers into.
child_filter: A function that returns True on the nodes that we want to
remove and replace with the inlined sequence of its sublayers.
Returns:
A copy of ``tree`` that inlines nodes that match ``child_filter`` into their
parents whenever their parents match ``parent_filter``, as long as they
are subclasses of `Sequential`, `CheckedSequential`, or `NamedGroup`.
"""
def _step(subtree):
# Process children.
with_processed_children = (
selectors.select(subtree).at_children().apply(_step)
)
# Check for inlining opportunities at this level.
if is_sequential_or_named(with_processed_children) and parent_filter(
with_processed_children
):
new_sublayers = []
for sublayer in with_processed_children.sublayers:
if is_sequential_or_named(sublayer) and child_filter(sublayer):
# Inline the child's children into the parent's children.
new_sublayers.extend(sublayer.sublayers)
else:
new_sublayers.append(sublayer)
# Substitute the new flattened children.
return dataclasses.replace(
with_processed_children, sublayers=new_sublayers
)
else:
return with_processed_children
return _step(tree)
def inline_anonymous_sequentials(tree: Any) -> Any:
"""Inlines instances of `Sequential` (not subclasses) into parent groups."""
return inline_groups(
tree,
parent_filter=is_sequential_or_named,
child_filter=is_anonymous_sequential,
)