-
Notifications
You must be signed in to change notification settings - Fork 45
/
shared_value_postprocessor.py
336 lines (278 loc) · 11.4 KB
/
shared_value_postprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Postprocessor to annotate repetitions of mutable values.
This can be useful for figuring out when multiple mutable objects in a tree are
all references to the same object.
"""
import contextlib
import dataclasses
import io
from typing import Any, Optional, Sequence
import jax
from penzai.core import context
from penzai.treescope import html_escaping
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
from penzai.treescope.foldable_representation import part_interface
@dataclasses.dataclass
class _SharedObjectTracker:
"""Helper object to track IDs we've seen before.
Attributes:
seen_at_least_once: Set of node IDs we've seen at least one time.
seen_more_than_once: Set of node IDs we've seen more than once.
"""
seen_at_least_once: set[int]
seen_more_than_once: set[int]
_shared_object_ids_seen: context.ContextualValue[
Optional[_SharedObjectTracker]
] = context.ContextualValue(
module=__name__, qualname="_shared_object_ids_seen", initial_value=None
)
"""A collection of IDs of objects that we have already rendered.
This is used to detect "unsafe" sharing of objects.
Normally references to the same object are ignored by JAX, and doing any
PyTree manipulation destroys the sharing. However, leaves of JAX PyTrees
may still involve mutable-object sharing, and we want to be able to
use Treescope to visualize this kind of structure as well. We also
use this to avoid infinite recursion if a node in the tree contains
itself.
This context manager should only be used by code in this module.
"""
class SharedWarningLabel(basic_parts.BaseSpanGroup):
"""A comment identifying this shared Python object."""
def _span_css_class(self) -> str:
return "shared_warning_label"
def _span_css_rule(
self, setup_context: part_interface.HtmlContextForSetup
) -> part_interface.CSSStyleRule:
return part_interface.CSSStyleRule(
html_escaping.without_repeated_whitespace("""
.shared_warning_label
{
color: #ffac13;
}
""")
)
@dataclasses.dataclass(frozen=False)
class DynamicSharedCheck(part_interface.RenderableTreePart):
"""Dynamic group that renders its child only if a node is shared.
This node is used to apply special rendering to nodes that are encountered in
multiple places in the same input. It works by holding a particular node ID
as well as a reference to a set of node IDs that have been seen more than
once. When rendered, it checks if its node ID has been seen more than once
and renders differently if so.
Note that we might have only seen this object once at the time the shared
warning object is constructed, even if eventually we will see it again. This
means the `seen_more_than_once` attribute has to be (a reference to) a
*mutable set* modified elsewhere.
Shared markers always act like empty parts for layout decisions.
Attributes:
if_shared: Child to render only if the node is shared.
node_id: Node ID of the node we are rendering. Used for the warning and also
looked up in `seen_more_than_once`.
seen_more_than_once: Reference to an externally-maintaned set of node IDs
we've seen more than once. Usually, this will be the same set from the
active `_shared_object_ids_seen` context.
"""
if_shared: part_interface.RenderableTreePart
node_id: int
seen_more_than_once: set[int]
def _compute_collapsed_width(self) -> int:
return 0
def _compute_newlines_in_expanded_parent(self) -> int:
return 0
def foldables_in_this_part(self) -> Sequence[part_interface.FoldableTreeNode]:
return []
def _compute_tags_in_this_part(self) -> frozenset[Any]:
return frozenset()
def html_setup_parts(
self, setup_context: part_interface.HtmlContextForSetup
) -> set[part_interface.CSSStyleRule | part_interface.JavaScriptDefn]:
return self.if_shared.html_setup_parts(setup_context)
def render_to_html(
self,
stream: io.TextIOBase,
*,
at_beginning_of_line: bool = False,
render_context: dict[Any, Any],
):
if self.node_id in self.seen_more_than_once:
self.if_shared.render_to_html(
stream,
at_beginning_of_line=at_beginning_of_line,
render_context=render_context,
)
def render_to_text(
self,
stream: io.TextIOBase,
*,
expanded_parent: bool,
indent: int,
roundtrip_mode: bool,
render_context: dict[Any, Any],
):
if self.node_id in self.seen_more_than_once:
self.if_shared.render_to_text(
stream,
expanded_parent=expanded_parent,
indent=indent,
roundtrip_mode=roundtrip_mode,
render_context=render_context,
)
@dataclasses.dataclass(frozen=False)
class WithDynamicSharedPip(basic_parts.DeferringToChild):
"""Dynamic group that adds an orange marker to its child if it is shared.
This node is used to apply special rendering to nodes that are encountered in
multiple places in the same input. It works by holding a particular node ID
as well as a reference to a set of node IDs that have been seen more than
once. When rendered, it checks if its node ID has been seen more than once
and renders differently if so.
Note that we might have only seen this object once at the time the shared
warning object is constructed, even if eventually we will see it again. This
means the `seen_more_than_once` attribute has to be (a reference to) a
*mutable set* modified elsewhere.
Shared markers always act like their child. They also do not disrupt children
drawing themselves first.
Attributes:
child: Child to render.
node_id: Node ID of the node we are rendering. Used for the warning and also
looked up in `seen_more_than_once`.
seen_more_than_once: Reference to an externally-maintaned set of node IDs
we've seen more than once. Usually, this will be the same set from the
active `_shared_object_ids_seen` context.
"""
child: part_interface.RenderableTreePart
node_id: int
seen_more_than_once: set[int]
def html_setup_parts(
self, setup_context: part_interface.HtmlContextForSetup
) -> set[part_interface.CSSStyleRule | part_interface.JavaScriptDefn]:
return self.child.html_setup_parts(setup_context) | {
part_interface.CSSStyleRule(
html_escaping.without_repeated_whitespace(f"""
.shared_warning_pip
{{
padding-right: 1ch;
margin-right: -0.5ch;
background: linear-gradient(135deg, orange 0 0.6ch, transparent 0.6ch );
}}
.shared_warning_pip.is_first_on_line:not({setup_context.collapsed_selector} *)
{{
margin-left: -0.5ch;
}}
""")
)
}
def render_to_html(
self,
stream: io.TextIOBase,
*,
at_beginning_of_line: bool = False,
render_context: dict[Any, Any],
):
if self.node_id in self.seen_more_than_once:
if at_beginning_of_line:
stream.write(
"<span class='shared_warning_pip is_first_on_line'></span>"
)
else:
stream.write("<span class='shared_warning_pip'></span>")
self.child.render_to_html(
stream,
at_beginning_of_line=at_beginning_of_line,
render_context=render_context,
)
def setup_shared_value_context() -> contextlib.AbstractContextManager[None]:
"""Returns a context manager for the shared value context.
Within this context manager, `_shared_object_ids_seen` will refer
to a consistent tracker. This tracker can then be used to check for repeated
appearances of the same mutable object.
This should be included in the `context_builders` argument to any renderer
that checks for shared values.
"""
return _shared_object_ids_seen.set_scoped(_SharedObjectTracker(set(), set()))
# Types that can have multiple references in the same object without it being
# necessary or important to highlight the shared reference.
_SAFE_TO_SHARE_TYPES = {
jax.Array,
}
def check_for_shared_values(
node: Any,
path: tuple[Any, ...] | None,
node_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
| part_interface.RenderableAndLineAnnotations
| type(NotImplemented)
):
# pylint: disable=g-doc-args,g-doc-return-or-yield
"""Wrapper hook to check for and annotate shared values.
This wrapper should only be used by renderers that also include
`build_styles_for_shared_values` in their HTML configuration and
`setup_shared_value_context` in their context builders.
Args:
node: The node that has been rendered
path: Optionally, a path to this node as a string.
node_renderer: The inner renderer for this node. This should be used to
render `node` itself into HTML tags.
Returns:
A possibly-modified representation of this object.
Raises:
RuntimeError: If called outside of the context constructed via
setup_shared_value_context.
"""
# pylint: enable=g-doc-args,g-doc-return-or-yield
shared_object_tracker = _shared_object_ids_seen.get()
if shared_object_tracker is None:
raise RuntimeError(
"`check_for_shared_values` should only be called in a shared value"
" context! Make sure the current treescope renderer has"
" `setup_shared_value_context` in its `context_builders`."
)
# Use object identity to track shared references and loops.
node_id = id(node)
# For types that we know are immutable, it's not necessary to render shared
# references in a special way. (Hashable objects can _technically_ be
# modified but we trust the user to know what they are doing if so.)
safe_to_share = (
hasattr(node, "__hash__") and node.__hash__ is not None
) or isinstance(node, tuple(_SAFE_TO_SHARE_TYPES))
# Render the node normally.
rendering = node_renderer(node, path)
if not safe_to_share:
# Mark this as possibly shared.
if node_id in shared_object_tracker.seen_at_least_once:
shared_object_tracker.seen_more_than_once.add(node_id)
else:
shared_object_tracker.seen_at_least_once.add(node_id)
# Wrap it in a shared value wrapper; this will check to see if the same
# node was seen more than once, and add an annotation if so.
return part_interface.RenderableAndLineAnnotations(
renderable=WithDynamicSharedPip(
rendering.renderable,
node_id=node_id,
seen_more_than_once=shared_object_tracker.seen_more_than_once,
),
annotations=basic_parts.siblings(
DynamicSharedCheck(
if_shared=SharedWarningLabel(
basic_parts.Text(f" # Repeated python obj at 0x{node_id:x}")
),
node_id=node_id,
seen_more_than_once=shared_object_tracker.seen_more_than_once,
),
rendering.annotations,
),
)
return rendering