/
linear.py
3774 lines (3072 loc) · 131 KB
/
linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Linear functions."""
import enum
import functools
import operator as op
import string
from typing import Callable, Iterable, Optional, Sequence, Tuple, TypeVar, Union
import warnings
import jax
from jax import lax
from jax import numpy as np
from jax import ops
from jax import random
from jax import ShapeDtypeStruct, ShapedArray, eval_shape, vmap
import jax.example_libraries.stax as ostax
import numpy as onp
from .requirements import Bool, Diagonal, get_diagonal_outer_prods, layer, mean_and_var, requires, supports_masking
from ..utils import utils
from ..utils.kernel import Kernel
from ..utils.typing import Axes, InternalLayer, InternalLayerMasked, PyTree
# Enums
class Padding(enum.Enum):
"""Type of padding in pooling and convolutional layers.
Attributes:
CIRCULAR:
circular padding, as if the input were a torus.
SAME:
same, a.k.a. zero padding.
VALID:
valid, a.k.a. no padding.
"""
CIRCULAR = 'CIRCULAR'
SAME = 'SAME'
VALID = 'VALID'
class _Pooling(enum.Enum):
"""Type of pooling in pooling layers.
Attributes:
AVG:
average pooling, the output is normalized by the input receptive field
size.
SUM:
sum pooling, no normalization.
"""
AVG = 'AVG'
SUM = 'SUM'
class AggregateImplementation(enum.Enum):
"""Implementation of the :obj:`Aggregate` layer.
See :obj:`Aggregate` docstring for details.
Attributes:
DENSE:
Is recommended for dense graphs, where the number of edges `E` is
proportional to the number of vertices `V` to the power of 1.5 or more.
SPARSE:
Is recommended for sparse graphs, where `E ~ O(V)` or less.
"""
DENSE = 'DENSE'
SPARSE = 'SPARSE'
# LAYERS
@layer
@supports_masking(remask_kernel=False)
def Identity() -> InternalLayer:
"""Identity (no-op).
Based on :obj:`jax.example_libraries.stax.Identity`.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
init_fn, apply_fn = ostax.Identity
kernel_fn = lambda k, **kwargs: k
return init_fn, apply_fn, kernel_fn
@layer
@supports_masking(remask_kernel=False)
def DotGeneral(
*,
lhs: Optional[Union[np.ndarray, float]] = None,
rhs: Optional[Union[np.ndarray, float]] = None,
dimension_numbers: lax.DotDimensionNumbers = (((), ()), ((), ())),
precision: Optional[lax.Precision] = None,
batch_axis: int = 0,
channel_axis: int = -1
) -> InternalLayerMasked:
r"""Constant (non-trainable) rhs/lhs Dot General.
Dot General allows to express any linear transformation on the inputs,
including but not limited to matrix multiplication, pooling, convolutions,
permutations, striding, masking etc (but specialized implementations are
typically much more efficient).
Returned `apply_fn` is calling
`jax.lax.dot_general(inputs, rhs, dimension_numbers, precision)` or
`jax.lax.dot_general(lhs, inputs, dimension_numbers, precision)`, depending
on whether `lhs` or `rhs` is specified (not `None`).
Example:
>>> from jax import random
>>> import jax.numpy as np
>>> from neural_tangents import stax
>>> #
>>> # Two time series stacked along the second (H) dimension.
>>> x = random.normal(random.PRNGKey(1), (5, 2, 32, 3)) # NHWC
>>> #
>>> # Multiply all outputs by a scalar:
>>> nn = stax.serial(
>>> stax.Conv(128, (1, 3)),
>>> stax.Relu(),
>>> stax.DotGeneral(rhs=2.), # output shape is (5, 2, 30, 128)
>>> stax.GlobalAvgPool() # (5, 128)
>>> )
>>> #
>>> # Subtract second time series from the first one:
>>> nn = stax.serial(
>>> stax.Conv(128, (1, 3)),
>>> stax.Relu(),
>>> stax.DotGeneral(
>>> rhs=np.array([1., -1.]),
>>> dimension_numbers=(((1,), (0,)), ((), ()))), # (5, 30, 128)
>>> stax.GlobalAvgPool() # (5, 128)
>>> )
>>> #
>>> # Flip outputs with each other
>>> nn = stax.serial(
>>> stax.Conv(128, (1, 3)),
>>> stax.Relu(),
>>> stax.DotGeneral(
>>> lhs=np.array([[0., 1.], [1., 0.]]),
>>> dimension_numbers=(((1,), (1,)), ((), ()))), # (5, 2, 30, 128)
>>> stax.GlobalAvgPool() # (5, 128)
>>> )
See Also:
https://www.tensorflow.org/xla/operation_semantics#dotgeneral
Args:
lhs:
a constant array to dot with. `None` means layer `inputs` are the
left-hand side.
rhs:
a constant array to dot with. `None` means layer `inputs` are the
right-hand side. If both `lhs` and `rhs` are `None` the layer is the same
as `Identity`.
dimension_numbers:
a tuple of tuples of the form `((lhs_contracting_dims,
rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
precision:
Optional. Either `None`, which means the default precision for the
backend, or a `lax.Precision` enum value (`Precision.DEFAULT`,
`Precision.HIGH` or `Precision.HIGHEST`).
batch_axis:
batch axis for `inputs`. Defaults to `0`, the leading axis. Can be present
in `dimension_numbers`, but contraction along `batch_axis` will not allow
for further layers to be applied afterwards.
channel_axis:
channel axis for `inputs`. Defaults to `-1`, the trailing axis. For
`kernel_fn`, channel size is considered to be infinite. Cannot be present
in `dimension_numbers`.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
if rhs is not None and lhs is not None:
raise ValueError('At most one of constant `rhs` and `lhs` can be non-`None`'
', since the other factor is considered to be the layer '
'`inputs`.')
is_lhs = rhs is None
other = np.array(lhs if is_lhs else rhs)
def dot_fn(x):
args = (x, other.astype(x.dtype))[::(-1 if is_lhs else 1)]
return lax.dot_general(*args, dimension_numbers, precision)
def init_fn(rng, input_shape):
out = eval_shape(dot_fn, ShapeDtypeStruct(input_shape, other.dtype))
return out.shape, ()
def apply_fn(params, inputs, **kwargs):
return dot_fn(inputs)
# If a dimension is contracted, respective pairwise covariances are needed to
# compute the covariance of contractions.
input_cs = dimension_numbers[0][1 if is_lhs else 0]
diagonal_batch = (batch_axis not in input_cs) or (rhs is None and lhs is None)
diagonal_spatial = Diagonal(
input=Bool.YES
if (input_cs in ((), (batch_axis,)) or (rhs is None and lhs is None))
else Bool.NO)
@requires(diagonal_batch=diagonal_batch,
diagonal_spatial=diagonal_spatial,
batch_axis=batch_axis,
channel_axis=channel_axis)
def kernel_fn(k: Kernel, **kwargs) -> Kernel:
return k.dot_general(other, other, is_lhs, dimension_numbers)
def mask_fn(mask, input_shape):
mask_shape = list(input_shape)
mask_shape[channel_axis] = mask.shape[channel_axis]
return ~dot_fn(~np.broadcast_to(mask, mask_shape))
return init_fn, apply_fn, kernel_fn, mask_fn
@layer
@supports_masking(remask_kernel=True)
def Aggregate(
aggregate_axis: Optional[Axes] = None,
batch_axis: int = 0,
channel_axis: int = -1,
to_dense: Optional[Callable[[np.ndarray], np.ndarray]] = lambda p: p,
implementation: str = AggregateImplementation.DENSE.value
) -> InternalLayer:
r"""Aggregation operator (graphical neural network).
See e.g.
"`Graph Neural Tangent Kernel: Fusing Graph Neural Networks with Graph Kernels
<https://arxiv.org/abs/1905.13192>`_".
Specifically, each `N+2`-D `input` of shape `(batch, X_1, ..., X_N, channels)`
(subject to `batch_axis` and `channel_axis`) is accompanied by an array
`pattern` specifying the directed edges (arcs, arrows) of the graph. The
format of `pattern` depends on `implementation`:
`implementation = "DENSE"`:
Is recommended for dense graphs, where the number of
edges `E` is proportional to the number of vertices `V` to the power of 1.5
or more. In this case, `pattern` is a [weighted] adjacency 2-adjacency
`2K+1`-D tensor of shape `(batch, X_i1, ..., X_iK, X_i1, ..., X_iK)` (i.e.
leading batch dimensions, repeated spatial dimensions, no channel dimension)
and the output tensor is
`lax.dot_general(inputs, pattern, ((aggregate_axes, range(1, K + 1)),
(batch_axis,), (0,)))` with the `batch_axis` and `channel_axis` preserved.
`K = len(aggregate_axes)`.
Having `pattern[n, i1, ..., iK, j1, ..., jK] == w` represents a directed
edge (arc) from tail pixel / token `(i1, ..., iK)` to head `(j1, ..., jK)`
with weight `w` in an individual input sample `n`. The `apply_fn` of this
layer replaces all vertices with the (weighted) sum of all direct
predecessors to the given vertex.
Note that individual inputs can have more than `K` dimensions (e.g.
channels, other coordinates), in which case slices along these coordinates
are processed in the same way independently.
This implementation uses matrix multiplication, and for a graph with `V`
vertices and `E` edges, `apply_fn` costs `O(V^2)` memory and time, while
`kernel_fn` costs `O(V^2)` memory and `O(V^3)` time.
The adjacency tensor `pattern` can be specified in a sparse format. If
you provide a `to_dense` function (defaults to identity), then `pattern` is
decoded into a dense representation as described above
(`pattern_dense = to_dense(pattern)`) each time `apply_fn` or `kernel_fn`
are called. This avoids storing the whole graph in the dense format in
advance, but only convert it to dense format on the fly, for each
individual batch `x` / `(x1, x2)`. However, this does not improve the
runtime or memory of the `Aggregate` layer (in fact makes it a bit slower
due to an extra `to_dense` call).
`implementation = "SPARSE"`:
Is recommended for sparse graphs, where `E ~ O(V)` or less. In this case,
`pattern` must be an integer array of shape `(batch, n_edges, K, 2)`,
specifying `n_edges` directed edges (arcs) of weight `w = 1` for each of
the `batch` input samples (if `K == 1` `pattern` can also have the shape
`(batch, n_edges, 2)`). Trailing dimension of size 2 corresponds to tails
(sources, senders) and heads (targets, receivers). Edges can be repeated,
which is interpreted as having their weight be the number of repetitions.
If any of the `K` coordinates of a given vertex in `heads` is negative
(e.g. `-1`), it is discarded. This can be used for padding, when different
input samples have different `n_edges`. Note that this means you can't use
negative indexing to specify vertices.
This implementation uses :obj:`jax.ops.segment_sum` instead of matrix
multiplication. This makes `apply_fn` cost `O(V + E)` memory and `O(V + E)`
time, and `kernel_fn` cost `O(V^2)` memory and `O(V^2 + E^2 + V * E)` time.
This is beneficial for sparse graphs, i.e. `E << V^2`, but detrimental for
dense graphs (when `E ~ V^2`).
See Also:
`AggregateTest` in `tests/stax_test.py` for examples and conversion between
sparse and dense patterns.
Example:
>>> # 1D inputs
>>> x = random.normal(random.PRNGKey(1), (5, 3, 32)) # NCH
>>> #
>>> # 1) NHH dense binary adjacency matrix
>>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32))
>>> # `A[n, h1, h2] == True`
>>> # means an edge between tokens `h1` and `h2` in sample `n`.
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2,
>>> batch_axis=0,
>>> channel_axis=1)
>>> #
>>> out = apply_fn((), x, pattern=A)
>>> # output is the same as `x @ A` of shape (5, 3, 32)
>>> #
>>> # Sparse NHH binary pattern with 10 edges
>>> n_edges = 10
>>> A_sparse = random.randint(random.PRNGKey(3),
>>> shape=(x.shape[0], n_edges, 1, 2),
>>> minval=0,
>>> maxval=x.shape[2])
>>> #
>>> # Setting `implementation="SPARSE"` to invoke the segment sum
>>> # implementation.
>>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2,
>>> batch_axis=0,
>>> channel_axis=1,
>>> implementation="SPARSE")
>>> #
>>> out = apply_fn((), x, pattern=A_sparse)
>>> # output is of shape (5, 3, 32), computed via `jax.ops.segment_sum`.
>>> #
>>> # 2D inputs
>>> x = random.normal(random.PRNGKey(1), (5, 3, 32, 16)) # NCHW
>>> #
>>> # 2) NHWHW dense binary adjacency matrix
>>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 16, 32, 16))
>>> # `A[n, h1, w1, h2, w2] == True`
>>> # means an edge between pixels `(h1, w1)` and `(h2, w2)` in image `n`.
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(2, 3),
>>> batch_axis=0,
>>> channel_axis=1)
>>> #
>>> out = apply_fn((), x, pattern=A)
>>> # output is of shape (5, 3, 32, 16), the same as
>>> # `(x.reshape((5, 3, 32 * 16)) @ A.reshape((5, 32 * 16, 32 * 16))
>>> # ).reshape(x.shape)`
>>> #
>>> # 3) NWW binary adjacency matrix
>>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 16, 16))
>>> # `A[n, w1, w2] == True`
>>> # means an edge between rows `w1` and `w2` in image `n`.
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(3,),
>>> batch_axis=0,
>>> channel_axis=1)
>>> #
>>> out = apply_fn((), x, pattern=A)
>>> # output is of shape (5, 3, 32, 16), the same as
>>> # `(x.reshape((5, 3 * 32, 16)) @ A).reshape(x.shape)`
>>> #
>>> # 4) Infinite width example
>>> x1 = random.normal(random.PRNGKey(1), (5, 3, 32)) # NCH
>>> x2 = random.normal(random.PRNGKey(2), (2, 3, 32)) # NCH
>>> #
>>> # NHH binary adjacency matrices
>>> A1 = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32))
>>> A2 = random.bernoulli(random.PRNGKey(2), 0.5, (2, 32, 32))
>>> #
>>> _, _, kernel_fn_id = stax.Identity()
>>> #
>>> _, _, kernel_fn_agg = stax.Aggregate(aggregate_axis=2,
>>> batch_axis=0,
>>> channel_axis=1)
>>> #
>>> nngp = kernel_fn_id(x1, x2, get='nngp', channel_axis=1)
>>> # initial NNGP of shape (5, 2, 32, 32)
>>> K_agg = kernel_fn_agg(x1, x2, get='nngp', pattern=(A1, A2))
>>> # output NNGP of same shape (5, 2, 32, 32):
>>> # `K_agg[n1, n2] == A1[n1].T @ nngp[n1, n2] @ A2[n2]`
Args:
aggregate_axis:
axes (non-batch and non-channel) to aggregate predecessor vertices over.
batch_axis:
batch axis for `inputs`. Defaults to `0`, the leading axis.
channel_axis:
channel axis for `inputs`. Defaults to `-1`, the trailing axis. For
`kernel_fn`, channel size is considered to be infinite.
to_dense:
Ignored unless `implementation == "DENSE"`. A function to convert
potentially sparse `pattern` matrices into dense `2K+1`-D tensors of shape
`(batch, X_i1, ..., X_iK, X_i1, ..., X_iK)`, with the batch leading
dimension, and no channel dimension, where `K = len(aggregate_axes)`.
Will be called on input `pattern` (or a pair `(pattern1, pattern2)`)
every time `apply_fn` or `kernel_fn` is called. Defaults to identity,
meaning that `pattern` is expected in the dense format.
implementation:
`"DENSE"` or `"SPARSE"`, specifying which implementation to use.
`"DENSE"` uses matrix multiplications and is recommended for dense graphs
(`E ~> O(V^1.5)`), while `"SPARSE"` uses :obj:`jax.ops.segment_sum` and is
recommended for sparse graphs (`E ~< O(V)`). Note that different
`implementation` require different `pattern` array format - see the
:obj:`Aggregate` layer docstring above for details.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
implementation = AggregateImplementation(implementation)
if implementation == AggregateImplementation.SPARSE:
warnings.warn('Negative indices in `pattern` are considered as padding '
'(i.e. ignored), unlike typical numpy negative indexing.')
init_fn = lambda rng, input_shape: (input_shape, ())
def get_agg_axes(ndim: int) -> Tuple[Tuple[int, ...], int, int]:
_batch_axis, _channel_axis = utils.mod((batch_axis, channel_axis), ndim)
if aggregate_axis is None:
agg_axes = tuple(i for i in range(ndim)
if i not in (_batch_axis, _channel_axis))
else:
agg_axes = tuple(utils.canonicalize_axis(aggregate_axis, ndim))
return agg_axes, _batch_axis, _channel_axis
def get_dimension_numbers(ndim: int) -> lax.DotDimensionNumbers:
agg_axes, batch_axis, _ = get_agg_axes(ndim)
agg_ndim = len(agg_axes)
return (agg_axes, (range(1, agg_ndim + 1))), ((batch_axis,), (0,))
@functools.partial(vmap, in_axes=(0, None))
def make_indices(index_array, agg_shape):
index_array = np.moveaxis(index_array, -1, 0)
raveled = np.ravel_multi_index(index_array, agg_shape, 'wrap')
# We mask edges where either sender or receiver is negative.
return np.where(np.all(index_array >= 0, axis=0), raveled, -1)
def get_senders_receivers(pattern, batch_size: int, agg_ndim: int):
"""Unpack `pattern` and make sure it has correct shape."""
if pattern.shape[-1] != 2:
raise ValueError('`pattern` must have a trailing dimension of 2, got '
f'{pattern.shape[-1]}.')
s, r = pattern[..., 0], pattern[..., 1]
# Allow for `(batch, n_edges, 2)` shape for single aggregation
# dimension `K == 1`.
if agg_ndim == 1 and s.ndim == 2:
s, r = np.expand_dims(s, -1), np.expand_dims(r, -1)
if s.ndim != 3:
raise ValueError(f'Tails and heads need to be 3-dimensional, '
f'got {s.ndim}.')
if s.shape[2] != agg_ndim:
raise ValueError(f'Trailing dimension of tails and heads need to have '
f'the same size as the number of aggregate axes of '
f'`aggregate_axis` ({agg_ndim}), got {s.shape[2]}.')
if s.shape[0] != batch_size:
raise ValueError(f'Tails and heads need to have leading dimension equal '
f'to batch size, got {s.shape[0]}.')
return s, r
def apply_fn(params,
inputs: np.ndarray,
*,
pattern: Optional[np.ndarray] = None,
**kwargs):
"""Compute the transformed tensors after an aggregation layer.
Args:
params:
Not used.
inputs:
An input `N+2`-D tensor of shape `(batch, X_1, ..., X_N, channels)`
(subject to `batch_axis` and `channel_axis`).
pattern:
A tensor specifying the directed edges between `inputs`. The shape and
type of `pattern` depends on `implementation` (see docstring of
`stax.Aggregate` above).
`implementation == "DENSE"`:
`pattern` must be a (float) `2K+1`-D tensor of shape
`(batch, X_i1, ..., X_iK, X_i1, ..., X_iK)`, with the batch leading
dimension, and no channel dimension, where `K = len(aggregate_axes)`.
Can have another shape (e.g. a sparse matrix), as long as
`to_dense(pattern)` has the correct (dense) shape (if `nt.batch` is
used, the leading dimension of `pattern` must be the batch dimension,
of size `batch`).
`implementation == "SPARSE"`:
`pattern` must be an integer array of shape `(batch, n_edges, K, 2)`,
specifying tail and head (source and target / sender and receiver)
vertices along the trailing dimension (if `K == 1`, `pattern` is also
allowed to have the shape `(batch, n_edges, 2)`).
`pattern=None` means identity adjacency, i.e. `apply_fn` is an identity
function.
**kwargs:
unused.
Returns:
An `N+2`-D tensor of shape of the same shape as `inputs`.
"""
if pattern is None:
return inputs
del params
ndim = inputs.ndim
agg_axes, batch_axis, channel_axis = get_agg_axes(ndim)
agg_ndim = len(agg_axes)
if implementation == AggregateImplementation.DENSE:
# Dense implementation through matrix multiplication.
pattern = to_dense(pattern)
dn = get_dimension_numbers(ndim)
out = lax.dot_general(inputs, pattern.astype(inputs.dtype), dn)
# Put back potentially displaced batch and channel axes.
out_c_axis = utils.axis_after_dot(channel_axis % ndim, dn[0][0], dn[1][0])
out_b_axis = utils.axis_after_dot(batch_axis % ndim, dn[0][0], dn[1][0])
out = np.moveaxis(out,
(out_b_axis, out_c_axis) + tuple(range(-agg_ndim, 0)),
(batch_axis, channel_axis) + agg_axes)
elif implementation == AggregateImplementation.SPARSE:
# Sparse implementation through `jax.ops.segment_sum`.
s, r = get_senders_receivers(pattern, inputs.shape[batch_axis], agg_ndim)
# Canonicalize axes
src_axes = (batch_axis,) + agg_axes + (channel_axis,)
dst_axes = (0,) + tuple(range(1, agg_ndim + 1)) + (-1,)
inputs = np.moveaxis(inputs, src_axes, dst_axes)
input_shape = inputs.shape
inputs = inputs.reshape((inputs.shape[0],
functools.reduce(
op.mul, inputs.shape[1:agg_ndim + 1], 1))
+ inputs.shape[agg_ndim + 1:])
agg_shape = input_shape[1:agg_ndim + 1]
s, r = make_indices(s, agg_shape), make_indices(r, agg_shape)
@vmap
def pass_messages(s, r, inputs):
n_nodes = inputs.shape[0]
sender_in = inputs[s]
messages = ops.segment_sum(sender_in, r, num_segments=n_nodes)
return messages
out = pass_messages(s, r, inputs)
out = out.reshape(input_shape)
out = np.moveaxis(out, dst_axes, src_axes)
else:
raise ValueError(f'Unrecognized `implementation == {implementation}.')
return out
@requires(batch_axis=batch_axis,
channel_axis=channel_axis,
diagonal_spatial=Diagonal(input=Bool.NO, output=Bool.NO))
def kernel_fn(k: Kernel,
*,
pattern: Tuple[Optional[np.ndarray],
Optional[np.ndarray]] = (None, None),
**kwargs):
"""Compute the transformed kernels after an aggregation kernel layer.
Specifically, the `nngp`/`ntk` is a `2N+2`-D tensor of shape
`(B_1, B_2, X_1, X_1, ..., X_N, X_N)`.
If `implementation == "DENSE"`, this tensor will be aggregated
(via matrix multiplication) on the left by `to_dense(pattern[0])` of
shape `(B_1, X_i1, ..., X_iK)` and on the right by `to_dense(pattern[1])`
of shape `(B_2, X_i1, ..., X_iK)`. Ignoring the batch dimensions, the
output `nngp/ntk` is `pattern[0].T @ nngp/ntk @ pattern[1]`.
If `implementation == "SPARSE"`, result is computed using
`jax.ops.segment_sum` given `pattern[0]` and `pattern[1]` as integer
arrays of shapes `(B_1, n_edges_1, K, 2)` and `(B_2, n_edges_2, K, 2)`
respectively.
"""
pattern1, pattern2 = pattern
if pattern1 is None and pattern2 is None:
return k
if pattern1 is None or pattern2 is None:
raise NotImplementedError(
'Having exactly one of two `pattern1/2=None` is not implemented. '
'Please file a bug at '
'https://github.com/google/neural-tangents/issues/new.')
ndim = len(k.shape1)
agg_axes, batch_axis, channel_axis = get_agg_axes(ndim)
agg_ndim = len(agg_axes)
agg_shape = tuple(k.shape1[a] for a in agg_axes)
agg_size = functools.reduce(op.mul, agg_shape, 1)
def bucket_axes(ndim, start_axis):
"""Bucket kernel axes into batch, aggregate, and non-aggregate."""
ndim_spatial = (ndim - start_axis) // 2
agg_1 = tuple(
a - int(batch_axis < a) - int(channel_axis < a) + start_axis
for a in agg_axes)
agg_2 = tuple(
a + ndim_spatial
for a in agg_1)
non_agg_1 = tuple(
a for a in range(start_axis, start_axis + ndim_spatial)
if a not in agg_1)
non_agg_2 = tuple(
a for a in range(start_axis + ndim_spatial, ndim)
if a not in agg_2)
return tuple(range(start_axis)), agg_1, agg_2, non_agg_1, non_agg_2
if implementation == AggregateImplementation.DENSE:
# Dense implementation through matrix multiplication.
pattern1 = None if pattern1 is None else to_dense(pattern1)
pattern2 = None if pattern2 is None else to_dense(pattern2)
k = k.dot_general(
other1=pattern1,
other2=pattern2,
is_lhs=False,
dimension_numbers=get_dimension_numbers(ndim)
)
# Put back potentially displaced axes.
def transpose(k, diagonal_batch):
if k is None or k.ndim == 0:
return k
start_axis = 1 if diagonal_batch else 2
k = utils.unzip_axes(k, start_axis)
b, agg_1, agg_2, non_agg_1, non_agg_2 = bucket_axes(k.ndim, start_axis)
permutation = b + non_agg_1 + agg_1 + non_agg_2 + agg_2
k = np.transpose(k, onp.argsort(permutation))
return utils.zip_axes(k, start_axis)
k = k.replace(
cov1=transpose(k.cov1, k.diagonal_batch),
cov2=transpose(k.cov2, k.diagonal_batch),
nngp=transpose(k.nngp, False),
ntk=transpose(k.ntk, False),
batch_axis=batch_axis % ndim,
channel_axis=channel_axis % ndim
)
elif implementation == AggregateImplementation.SPARSE:
# Sparse implementation through `jax.ops.segment_sum`.
def pass_messages(s1, s2, r1, r2, k):
v1, v2 = k.shape[:2]
def send(s, r, num_segments):
return ops.segment_sum(s, r, num_segments=num_segments)
send_inner = vmap(functools.partial(send, num_segments=v2), (0, None))
k = k[s1[:, None], s2[None, :]]
k = send_inner(k, r2)
k = send(k, r1, num_segments=v1)
return k
pass_messages_self = vmap(pass_messages)
pass_messages_cross = vmap(vmap(pass_messages,
(None, 0, None, 0, 0)),
(0, None, 0, None, 0))
s1, r1 = get_senders_receivers(pattern1, k.shape1[batch_axis], agg_ndim)
s2, r2 = get_senders_receivers(pattern2, k.shape2[batch_axis], agg_ndim)
s1, r1 = make_indices(s1, agg_shape), make_indices(r1, agg_shape)
s2, r2 = make_indices(s2, agg_shape), make_indices(r2, agg_shape)
def agg(k, diagonal_batch, s1, r1, s2, r2):
if k is None or k.ndim == 0:
return k
start_axis = 1 if diagonal_batch else 2
k = utils.unzip_axes(k, start_axis)
b, agg_1, agg_2, non_agg_1, non_agg_2 = bucket_axes(k.ndim, start_axis)
permutation = b + agg_1 + agg_2 + non_agg_1 + non_agg_2
k = np.transpose(k, permutation)
k_shape = k.shape
k = k.reshape(
k.shape[:start_axis] +
(agg_size,) * 2 +
k.shape[start_axis + 2 * len(agg_axes):]
)
fn = pass_messages_self if diagonal_batch else pass_messages_cross
k = fn(s1, s2, r1, r2, k)
k = k.reshape(k_shape)
k = np.transpose(k, onp.argsort(permutation))
return utils.zip_axes(k, start_axis)
nngp = agg(k.nngp, False, s1, r1, s2, r2)
ntk = agg(k.ntk, False, s1, r1, s2, r2)
cov1 = agg(k.cov1, k.diagonal_batch, s1, r1, s1, r1)
cov2 = agg(k.cov2, k.diagonal_batch, s2, r2, s2, r2)
k = k.replace(nngp=nngp, ntk=ntk, cov1=cov1, cov2=cov2)
else:
raise ValueError(f'Unregocnized `implementation == {implementation}.')
return k
return init_fn, apply_fn, kernel_fn
@layer
@supports_masking(remask_kernel=True)
def Dense(
out_dim: int,
W_std: float = 1.,
b_std: Optional[float] = None,
batch_axis: int = 0,
channel_axis: int = -1,
parameterization: str = 'ntk',
s: Tuple[int, int] = (1, 1),
) -> InternalLayerMasked:
r"""Dense (fully-connected, matrix product).
Based on :obj:`jax.example_libraries.stax.Dense`.
Args:
out_dim:
The output feature / channel dimension. This is ignored in by the
`kernel_fn` in `"ntk"` parameterization.
W_std:
Specifies the standard deviation of the weights.
b_std:
Specifies the standard deviation of the biases. `None` means no bias.
batch_axis:
Specifies which axis is contains different elements of the batch.
Defaults to `0`, the leading axis.
channel_axis: Specifies which axis contains the features / channels.
Defaults to `-1`, the trailing axis. For `kernel_fn`, channel size is
considered to be infinite.
parameterization:
Either `"ntk"` or `"standard"`.
Under `"ntk"` parameterization (page 3 in "`Neural Tangent Kernel:
Convergence and Generalization in Neural Networks
<https://arxiv.org/abs/1806.07572>`_"),
weights and biases are initialized as
:math:`W_{ij} \sim \mathcal{N}(0,1)`, :math:`b_i \sim \mathcal{N}(0,1)`,
and the finite width layer equation is
:math:`z_i = \sigma_W / \sqrt{N} \sum_j W_{ij} x_j + \sigma_b b_i`, where
`N` is `out_dim`.
Under `"standard"` parameterization ("`On the infinite width limit of
neural networks with a standard parameterization
<https://arxiv.org/abs/2001.07301>`_".),
weights and biases are initialized as :math:`W_{ij} \sim \mathcal{N}(0,
W_{std}^2/N)`,
:math:`b_i \sim \mathcal{N}(0,\sigma_b^2)`, and the finite width layer
equation is
:math:`z_i = \frac{1}{s} \sum_j W_{ij} x_j + b_i`, where `N` is `out_dim`.
`N` corresponds to the respective variable in
"`On the infinite width limit of neural networks with a standard
parameterization <https://arxiv.org/abs/2001.07301>`_".
s:
only applicable when `parameterization="standard"`. A tuple of integers
specifying the width scalings of the input and the output of the layer,
i.e. the weight matrix `W` of the layer has shape
`(s[0] * in_dim, s[1] * out_dim)`, and the bias has size `s[1] * out_dim`.
.. note::
We need `s[0]` (scaling of the previous layer) to infer `in_dim` from
`input_shape`. Further, for the bottom layer, `s[0]` must be `1`, and
for all other layers `s[0]` must be equal to `s[1]` of the previous
layer. For the top layer, `s[1]` is expected to be `1` (recall that the
output size is `s[1] * out_dim`, and in common infinite network
research input and output sizes are considered fixed).
`s` corresponds to the respective variable in
"`On the infinite width limit of neural networks with a standard
parameterization <https://arxiv.org/abs/2001.07301>`_".
For `parameterization="ntk"`, or for standard, finite-width networks
corresponding to He initialization, `s=(1, 1)`.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
# TODO(jaschasd): after experimentation, evaluate whether to change default
# parameterization from "ntk" to "standard"
parameterization = parameterization.lower()
def _init_fn(rng, input_shape, out_dim):
_channel_axis = channel_axis % len(input_shape)
output_shape = (input_shape[:_channel_axis] + (out_dim,)
+ input_shape[_channel_axis + 1:])
rng1, rng2 = random.split(rng)
W = random.normal(rng1, (input_shape[_channel_axis], out_dim))
if b_std is None:
b = None
else:
b_shape = [1] * len(input_shape)
b_shape[channel_axis] = out_dim
b = random.normal(rng2, b_shape)
return output_shape, (W, b)
def ntk_init_fn(rng, input_shape):
return _init_fn(rng, input_shape, out_dim)
def standard_init_fn(rng, input_shape):
output_shape, (W, b) = _init_fn(rng, input_shape, out_dim * s[1])
W *= W_std / (input_shape[channel_axis] / s[0])**0.5
b = None if b is None else b * b_std
return output_shape, (W, b)
if parameterization == 'ntk':
init_fn = ntk_init_fn
elif parameterization == 'standard':
init_fn = standard_init_fn
else:
raise ValueError(f'Parameterization not supported: {parameterization}')
def apply_fn(params, inputs, **kwargs):
W, b = params
prod = np.moveaxis(np.tensordot(W, inputs, (0, channel_axis)),
0, channel_axis)
if parameterization == 'ntk':
norm = W_std / inputs.shape[channel_axis]**0.5
outputs = norm * prod
if b is not None:
outputs += b_std * b
elif parameterization == 'standard':
outputs = prod / s[0]**0.5
if b is not None:
outputs += b
else:
raise ValueError(f'Parameterization not supported: {parameterization}')
return outputs
@requires(batch_axis=batch_axis,
channel_axis=channel_axis,
diagonal_spatial=Diagonal())
def kernel_fn(k: Kernel, **kwargs):
"""Compute the transformed kernels after a `Dense` layer."""
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
def fc(x):
return _affine(x, W_std, b_std)
if parameterization == 'ntk':
cov1, nngp, cov2 = map(fc, (cov1, nngp, cov2))
if ntk is not None:
ntk = nngp + W_std**2 * ntk
elif parameterization == 'standard':
input_width = k.shape1[channel_axis] / s[0]
if ntk is not None:
ntk = input_width * nngp + W_std**2 * ntk
if b_std is not None:
ntk += 1.
cov1, nngp, cov2 = map(fc, (cov1, nngp, cov2))
return k.replace(cov1=cov1,
nngp=nngp,
cov2=cov2,
ntk=ntk,
is_gaussian=True,
is_input=False)
def mask_fn(mask, input_shape):
return np.all(mask, axis=channel_axis, keepdims=True)
return init_fn, apply_fn, kernel_fn, mask_fn
@layer
@supports_masking(remask_kernel=True)
def Conv(
out_chan: int,
filter_shape: Sequence[int],
strides: Optional[Sequence[int]] = None,
padding: str = Padding.VALID.name,
W_std: float = 1.0,
b_std: Optional[float] = None,
dimension_numbers: Optional[Tuple[str, str, str]] = None,
parameterization: str = 'ntk',
s: Tuple[int, int] = (1, 1),
) -> InternalLayerMasked:
"""General convolution.
Based on :obj:`jax.example_libraries.stax.GeneralConv`.
Args:
out_chan:
The number of output channels / features of the convolution. This is
ignored in by the `kernel_fn` in NTK parameterization.
filter_shape:
The shape of the filter. The shape of the tuple should agree with the
number of spatial dimensions in `dimension_numbers`.
strides:
The stride of the convolution. The shape of the tuple should agree with
the number of spatial dimensions in `dimension_numbers`.
padding:
Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`,
or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions.
W_std:
The standard deviation of the weights.
b_std:
The standard deviation of the biases.
dimension_numbers:
Specifies which axes should be convolved over. Should match the
specification in :obj:`jax.lax.conv_general_dilated`.
parameterization:
Either `"ntk"` or `"standard"`. These parameterizations are the direct
analogues for convolution of the corresponding parameterizations for
:obj:`Dense` layers.
s:
A tuple of integers, a direct convolutional analogue of the respective
parameters for the :obj:`Dense` layer.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
return _Conv(out_chan, filter_shape, strides, padding, W_std, b_std,
dimension_numbers, parameterization, s, False, True)
@layer
@supports_masking(remask_kernel=True)
def ConvTranspose(
out_chan: int,
filter_shape: Sequence[int],
strides: Optional[Sequence[int]] = None,
padding: str = Padding.VALID.name,
W_std: float = 1.0,
b_std: Optional[float] = None,
dimension_numbers: Optional[Tuple[str, str, str]] = None,
parameterization: str = 'ntk',
s: Tuple[int, int] = (1, 1),
) -> InternalLayerMasked:
"""General transpose convolution.
Based on :obj:`jax.example_libraries.stax.GeneralConvTranspose`.
Args:
out_chan:
The number of output channels / features of the convolution. This is