-
Notifications
You must be signed in to change notification settings - Fork 45
/
canonical_aliases.py
592 lines (495 loc) · 20.7 KB
/
canonical_aliases.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Registry of certain "well-known" objects, such as module functions.
Taking the `repr` of a function or callable usually produces something like
<function vmap at 0x7f98bf556440>
and in some cases produces something like
<jax.custom_derivatives.custom_jvp object at 0x7f98c0b5f130>
This can make it hard to determine what object this actually is from a user
perspective, and it would likely be more user-friendly to just output `jax.vmap`
or `jax.nn.relu` as the representation of this object.
Many functions and classes store a reference to the location where they were
originally defined (in their __module__ and __qualname__ attributes), which
can be used to find an alias for them. However, this may not be the "canonical"
alias, because some modules re-export private symbols under a public namespace
(in particular, JAX, Penzai, and Equinox all do this).
This module contains a registry of canonical paths for specific functions and
other objects, so that they can be rendered in a predictable way by treescope.
The intended purpose is primarily for interactive printing and debugging,
although it also helps for reifying objects into executable code through
round-trippable pretty printing, which can enable a simple form of
serialization.
This module also supports walking the public API of a package to automatically
set aliases; this is done by default for JAX and a few other libraries to ensure
the pretty-printed outputs avoid private module paths whenever possible. This
is intended as a heuristic to construct readable aliases for common objects on a
best-effort basis. It is not guaranteed that these inferred aliases will always
be stable across different versions of the external libraries.
"""
import collections
import contextlib
import dataclasses
import inspect
import sys
import types
from typing import Any, Callable, Literal, Mapping
import warnings
from penzai.core import context
@dataclasses.dataclass(frozen=True)
class ModuleAttributePath:
"""Expected path where we can find a particular object in a module.
Attributes:
module_name: Fully-qualified name of the module (the key of the module in
sys.modules) in which this object can be found.
attribute_path: Sequence of attributes identifying this object, separated by
dots ("."). For instance, if this is ["foo", "bar"] then we expect to find
the object at "{module_name}.foo.bar".
"""
module_name: str
attribute_path: tuple[str, ...]
def __str__(self):
module_name = self.module_name
attribute_path_str = "".join(f".{attr}" for attr in self.attribute_path)
if module_name == "__main__":
assert attribute_path_str.startswith(".")
return attribute_path_str[1:]
return f"{module_name}{attribute_path_str}"
def retrieve(self, forgiving: bool = False) -> Any:
"""Retrieves the object at this path.
Args:
forgiving: If True, return None on failure instead of raising an error.
Returns:
The retrieved object, or None if it wasn't found and `forgiving` was True.
Raises:
ValueError: If the object wasn't found and `forgiving` is False.
"""
if self.module_name not in sys.modules:
if forgiving:
return None
raise ValueError(
f"Invalid alias {self} pointing to a non-imported module"
f" {self.module_name}"
)
try:
the_module = sys.modules[self.module_name]
except KeyError:
if forgiving:
return None
else:
raise
current_object = the_module
for attr in self.attribute_path:
if hasattr(current_object, attr):
current_object = getattr(current_object, attr)
else:
if forgiving:
return None
raise ValueError(
f"Can't retrieve {self}: {self.module_name} does"
f" not expose an attribute {self.attribute_path!r}."
)
return current_object
@dataclasses.dataclass(frozen=True)
class LocalNamePath:
"""Expected path where we can find a particular object in a local scope.
The "local scope" can be any dictionary of values with string keys, but it
is usually either the locals() or globals() for a particular scope (or the
union of these). A local name path is only valid relative to a particular
scope that was used to create it.
Attributes:
local_name: Name of the variable in the local scope that we are retrieving.
attribute_path: Sequence of attributes identifying this object, separated by
dots ("."). For instance, if this is ["foo", "bar"] then we expect to find
the object at "{local_name}.foo.bar".
"""
local_name: str
attribute_path: tuple[str, ...]
def __str__(self):
attribute_path_str = "".join(f".{attr}" for attr in self.attribute_path)
return f"{self.local_name}{attribute_path_str}"
def retrieve(
self, local_scope: dict[str, Any], forgiving: bool = False
) -> Any:
"""Retrieves the object at this path.
Args:
local_scope: The scope in which we should retrieve this value.
forgiving: If True, return None on failure instead of raising an error.
Returns:
The retrieved object, or None if it wasn't found and `forgiving` was True.
Raises:
KeyError, AttributeError: If the object wasn't found and `forgiving` is
False.
"""
try:
current_object = local_scope[self.local_name]
except KeyError:
if forgiving:
return None
else:
raise
for attr in self.attribute_path:
try:
current_object = getattr(current_object, attr)
except AttributeError:
if forgiving:
return None
else:
raise
return current_object
@dataclasses.dataclass(frozen=True)
class CanonicalAliasEnvironment:
"""An environment that defines a set of canonical aliases.
Attributes:
aliases: A mapping from id(some_object) to the path where we expect to find
that object.
lazy_populate_if_imported: A list of module names we should populate lazily
if they are imported, without importing them directly, along with a
predicate to use for them.
"""
aliases: dict[int, ModuleAttributePath]
lazy_populate_if_imported: list[
tuple[str, Callable[[Any, ModuleAttributePath], bool]]
]
_alias_environment: context.ContextualValue[CanonicalAliasEnvironment] = (
context.ContextualValue(
module=__name__,
qualname="_alias_environment",
initial_value=CanonicalAliasEnvironment({}, []),
)
)
"""The current environment for module-level canonical aliases.
All alias mutation and lookups occur relative to the current
environment. Usually, this will not need to be modified, but it can
be useful to modify for tests or other local modifications.
"""
def add_alias(
the_object: Any,
path: ModuleAttributePath,
on_conflict: Literal["ignore", "overwrite", "warn", "error"] = "warn",
):
"""Adds an alias to this object.
Args:
the_object: Object to add an alias to.
path: The path where we expect to find this object.
on_conflict: What to do if we try to add an alias for an object that already
has one.
Raises:
ValueError: If overwrite is False and this object already has an alias,
or if this object is not accessible at this path.
"""
alias_env = _alias_environment.get()
if id(the_object) in alias_env.aliases:
if on_conflict == "ignore":
return
elif on_conflict == "overwrite":
pass # Continue adding an alias
elif on_conflict == "warn":
warnings.warn(
f"Not defining alias {path} for {the_object!r}: it already has an"
f" alias {alias_env.aliases[id(the_object)]}."
)
elif on_conflict == "error":
raise ValueError(
f"Can't define alias {path} for {the_object!r}: it already has an"
f" alias {alias_env.aliases[id(the_object)]}."
)
retrieved = path.retrieve()
if retrieved is not the_object:
raise ValueError(
f"Can't define alias {path} for {the_object!r}: {path} "
f" is a different object {retrieved!r}."
)
# OK, it's probably safe to add this object as a well-known alias.
alias_env.aliases[id(the_object)] = path
_current_scope_for_local_aliases: context.ContextualValue[
dict[int, LocalNamePath] | None
] = context.ContextualValue(
module=__name__,
qualname="_current_scope_for_local_aliases",
initial_value=None,
)
"""A mapping from IDs to local names.
Should only be used or modified by `local_alias_names`.
"""
def update_lazy_aliases() -> None:
"""Checks for newly-imported modules and defines aliases for them.
This function loops over the modules listed in `lazy_populate_if_imported`
for the active environment, and adds canonical aliases for any modules that
were recently imported.
"""
alias_env = _alias_environment.get()
# Check for newly-imported modules that we should define aliases for.
all_handled = []
for name, predicate in alias_env.lazy_populate_if_imported:
if name in sys.modules:
populate_from_public_api(sys.modules[name], predicate)
all_handled.append((name, predicate))
if all_handled:
for pair in all_handled:
alias_env.lazy_populate_if_imported.remove(pair)
def lookup_alias(
the_object: Any,
infer_from_attributes: bool = True,
allow_outdated: bool = False,
allow_relative: bool = False,
) -> ModuleAttributePath | LocalNamePath | None:
"""Retrieves an alias for this object, if possible.
This function checks if an alias has been registered for this object, and also
makes sure the object is still available at that alias. It optionally also
tries to infer a fallback alias using __module__ and __qualname__ attributes.
Args:
the_object: Object to get a well-known alias for.
infer_from_attributes: Whether to use __module__ and __qualname__ attributes
to infer an alias if no explicit path is registered.
allow_outdated: If True, return old aliases even if the object is no longer
accessible at this path. (For instance, if the module was reloaded after
this class/function was defined.)
allow_relative: If True, return aliases that are local to the current
`relative_alias_names` context if possible.
Returns:
A path at which we can find this object, or None if we do not have a path
for it (or if the path is no longer correct).
"""
if the_object is None:
# None itself should never have an alias. Checking for it here lets us
# easly catch the "broken alias" case in `retrieve` below.
return None
alias_env = _alias_environment.get()
# Is this object in the local aliases? If so, return that.
if allow_relative:
local_aliases = _current_scope_for_local_aliases.get()
else:
local_aliases = None
if local_aliases and id(the_object) in local_aliases:
return local_aliases[id(the_object)]
# Check for a global canonical alias.
if id(the_object) in alias_env.aliases:
alias = alias_env.aliases[id(the_object)]
elif infer_from_attributes:
# Try to unwrap it, in case it's a function-like object wrapping a function.
unwrapped = inspect.unwrap(the_object)
if (
hasattr(unwrapped, "__module__")
and hasattr(unwrapped, "__qualname__")
and "<" not in unwrapped.__qualname__
):
alias = ModuleAttributePath(
unwrapped.__module__, tuple(unwrapped.__qualname__.split("."))
)
elif isinstance(the_object, types.ModuleType):
alias = ModuleAttributePath(the_object.__name__, ())
else:
# Can't infer an alias.
return None
else:
return None
if not allow_outdated:
if isinstance(the_object, types.MethodType):
# Methods get different IDs on each access.
if alias.retrieve(forgiving=True) != the_object:
return None
else:
if alias.retrieve(forgiving=True) is not the_object:
return None
if local_aliases:
# Check if any of the attributes along this path have a local alias.
for split_point in reversed(range(len(alias.attribute_path))):
parent_object = ModuleAttributePath(
alias.module_name, alias.attribute_path[:split_point]
).retrieve(forgiving=True)
if id(parent_object) in local_aliases:
local_path_to_parent = local_aliases[id(parent_object)]
return LocalNamePath(
local_name=local_path_to_parent.local_name,
attribute_path=(
local_path_to_parent.attribute_path
+ alias.attribute_path[split_point:]
),
)
# Check if any parent module of the module in which this global alias is
# defined has a local alias that we should use instead (e.g. using `np`
# instead of `numpy`).
module_parts = alias.module_name.split(".")
submodule_path_reversed = []
while module_parts:
# Get an equivalent path from the parent module.
parent_module_alias = ModuleAttributePath(
".".join(module_parts),
tuple(reversed(submodule_path_reversed)) + alias.attribute_path,
)
# Is this a real parent module? (It probably should be unless someone
# mucked with import paths.)
if parent_module_alias.module_name in sys.modules:
module_id = id(sys.modules[parent_module_alias.module_name])
# Do we have a local alias for this parent module? And, if so, can we
# access this particular object through the parent module in the
# expected way?
if (
module_id in local_aliases
and parent_module_alias.retrieve(forgiving=True) is the_object
):
# First lookup the module in the local scope, then lookup the value
# in the module.
local_path_to_module = local_aliases[module_id]
return LocalNamePath(
local_name=local_path_to_module.local_name,
attribute_path=(
local_path_to_module.attribute_path
+ parent_module_alias.attribute_path
),
)
submodule_path_reversed.append(module_parts.pop())
return alias
def maybe_local_module_name(module: types.ModuleType) -> str:
"""Returns a name for this module, possibly looking up local aliases."""
alias = lookup_alias(module, allow_outdated=True, allow_relative=True)
assert alias is not None
return str(alias)
def default_well_known_filter(
the_object: Any, path: ModuleAttributePath | LocalNamePath
) -> bool:
"""Checks if an object looks like something we want to define an alias for."""
is_function_after_unwrap = isinstance(
inspect.unwrap(the_object), types.FunctionType
)
# Only define aliases for objects that are either (a) mutable or (b)
# classes/functions/modules.
if (
(hasattr(the_object, "__hash__") and the_object.__hash__ is not None)
and not isinstance(the_object, (types.ModuleType, type))
and not is_function_after_unwrap
):
return False
# Don't allow classes and functions to be assigned to any name other than the
# name they were given at creation time.
if isinstance(the_object, type) or (
is_function_after_unwrap and hasattr(the_object, "__name__")
):
expected_name = the_object.__name__
if path.attribute_path and expected_name != path.attribute_path[-1]:
return False
# Assume any name that starts with an underscore is private.
if any(attr.startswith("_") for attr in path.attribute_path) or (
isinstance(path, LocalNamePath) and path.local_name.startswith("_")
):
return False
return True
def relative_alias_names(
relative_scope: Mapping[str, Any] | Literal["magic"],
predicate: Callable[[Any, LocalNamePath], bool] = default_well_known_filter,
) -> contextlib.AbstractContextManager[None]:
"""Context manager that makes `lookup_alias` return relative aliases.
Args:
relative_scope: A dictionary mapping in-scope names to their values, e.g.
`globals()` or `{**globals(), **locals()}`. Objects that are reachable
from this scope will have aliases that reference the keys in this dict.
Alternatively, can be the string "magic", in which case we will walk the
stack and use the `{**globals(), **locals()}` of the caller. ("magic" is
only recommended for interactive usage.)
predicate: A filter function to check if an object should be given an alias.
Returns:
A context manager in which `lookup_alias` will try to return relative
aliases when `allow_relative=True`.
"""
if relative_scope == "magic":
# Infer the scope by walking the stack. 1 is the caller's frame.
caller_frame_info = inspect.stack()[1]
relative_scope = collections.ChainMap(
caller_frame_info.frame.f_globals,
caller_frame_info.frame.f_locals,
)
local_alias_map = {}
for key, value in relative_scope.items():
path = LocalNamePath(key, ())
if predicate(value, path):
local_alias_map[id(value)] = path
return _current_scope_for_local_aliases.set_scoped(local_alias_map)
def populate_from_public_api(
module: types.ModuleType,
predicate: Callable[
[Any, ModuleAttributePath], bool
] = default_well_known_filter,
):
"""Populates canonical aliases with all public symbols in a module.
Attempts to walk this module and its submodules to extract well-known symbols.
Symbols that already have an alias defined will be ignored.
If the module defines __all__, we assume the symbols in __all__ are the
well-known symbols in this module. (See
https://docs.python.org/3/reference/simple_stmts.html#the-import-statement)
If the module does not define __all__, we look for all names that do not
start with "_".
We then additionally filter down this set to the set of objects for which
the `predicate` argument returns True.
This function should only be called on modules with a well-defined public API,
that exports only functions defined in the module itself. If a module
re-exports symbols from another external module (e.g. importing `partial`
from `functools`), we might otherwise end up using the unrelated module as
the "canonical" source of that object. (The `prefix_filter` function below
tries to prevent this if possible when used as a predicate.)
Args:
module: The module we will collect symbols from.
predicate: A filter function to check if an object should be given an alias.
"""
if hasattr(module, "__all__"):
public_names = module.__all__
else:
public_names = [
key for key in module.__dict__.keys() if not key.startswith("_")
]
for name in public_names:
value = getattr(module, name)
path = ModuleAttributePath(module.__name__, (name,))
if isinstance(value, types.ModuleType):
if (
value.__name__.startswith(module.__name__)
and value.__name__ != module.__name__
):
# Process submodules of this module also.
populate_from_public_api(value, predicate)
# Don't process external modules that are being re-exported.
elif predicate(value, path):
add_alias(value, path, on_conflict="ignore")
def prefix_filter(prefix: str):
"""Builds a filter that only defines aliases within a given prefix."""
def module_is_under_prefix(module_name: str):
return module_name == prefix or module_name.startswith(prefix + ".")
def predicate(the_object: Any, path: ModuleAttributePath) -> bool:
if not default_well_known_filter(the_object, path):
return False
if not module_is_under_prefix(path.module_name):
return False
if (
hasattr(the_object, "__module__")
and the_object.__module__
and not module_is_under_prefix(the_object.__module__)
):
return False
return True
return predicate
# Register well-known aliases for the functions defined in these modules, since
# they are likely to be used in penzai code.
_alias_environment.get().lazy_populate_if_imported.extend([
# Third-party libraries with useful APIs:
("numpy", prefix_filter("numpy")),
("jax.lax", prefix_filter("jax")),
("jax.numpy", prefix_filter("jax")),
("jax.scipy", prefix_filter("jax")),
("jax.random", prefix_filter("jax")),
("jax.nn", prefix_filter("jax")),
("jax.custom_derivatives", prefix_filter("jax")),
("jax.experimental.pjit", prefix_filter("jax")),
("jax.experimental.shard_map", prefix_filter("jax")),
("jax", prefix_filter("jax")),
("equinox", prefix_filter("equinox")),
])