-
Notifications
You must be signed in to change notification settings - Fork 45
/
named_axes.py
1666 lines (1405 loc) · 62.4 KB
/
named_axes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A lightweight and minimal implementation of named axes.
As argued by "Tensors Considered Harmful", relying on axis indices for complex
tensor operations can be brittle and difficult to read. This has led to a
number of proposals for indexing axes by name instead of by position. However,
due to the large API surface for NDArray manipulation, building a fully-featured
named axis implementation requires making named-axis versions of many individual
operations.
This module provides a lightweight implementation of named axes using a
"locally positional" style. The key idea is to reuse positional-axis operations
in their original form, but use *local* bindings of positional axes to the
underlying named axes; positional axes then become aliases for particular
named axes within a local context.
To see how this works, suppose we want to implement dot-product attention
between queries and keys. We'd start with two named arrays::
queries = pz.nx.wrap(...).tag("batch", "query_pos", "heads", "embed")
keys = pz.nx.wrap(...).tag("batch", "key_pos", "heads", "embed")
We then contract them against each other, discarding the "embed" dimension
and broadcasting jointly over the others (e.g. "bqhf, bkhf -> bqkh" in einsum
notation). In our "locally positional" style, we could write this as::
dot_prods = nmap(jnp.dot)(queries.untag("embed"), keys.untag("embed"))
Here `jnp.dot` is called with two one-axis views of ``queries`` and ``keys``,
respectively. (More specifically, it is called with `jax.vmap` tracers that
have a single logical axis and three implicitly-broadcasted axes.) We could just
as easily use our own function::
def my_dot(a, b):
print("a:", a)
print("b:", b)
print("a.shape:", a.shape)
print("b.shape:", b.shape)
return jnp.dot(a, b)
dot_prods = nmap(my_dot)(queries.untag("embed"), keys.untag("embed"))
We can similarly apply ``softmax`` over one of the axes::
attn_weights = nmap(jax.nn.softmax)(
dot_prods.untag("key_pos")).tag("key_pos")
In this case, we need to "tag" the positional axis produced by softmax with a
name, and we choose to give it the same name as the original axis.
One advantage of the locally-positional style is that it does not require
wrapping/modifying any of the functions in the numpy/JAX API to take axis names;
instead, the primitives are written in terms of ordinary positional-axis logic.
This means that the full API surface for named axes can be very small. It
also means that it's easy to "drop down" into positional-axis code and do more
complex modifications (e.g. slicing, updating) without losing the readability
or flexibility of named-axis code.
The locally-positional style is fairly similar to the notation used in
the paper `"Named Tensor Notation"`_ (Chiang, Rush, and Barak, 2022), in which
ordinary mathematical notation is extended with subscripts to identify which
axis or axes they should operate over. In both cases, any names that do NOT
appear as part of the operation are implicitly vectorized over. The primary
difference is that named axes are specified (by ``untag``) separately for each
argument instead of being necessarily shared; this simplifies operations that
act over different names for each argument or that produce new axis names as
outputs.
.. _"Named Tensor Notation": https://arxiv.org/abs/2102.13196
For more information, see the named axis tutorial in ``penzai/notebooks``.
"""
from __future__ import annotations
import abc
import collections
import dataclasses
import functools
import operator
import typing
from typing import Any, Callable, Hashable, Mapping, Sequence
import jax
import jax.numpy as jnp
import numpy as np
import ordered_set
from penzai.core import struct
# Axis names are almost always strings, but can be arbitrary hashable objects
# (except integers, which aren't allowed to avoid confusion with positional
# axes.)
AxisName: typing.TypeAlias = Hashable
class TmpPosAxisMarker:
"""A marker object used to temporarily assign names to positional axes.
Every ``TmpPosAxisMarker`` is unique, comparing equal only to itself, so it is
always safe to bind a ``TmpPosAxisMarker`` to a positional axis without
worrying about axis name conflicts.
"""
def nmap(fun: Callable[..., Any]) -> Callable[..., Any]:
"""Automatically vectorizes ``fun`` over named axes of NamedArray inputs.
``nmap`` is a "named-axis vectorizing map". It wraps an ordinary
positional-axis-based function so that it accepts NamedArrays as input and
produces NamedArrays as output, and vectorizes over all of the named axes,
calling the original function with positionally-indexed slices corresponding
to each argument's `positional_shape`.
Unlike `jax.vmap`, the axes to vectorize over are inferred
automatically from the named axes in the NamedArray / NamedArrayView, rather
than being specified as part of the mapping transformation. Specifically, each
axis name that appears in any of the arguments is vectorized over jointly
across all arguments that include that axis name, and is then included as an
axis name in the output. To make an axis visible to ``fun``, you can call
`untag` on the argument and pass the axis name(s) of interest; ``fun`` will
then see those axes as positional axes instead of mapping over them.
`untag` and ``nmap`` are together the primary ways to apply individual
operations to axes of a NamedArray. `tag` can then be used on the result to
re-bind names to positional axes.
Within ``fun``, any mapped-over axes will be accessible using standard JAX
collective operations like ``psum``, although doing this is usually
unnecessary.
Args:
fun: Function to vectorize by name. This can take arbitrary arguments (even
non-JAX-arraylike arguments or "static" axis sizes), but must produce a
PyTree of JAX ArrayLike outputs.
Returns:
An automatically-vectorized version of ``fun``, which can optionally be
called with NamedArrays (or NamedArrayViews) instead of ordinary arrays, and
which will always return NamedArrays (or NamedArrayViews) for each of its
output leaves. Any argument (or PyTree leaf of an argument) that is a
NamedArray(View) will have its named axes vectorized over; ``fun`` will then
be called with batch tracers corresponding to slices of the input array that
are shaped like ``named_array_arg.positional_shape``. Every axis name that
appeared in any input will also appear in every output.
"""
if hasattr(fun, "__name__"):
fun_name = fun.__name__
else:
fun_name = repr(fun)
if hasattr(fun, "__doc__"):
fun_doc = fun.__doc__
else:
fun_doc = None
return _nmap_with_doc(fun, fun_name, fun_doc)
def _nmap_with_doc(
fun: Callable[..., Any], fun_name: str, fun_doc: str | None = None
) -> Callable[..., Any]:
"""Builds a nmap-wrapped function with a docstring."""
@functools.wraps(fun)
def wrapped_fun(*args, **kwargs):
arg_leaves_and_paths, arg_treedef = jax.tree_util.tree_flatten_with_path(
(args, kwargs),
is_leaf=lambda node: isinstance(node, NamedArray | NamedArrayView),
)
arg_leaves = [leaf for _, leaf in arg_leaves_and_paths]
# Extract any argument leaves that were NamedArrays or NamedArrayViews. The
# rest of the arguments will just be closed over, so they don't have to be
# arraylike. We also check named shapes here.
# To simplify the implementation, we ensure that all arguments are views
# rather than positional-prefix NamedArrays; this can always be done without
# any device computations.
named_array_arg_leaves = []
known_sizes = {}
bad_names = []
for leaf in arg_leaves:
if isinstance(leaf, NamedArray | NamedArrayView):
for name, size in leaf.named_shape.items():
if name in known_sizes:
if known_sizes[name] != size and name not in bad_names:
bad_names.append(name)
else:
known_sizes[name] = size
named_array_arg_leaves.append(leaf.as_namedarrayview())
if bad_names:
msg = [
f"Inconsistent named axes in a call to nmap({fun}) for axes"
f" {bad_names}:"
]
for keypath, leaf in arg_leaves_and_paths:
if isinstance(leaf, NamedArray | NamedArrayView):
assert keypath
if keypath[0] == jax.tree_util.SequenceKey(0):
prefix = f"args{jax.tree_util.keystr(keypath[1:])}"
elif keypath[0] == jax.tree_util.SequenceKey(1):
prefix = f"kwargs{jax.tree_util.keystr(keypath[1:])}"
else:
# Should never happen!
prefix = f"tree{jax.tree_util.keystr(keypath)}"
msg.append(f" {prefix}.named_shape == {leaf.named_shape}")
raise ValueError("\n".join(msg))
# Prepare a version of the function that accepts tracers for each of those
# extracted arguments, and rebuilds the full arguments to `fun`.
def flat_array_fun(batch_tracers):
# Replace each NamedArray with its batch tracer.
batch_tracers_stack = batch_tracers[::-1]
new_arg_leaves = []
for leaf in arg_leaves:
if isinstance(leaf, NamedArray | NamedArrayView):
# Substitute the tracer.
new_arg_leaves.append(batch_tracers_stack.pop())
else:
new_arg_leaves.append(leaf)
# Call the function.
args, kwargs = jax.tree_util.tree_unflatten(arg_treedef, new_arg_leaves)
return fun(*args, **kwargs)
# Collect the axis names. This determines what we will be vectorizing over.
# We use an ordered set to guarantee a deterministic ordering based on the
# arguments, since sets have nondeterministic iteration order in Python.
all_names = ordered_set.OrderedSet()
for named_arg in named_array_arg_leaves:
all_names.update(named_arg.data_axis_for_name.keys())
all_names = list(all_names)
# Recursively vectorize over each axis.
def recursive_vectorize_step(current_views, remaining_names):
if not remaining_names:
# All names have been processed, so none of the args should have names.
# Unwrap them and call the function.
return flat_array_fun([view.unwrap() for view in current_views])
# Otherwise, we still have names to vectorize over. Pop one name off the
# stack and vectorize over it as needed.
vmap_name = remaining_names[0]
reduced_views = []
vmap_axes = []
for view in current_views:
if vmap_name in view.data_axis_for_name:
vmap_axis = view.data_axis_for_name[vmap_name]
vmap_axes.append(vmap_axis)
# pylint: disable=cell-var-from-loop
def _shift_axis(other_axis):
assert other_axis != vmap_axis
if other_axis < vmap_axis:
return other_axis
else:
return other_axis - 1
# pylint: enable=cell-var-from-loop
# We are temporarily constructing an "invalid" view here because
# data_array will still have the extra axis. But after running `vmap`,
# it will be valid again.
reduced_views.append(
NamedArrayView(
data_array=view.data_array,
data_axis_for_name={
name: _shift_axis(data_axis)
for name, data_axis in view.data_axis_for_name.items()
if name != vmap_name
},
data_axis_for_logical_axis=tuple(
_shift_axis(data_axis)
for data_axis in view.data_axis_for_logical_axis
),
data_shape=(
view.data_shape[:vmap_axis]
+ view.data_shape[vmap_axis + 1 :]
),
)
)
else:
# This argument doesn't have this axis, so don't map over anything.
vmap_axes.append(None)
reduced_views.append(view)
return jax.vmap(
functools.partial(
recursive_vectorize_step,
remaining_names=remaining_names[1:],
),
in_axes=(vmap_axes,),
out_axes=0,
axis_name=vmap_name,
)(reduced_views)
# Run the function.
result_data = recursive_vectorize_step(named_array_arg_leaves, all_names)
# Wrap all leaves in NamedArray or NamedArrayView, assigning the names from
# `all_names` to their mapped-over axes. The mapped-over named axes always
# end up at the front, followed by positional axes, so if there are any
# positional axes we need to return a NamedArrayView.
def handle_result(leaf):
leaf = jnp.array(leaf)
if leaf.ndim == len(all_names):
return NamedArray(
data_array=leaf,
named_axes=collections.OrderedDict(zip(all_names, leaf.shape)),
)
else:
assert leaf.ndim > len(all_names)
return NamedArrayView(
data_array=leaf,
data_shape=leaf.shape,
data_axis_for_name={name: i for i, name in enumerate(all_names)},
data_axis_for_logical_axis=tuple(range(len(all_names), leaf.ndim)),
)
return jax.tree_util.tree_map(handle_result, result_data)
docstr = (
f"Name-vectorized version of `{fun_name}`. Takes similar arguments as"
f" `{fun_name}` but accepts and returns NamedArrays (or NamedArrayViews)"
" in place of regular arrays."
)
if fun_doc:
docstr += f"\n\nOriginal documentation:\n\n{fun_doc}"
wrapped_fun.__doc__ = docstr
return wrapped_fun
def _swapped_binop(binop):
"""Swaps the order of operations for a binary operation."""
def swapped(x, y):
return binop(y, x)
return swapped
def _wrap_scalar_conversion(scalar_conversion):
"""Wraps a scalar conversion operator on a named array."""
def wrapped_scalar_conversion(self: NamedArrayBase):
if self.named_shape or self.positional_shape:
raise ValueError(
"Cannot convert a non-scalar NamedArray or NamedArrayView with"
f" {scalar_conversion}"
)
return scalar_conversion(self.unwrap())
return wrapped_scalar_conversion
def _wrap_array_method(name):
"""Wraps an array method on a named array."""
def func(array, *args, **kwargs):
return getattr(array, name)(*args, **kwargs)
array_method = getattr(jax.Array, name)
wrapped_func = nmap(func)
functools.update_wrapper(
wrapped_func,
array_method,
assigned=("__name__", "__qualname__", "__annotations__"),
updated=(),
)
wrapped_func.__module__ = __name__
wrapped_func.__doc__ = (
"Name-vectorized version of array method"
f" `{name} <numpy.ndarray.{name}>`. Takes similar arguments as"
f" `{name} <numpy.ndarray.{name}>` but accepts and returns NamedArrays"
" (or NamedArrayViews) in place of regular arrays."
)
return wrapped_func
@struct.pytree_dataclass
class _StaticThunk(struct.Struct):
value: Any = dataclasses.field(metadata={"pytree_node": False})
def unwrap(self):
return self.value
@struct.pytree_dataclass
class _DynamicThunk(struct.Struct):
value: Any
def unwrap(self):
return self.value
@struct.pytree_dataclass
class _SliceThunk(struct.Struct):
start: Any = dataclasses.field(metadata={"pytree_node": False})
stop: Any = dataclasses.field(metadata={"pytree_node": False})
step: Any = dataclasses.field(metadata={"pytree_node": False})
def unwrap(self):
return slice(self.start, self.stop, self.step)
@jax.jit
@nmap
def _jitted_nmapped_getitem(
array, index_thunks: tuple[_StaticThunk | _DynamicThunk | _SliceThunk, ...]
):
indexer = tuple(thunk.unwrap() for thunk in index_thunks)
return array[indexer]
class NamedArrayBase(abc.ABC):
"""Base class for named arrays and their transposed views."""
# Abstract methods.
@property
@abc.abstractmethod
def dtype(self) -> np.dtype:
"""The dtype of the wrapped array."""
@abc.abstractmethod
def check_valid(self) -> None:
"""Checks that the names in the array are correct."""
@property
@abc.abstractmethod
def named_shape(self) -> Mapping[AxisName, int]:
"""A mapping of axis names to their sizes."""
@property
@abc.abstractmethod
def positional_shape(self) -> tuple[int, ...]:
"""A tuple of axis sizes for any anonymous axes."""
@abc.abstractmethod
def unwrap(self, *names: AxisName) -> jax.Array:
"""Unwraps this array, possibly mapping axis names to positional axes.
Unwrap can be called either on arrays with no named axes, or arrays with
no positional axes (in which case ``names`` should be a permutation of its
axis names).
Args:
*names: Sequence of axis names to map to positional axes, if this array
has named axes. Shortand for ``untag(*names).unwrap()``.
Returns:
An equivalent ordinary positional array.
Raises:
ValueError: If the array has a mixture of positional and named axes, or if
the names do not match the named axes.
"""
@abc.abstractmethod
def with_positional_prefix(self) -> NamedArray:
"""Ensures a view is a `NamedArray` by moving positional axes.
The resulting `NamedArray` has the same named and positional shapes as
this object, but the data array may be transposed so that all the positional
axes are in the front. This makes it possible to manipulate those prefix
axes safely using `jax.tree_util` or scan/map over them using JAX
control flow primitives.
Returns:
An equivalent `NamedArray` for this view, or the original `NamedArray`
if it already was one.
"""
@abc.abstractmethod
def as_namedarrayview(self) -> NamedArrayView:
"""Converts into a `NamedArrayView`, keeping positional axes.
This function is usually not necessary for ordinary named-array
manipulation, since `NamedArray` and `NamedArrayView` define the same
methods. However, it can be useful for simplifying library code that wishes
to access the fields of `NamedArrayView` directly, or handle arbitrary
named array objects without handling each case separately.
Converting a `NamedArray` to a `NamedArrayView` never involves any
device computations. (The reverse is not true).
Returns:
An equivalent `NamedArrayView` for this array if it isn't one already.
"""
@abc.abstractmethod
def untag(self, *axis_order: AxisName) -> NamedArray | NamedArrayView:
"""Produces a positional view of the requested axis names.
`untag` can only be called on a `NamedArray` or `NamedArrayView` that
does not have any positional axes. It produces a new `NamedArrayView` where
the axes with the requested names (the arguments to this function) are now
treated as positional in the given order.
If you want to use `untag` on an array that already has positional axes,
you can use `untag_prefix` instead.
Args:
*axis_order: Axis names to make positional, in the order they should
appear in the positional view.
Raises:
ValueError: If this array already has positional axes, or if the provided
axis ordering is not valid.
Returns:
A view with the given axes treated as positional for the purposes of
later calls to `apply`, `nmap`, or `with_positional_prefix`. If passed
an empty axis order, returns an ordinary NamedArray with no positional
axes.
"""
@abc.abstractmethod
def tag(self, *names) -> NamedArray:
"""Attaches names to the positional axes of an array or view.
Args:
*names: Axis names to assign to each positional axis in the array or view.
Must be the same length as `positional_shape`; if you only want to tag a
subset of axes, use `tag_prefix` instead.
Raises:
ValueError: If the names are invalid, or if they aren't the same length
as `positional_shape`.
Returns:
A NamedArray with the given names assigned to the positional axes, and no
remaining positional axes.
"""
# Inherited methods that can already be implemented in terms of the above.
def untag_prefix(self, *axis_order: AxisName) -> NamedArray | NamedArrayView:
"""Adds the requested axes to the front of the array's positional axes.
This is a version of `untag` that can be called on NamedArrays or
NamedArrayViews that already have positional axes.
Args:
*axis_order: Axis names to make positional, in the order they should
appear in the positional view.
Returns:
A view with the given axes treated as positional, followed by the existing
positional axes.
"""
# We implement `untag_prefix` using `untag` with temporary axis
# identifiers.
tmp_axis_ids = [TmpPosAxisMarker() for _ in self.positional_shape]
return self.tag(*tmp_axis_ids).untag(*axis_order, *tmp_axis_ids)
def tag_prefix(self, *axis_order: AxisName) -> NamedArray | NamedArrayView:
"""Attaches names to the first positional axes in an array or view.
This is a version of `tag` that allows you to name only a subset of the
array's positional axes.
Args:
*axis_order: Axis names to make positional, in the order they should
appear in the positional view.
Returns:
A NamedArray or view with the given names assigned to the first positional
axes, and whose positional shape includes only the suffix of axes that
have not been given names.
"""
# We implement `tag_prefix` using `tag` with temporary axis
# identifiers.
tmp_axis_ids = [
TmpPosAxisMarker()
for _ in range(len(self.positional_shape) - len(axis_order))
]
return self.tag(*axis_order, *tmp_axis_ids).untag(*tmp_axis_ids)
def order_as(self, *axis_order: AxisName) -> NamedArray:
"""Ensures that the named axes are stored in this order, keeping them named.
This function can be used if it is important for the axis names to appear in
a consistent order, e.g. to ensure that two `NamedArray` instances have
exactly the same PyTree structure.
If you want a canonical ordering for a named array that doesn't involve
knowing all the axis names in advance, you could do something like
``array.order_as(*sorted(array.named_shape.keys()))``.
See also `order_like`.
Args:
*axis_order: Axis names in the order they should appear in the data array.
Must be a permutation of all of the axis names in this array.
Returns:
Equivalent `NamedArray` whose data array contains the positional axes
followed by the named axes in the given order.
"""
# Create temporary "names" for the positional axes by creating new objects,
# which are hashable but only compare equal by ID and are thus
# guaranteed unique.
tmp_names = [TmpPosAxisMarker() for _ in self.positional_shape]
data_array = self.tag(*tmp_names).untag(*tmp_names, *axis_order).unwrap()
return (
NamedArray.wrap(data_array)
.tag(*tmp_names, *axis_order)
.untag(*tmp_names)
.with_positional_prefix()
)
def order_like(
self, other: NamedArray | NamedArrayView
) -> NamedArray | NamedArrayView:
"""Ensures that this array's PyTree structure matches another array's.
This can be used to ensure that one named array has the same PyTree
structure as another, so that the two can be jointly processed by
non-namedarray-aware tree functions (e.g. `jax.tree_util` functions,
`jax.lax.cond`, `jax.jvp`, etc).
To ensure compatibility of entire PyTrees, you can use something like::
jax.tree_util.tree_map(
lambda a, b: a.order_like(b), tree1, tree2,
is_leaf=pz.nx.is_namedarray,
)
Args:
other: Another named array or named array view. Must have the same set of
named axes as this one. If this is a `NamedArrayView`, must also have
the same positional axes.
Returns:
A new `NamedArray` or `NamedArrayView` that has the content of ``self``
but is possibly transposed to have the same PyTree structure as ``other``
(as long as the arrays have the same shape).
"""
self.check_valid()
other.check_valid()
if isinstance(other, NamedArray):
return self.order_as(*other.named_shape.keys())
elif isinstance(other, NamedArrayView):
if (
self.positional_shape != other.positional_shape
or self.named_shape != other.named_shape
):
raise ValueError(
"Calling `order_like` with a NamedArrayView requires the two"
" arrays have the same positional and named shapes."
f" {self.positional_shape=}, {self.named_shape=},"
f" {other.positional_shape=}, {other.named_shape=}"
)
self_view = self.as_namedarrayview()
new_to_old_data_axis = {}
for old_data_axis, new_data_axis in zip(
self_view.data_axis_for_logical_axis, other.data_axis_for_logical_axis
):
new_to_old_data_axis[new_data_axis] = old_data_axis
for name, new_data_axis in other.data_axis_for_name.items():
new_to_old_data_axis[new_data_axis] = self_view.data_axis_for_name[name]
new_data_array = jnp.transpose(
self_view.data_array,
[new_to_old_data_axis[i] for i in range(self_view.data_array.ndim)],
)
assert new_data_array.shape == other.data_shape
return NamedArrayView(
data_shape=other.data_shape,
data_axis_for_logical_axis=other.data_axis_for_logical_axis,
data_axis_for_name=other.data_axis_for_name,
data_array=new_data_array,
)
else:
raise TypeError(
"`order_like` requires a NamedArray or NamedArrayView, but got"
f" {type(other).__name__}"
)
def broadcast_to(
self,
positional_shape: Sequence[int] = (),
named_shape: Mapping[AxisName, int] | None = None,
) -> NamedArrayBase:
"""Broadcasts a named array to a possibly-larger shape.
Args:
positional_shape: Desired positional shape for the array. Will be
broadcast using numpy broadcasting rules.
named_shape: Desired named shape for the array. Will be broadcast using
`nmap`-style vectorizing rules (e.g. new named axes will be introduced
if missing, but length-1 axes will not be broadcast).
Returns:
A named array that has the given positional and named shapes. Note that
if this array has axis names that are not in ``named_shape``, these will
be preserved in the answer as well.
"""
if named_shape is None:
named_shape = {}
named_shape = dict(named_shape)
if (
self.positional_shape == tuple(positional_shape)
and dict(self.named_shape) == named_shape
):
return self
# Trick: create a size-zero array with the right shape so that we can
# broadcast using nmap's vectorization rules.
prototype_data = jnp.zeros(
tuple(named_shape.values()) + tuple(positional_shape) + (0,)
)
assert prototype_data.size == 0
prototype = NamedArray.wrap(prototype_data).tag_prefix(*named_shape.keys())
return nmap(lambda a, b: jnp.broadcast_to(a, b.shape[:-1]))(self, prototype)
def broadcast_like(
self, other: NamedArrayBase | jax.typing.ArrayLike
) -> NamedArrayBase:
"""Broadcasts a named array to be compatible with another.
Args:
other: Another named array.
Returns:
A named array that has the same positional and named shapes as ``other``
(although it may also include extra named axes).
"""
if isinstance(other, NamedArrayBase):
return self.broadcast_to(other.positional_shape, other.named_shape)
else:
shape = jnp.shape(other)
return self.broadcast_to(shape, {})
def canonicalize(self) -> NamedArray:
"""Ensures that the named axes are stored in a canonical order.
Returns:
Equivalent `NamedArray` whose data array contains the positional axes
followed by the named axes in sorted order.
"""
return self.order_as(*sorted(self.named_shape.keys()))
# Indexing.
def __getitem__(self, indexer) -> NamedArray | NamedArrayView:
"""Retrieves slices from an indexer.
`NamedArray` and `NamedArrayView` can be indexed in two different ways,
depending on whether they have positional axes or not.
If they do have positional axes, those positional axes can be indexed into
using ordinary Numpy-style indexing. Indexing operations will be
automatically vectorized over all of the named axes. For instance, an
embedding lookup could look something like::
embedding_table.untag("vocab")[token_ids]
which first untags the "vocab" named axis as positional, then indexes into
that axis using another array (which can be a `NamedArray` or an ordinary
array).
If they do NOT have positional axes, the named axes can be sliced directly
using dictionaries mapping names to indices or slices, e.g.::
my_array[{"position": 1, "feature": pz.slice[2:5]}]
Here ``pz.slice[2:5]`` is syntactic sugar for ``slice(2, 5, None)``.
Args:
indexer: Either a normal Numpy-style indexer into the positional axes, or
a mapping from a subset of axes names to the indices or slice it should
return.
Returns:
A slice of the array.
"""
if isinstance(indexer, dict):
# Dict indexing => desugar it to positional indexing over the requested
# names.
self.check_valid()
# Create temporary "names" for existing positional axes by creating new
# objects, which are hashable but only compare equal by ID and are thus
# guaranteed unique.
tmp_names_for_pos = [TmpPosAxisMarker() for _ in self.positional_shape]
# Figure out how to bind the requested keys to new positional axes to
# apply indexing to them.
input_names = []
output_names = []
index_seq = []
for name, index in indexer.items():
if index is None:
# New axis.
index_seq.append(None)
output_names.append(name)
else:
# Slicing an existing axis.
input_names.append(name)
index_seq.append(index)
if isinstance(index, slice):
# If a slice is provided, this axis will still appear in the output.
output_names.append(name)
elif isinstance(index, NamedArrayBase):
if index.positional_shape:
raise TypeError(
"Dict-style indexing of a named array with another named"
" array is only supported if the indexer has only named axes,"
" no positional axes. Please tag the positional axes with a"
" name, either the same as one of the target array's names or"
" a new one, depending on the desired behavior."
)
# Named axes get mapped over automatically, so this temporary
# positional axis will go away in the mapped computation.
elif isinstance(index, int) or (
isinstance(index, jax.Array | np.ndarray)
and jnp.issubdtype(index.dtype, np.integer)
and index.ndim == 0
):
# Scalar indexing, this is OK. This axis will be removed.
pass
else:
raise TypeError(
"Dict-style indexing of a named array only supports indexing"
" with scalar integers, None, slices, and fully-named"
f" NamedArray(View)s. Got: {index}\nNote: If you are trying to"
" perform Numpy-style advanced indexing with an integer"
" positional array, you can instead index with a NamedArray"
" with a fresh named axis, which will be vectorized over"
" automatically."
)
return (
self.tag(*tmp_names_for_pos)
.untag(*input_names)[tuple(index_seq)]
.tag(*output_names)
.untag(*tmp_names_for_pos)
)
else:
# Normal indexing => map it over any named axes.
# We always jit-compile `getitem`, because eagerly `vmap`-ing (and, by
# extension, `nmap`-ing) a gather operation can lead to spurious
# transposes that waste device memory when indexing into a large array.
# We have to do a bit of trickery to deal with slices and non-jittable
# array data.
if not isinstance(indexer, tuple):
indexer = (indexer,)
index_thunks = []
for c in indexer:
if isinstance(c, jax.Array | np.ndarray | NamedArrayBase | int):
index_thunks.append(_DynamicThunk(c))
elif isinstance(c, slice):
index_thunks.append(_SliceThunk(c.start, c.stop, c.step))
else:
index_thunks.append(_StaticThunk(c))
return _jitted_nmapped_getitem(self, tuple(index_thunks))
# Array conversion operators
def __array__(self, *args, **kwargs):
"""Converts a named array with no named axes to a Numpy array."""
if self.named_shape:
raise ValueError(
"Only NamedArray(View)s with no named axes can be converted to numpy"
" arrays. Use `unwrap` or `untag` to assign positions to named axes"
" first, or use `penzai.named_axes.nmap` with a JAX function instead."
)
else:
return np.array(self.unwrap(), *args, **kwargs)
def __jax_array__(self):
"""Converts a named array with no named axes to a JAX array."""
if self.named_shape:
raise ValueError(
"Only NamedArray(View)s with no named axes can be converted to JAX"
" arrays. Use `unwrap` or `untag` to assign positions to named axes"
" first, or use `penzai.named_axes.nmap` with a JAX function instead."
)
else:
return self.unwrap()
# Iteration. Note that we *must* implement this to avoid Python simply trying
# to run __getitem__ until it raises IndexError, because we won't raise
# IndexError (since JAX clips array indices).
def __iter__(self):
if not self.positional_shape:
raise ValueError("Cannot iterate over an array with no positional axes!")
for i in range(self.positional_shape[0]):
yield self[i]
# Convenience wrappers: Elementwise infix operators.
__lt__ = _nmap_with_doc(operator.lt, "jax.Array.__lt__")
__le__ = _nmap_with_doc(operator.le, "jax.Array.__le__")
__eq__ = _nmap_with_doc(operator.eq, "jax.Array.__eq__")
__ne__ = _nmap_with_doc(operator.ne, "jax.Array.__ne__")
__ge__ = _nmap_with_doc(operator.ge, "jax.Array.__ge__")
__gt__ = _nmap_with_doc(operator.gt, "jax.Array.__gt__")
__add__ = _nmap_with_doc(operator.add, "jax.Array.__add__")
__sub__ = _nmap_with_doc(operator.sub, "jax.Array.__sub__")
__mul__ = _nmap_with_doc(operator.mul, "jax.Array.__mul__")
__truediv__ = _nmap_with_doc(operator.truediv, "jax.Array.__truediv__")
__floordiv__ = _nmap_with_doc(operator.floordiv, "jax.Array.__floordiv__")
__mod__ = _nmap_with_doc(operator.mod, "jax.Array.__mod__")
__divmod__ = _nmap_with_doc(divmod, "jax.Array.__divmod__")
__pow__ = _nmap_with_doc(operator.pow, "jax.Array.__pow__")
__lshift__ = _nmap_with_doc(operator.lshift, "jax.Array.__lshift__")
__rshift__ = _nmap_with_doc(operator.rshift, "jax.Array.__rshift__")
__and__ = _nmap_with_doc(operator.and_, "jax.Array.__and__")
__or__ = _nmap_with_doc(operator.or_, "jax.Array.__or__")
__xor__ = _nmap_with_doc(operator.xor, "jax.Array.__xor__")
__radd__ = _nmap_with_doc(_swapped_binop(operator.add), "jax.Array.__radd__")
__rsub__ = _nmap_with_doc(_swapped_binop(operator.sub), "jax.Array.__rsub__")
__rmul__ = _nmap_with_doc(_swapped_binop(operator.mul), "jax.Array.__rmul__")
__rtruediv__ = _nmap_with_doc(
_swapped_binop(operator.truediv), "jax.Array.__rtruediv__"
)
__rfloordiv__ = _nmap_with_doc(
_swapped_binop(operator.floordiv), "jax.Array.__rfloordiv__"
)
__rmod__ = _nmap_with_doc(_swapped_binop(operator.mod), "jax.Array.__rmod__")
__rdivmod__ = _nmap_with_doc(_swapped_binop(divmod), "jax.Array.__rdivmod__")
__rpow__ = _nmap_with_doc(_swapped_binop(operator.pow), "jax.Array.__rpow__")
__rlshift__ = _nmap_with_doc(
_swapped_binop(operator.lshift), "jax.Array.__rlshift__"
)
__rrshift__ = _nmap_with_doc(
_swapped_binop(operator.rshift), "jax.Array.__rrshift__"
)
__rand__ = _nmap_with_doc(_swapped_binop(operator.and_), "jax.Array.__rand__")
__ror__ = _nmap_with_doc(_swapped_binop(operator.or_), "jax.Array.__ror__")
__rxor__ = _nmap_with_doc(_swapped_binop(operator.xor), "jax.Array.__rxor__")
__abs__ = _nmap_with_doc(operator.abs, "jax.Array.__abs__")
__neg__ = _nmap_with_doc(operator.neg, "jax.Array.__neg__")
__pos__ = _nmap_with_doc(operator.pos, "jax.Array.__pos__")
__invert__ = _nmap_with_doc(operator.inv, "jax.Array.__invert__")
# Convenience wrappers: Scalar conversions.
__bool__ = _wrap_scalar_conversion(bool)
__complex__ = _wrap_scalar_conversion(complex)
__int__ = _wrap_scalar_conversion(int)
__float__ = _wrap_scalar_conversion(float)
__index__ = _wrap_scalar_conversion(operator.index)
# Convenience wrappers: np.ndarray / jax.Array methods.
all = _wrap_array_method("all")
any = _wrap_array_method("any")
argmax = _wrap_array_method("argmax")
argmin = _wrap_array_method("argmin")
argpartition = _wrap_array_method("argpartition")
argsort = _wrap_array_method("argsort")
astype = _wrap_array_method("astype")
choose = _wrap_array_method("choose")
clip = _wrap_array_method("clip")
compress = _wrap_array_method("compress")
conj = _wrap_array_method("conj")
conjugate = _wrap_array_method("conjugate")
# copy not implemented
cumprod = _wrap_array_method("cumprod")
cumsum = _wrap_array_method("cumsum")
diagonal = _wrap_array_method("diagonal")
dot = _wrap_array_method("dot")
flatten = _wrap_array_method("flatten")
imag = _wrap_array_method("imag")
item = _wrap_array_method("item")
max = _wrap_array_method("max")
mean = _wrap_array_method("mean")
min = _wrap_array_method("min")
# nbytes not implemented
nonzero = _wrap_array_method("nonzero")
prod = _wrap_array_method("prod")
ptp = _wrap_array_method("ptp")
ravel = _wrap_array_method("ravel")
real = _wrap_array_method("real")
repeat = _wrap_array_method("repeat")
reshape = _wrap_array_method("reshape")
round = _wrap_array_method("round")
searchsorted = _wrap_array_method("searchsorted")
sort = _wrap_array_method("sort")
squeeze = _wrap_array_method("squeeze")
std = _wrap_array_method("std")
sum = _wrap_array_method("sum")
swapaxes = _wrap_array_method("swapaxes")
take = _wrap_array_method("take")
# tobytes / tolist not implemented
trace = _wrap_array_method("trace")
transpose = _wrap_array_method("transpose")