/
stax_test.py
3296 lines (2865 loc) · 116 KB
/
stax_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for stax.py."""
import functools
import itertools
import random as prandom
import string
import time
from typing import Tuple
from absl.testing import absltest
from jax import lax
from jax import ops
from jax import test_util as jtu
from jax.api import jit, vjp
from jax.config import config
from jax.lib import xla_bridge
import jax.numpy as np
import jax.random as random
import more_itertools
from neural_tangents import stax
from neural_tangents.utils import monte_carlo, test_utils, utils, batch
import numpy as onp
config.parse_flags_with_absl()
config.update('jax_numpy_rank_promotion', 'raise')
MODELS = [
'fc',
'conv'
]
BATCH_SIZE = 4
INPUT_SHAPE = (BATCH_SIZE, 8, 6, 2)
WIDTHS = [2**10]
N_SAMPLES = 100
RTOL = 0.041
FILTER_SHAPES = [
(2, 1),
(3, 2)
]
PADDINGS = [
'SAME',
'VALID',
'CIRCULAR'
]
STRIDES = [
(1, 2),
(2, 1),
]
ACTIVATIONS = {
stax.Relu(): 'Relu',
}
PROJECTIONS = [
'FLAT',
'POOL',
'ATTN',
]
LAYER_NORM = [
'C',
'HC',
'CHW',
'NC',
'NWC',
'NCHW'
]
POOL_TYPES = [
'SUM',
'AVG'
]
PARAMETERIZATIONS = [
'NTK',
'STANDARD'
]
test_utils.update_test_tolerance()
def _skip_test(msg='Skipping large tests for speed.', platforms=('cpu',)):
if xla_bridge.get_backend().platform in platforms:
raise absltest.SkipTest(msg)
def _get_inputs(
key,
same_inputs,
shape,
fn=np.cos
) -> Tuple[np.ndarray, np.ndarray]:
key, split = random.split(key)
x1 = fn(random.normal(key, shape))
batch_axis = shape.index(BATCH_SIZE)
shape = shape[:batch_axis] + (2 * BATCH_SIZE,) + shape[batch_axis + 1:]
x2 = None if same_inputs else fn(random.normal(split, shape)) * 2
return x1, x2
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm,
parameterization, use_dropout):
if is_conv:
# Select a random filter order.
default_filter_spec = 'HW'
filter_specs = [''.join(p) for p in itertools.permutations('HWIO')]
filter_spec = prandom.choice(filter_specs)
filter_shape = tuple(filter_shape[default_filter_spec.index(c)]
for c in filter_spec if c in default_filter_spec)
strides = tuple(strides[default_filter_spec.index(c)]
for c in filter_spec if c in default_filter_spec)
# Select the activation order.
default_spec = 'NHWC'
if xla_bridge.get_backend().platform == 'tpu':
# Keep batch dimension leading for TPU for batching to work.
specs = ['N' + ''.join(p) for p in itertools.permutations('CHW')]
else:
specs = [''.join(p) for p in itertools.permutations('NCHW')]
spec = prandom.choice(specs)
input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec)
else:
input_shape = (INPUT_SHAPE[0], onp.prod(INPUT_SHAPE[1:]))
if xla_bridge.get_backend().platform == 'tpu':
spec = 'NC'
else:
spec = prandom.choice(['NC', 'CN'])
if spec.index('N') == 1:
input_shape = input_shape[::-1]
filter_spec = None
dimension_numbers = (spec, filter_spec, spec)
batch_axis, channel_axis = spec.index('N'), spec.index('C')
spec_fc = ''.join(c for c in spec if c in ('N', 'C'))
batch_axis_fc, channel_axis_fc = spec_fc.index('N'), spec_fc.index('C')
if not is_conv:
batch_axis = batch_axis_fc
channel_axis = channel_axis_fc
if layer_norm:
layer_norm = tuple(spec.index(c) for c in layer_norm)
def fc(out_dim):
return stax.Dense(
out_dim=out_dim,
W_std=W_std,
b_std=b_std,
parameterization=parameterization,
batch_axis=batch_axis_fc,
channel_axis=channel_axis_fc
)
def conv(out_chan):
return stax.Conv(out_chan=out_chan, filter_shape=filter_shape,
strides=strides, padding=padding, W_std=W_std,
b_std=b_std, dimension_numbers=dimension_numbers,
parameterization=parameterization)
affine = conv(width) if is_conv else fc(width)
rate = onp.random.uniform(0.5, 0.9)
dropout = stax.Dropout(rate, mode='train')
if pool_type == 'AVG':
pool_fn = stax.AvgPool
global_pool_fn = stax.GlobalAvgPool
elif pool_type == 'SUM':
pool_fn = stax.SumPool
global_pool_fn = stax.GlobalSumPool
else:
raise ValueError(pool_type)
if use_pooling:
pool_or_identity = pool_fn((2, 3),
None,
'SAME' if padding == 'SAME' else 'CIRCULAR',
batch_axis=batch_axis,
channel_axis=channel_axis)
else:
pool_or_identity = stax.Identity()
dropout_or_identity = dropout if use_dropout else stax.Identity()
layer_norm_or_identity = (stax.Identity() if layer_norm is None else
stax.LayerNorm(axis=layer_norm,
batch_axis=batch_axis,
channel_axis=channel_axis))
res_unit = stax.serial(dropout_or_identity, affine, pool_or_identity)
if is_res:
block = stax.serial(
affine,
stax.FanOut(2),
stax.parallel(stax.Identity(),
res_unit),
stax.FanInSum(),
layer_norm_or_identity,
phi)
else:
block = stax.serial(
affine,
res_unit,
layer_norm_or_identity,
phi)
if proj_into_2d == 'FLAT':
proj_layer = stax.Flatten(batch_axis, batch_axis_fc)
elif proj_into_2d == 'POOL':
proj_layer = global_pool_fn(batch_axis, channel_axis)
elif proj_into_2d.startswith('ATTN'):
n_heads = int(np.sqrt(width))
n_chan_val = int(np.round(float(width) / n_heads))
proj_layer = stax.serial(
stax.GlobalSelfAttention(
n_chan_out=width,
n_chan_key=width,
n_chan_val=n_chan_val,
n_heads=n_heads,
linear_scaling=True,
W_key_std=W_std,
W_value_std=W_std,
W_query_std=W_std,
W_out_std=1.0,
b_std=b_std,
batch_axis=batch_axis,
channel_axis=channel_axis),
stax.Flatten(batch_axis, batch_axis_fc))
else:
raise ValueError(proj_into_2d)
readout = stax.serial(proj_layer, fc(1 if is_ntk else width))
device_count = -1 if spec.index('N') == 0 else 0
return stax.serial(block, readout), input_shape, device_count, channel_axis_fc
def _get_net_pool(width, is_ntk, pool_type, padding,
filter_shape, strides, normalize_edges):
W_std, b_std = 2.**0.5, 0.5**0.5
phi = stax.Relu()
parameterization = 'ntk'
fc = functools.partial(
stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization)
conv = functools.partial(
stax.Conv,
filter_shape=(3, 2),
strides=None,
padding='SAME',
W_std=W_std,
b_std=b_std,
parameterization=parameterization)
if pool_type == 'AVG':
pool_fn = functools.partial(stax.AvgPool, normalize_edges=normalize_edges)
global_pool_fn = stax.GlobalAvgPool
elif pool_type == 'SUM':
pool_fn = stax.SumPool
global_pool_fn = stax.GlobalSumPool
else:
raise ValueError(pool_type)
pool = pool_fn(filter_shape, strides, padding)
return stax.serial(
conv(width), phi, pool, conv(width), phi, global_pool_fn(),
fc(1 if is_ntk else width)), INPUT_SHAPE, -1, -1
def _mask(x, mask_constant, mask_axis, key, p):
if mask_constant is not None:
mask_shape = [1 if i in mask_axis else s
for i, s in enumerate(x.shape)]
mask = random.bernoulli(key, p=p, shape=mask_shape)
x = np.where(mask, mask_constant, x)
x = np.sort(x, 1)
return x
class StaxTest(test_utils.NeuralTangentsTestCase):
def _skip_test(self, filter_shape, is_conv, is_res, padding, proj_into_2d,
strides, use_pooling):
if is_conv:
if xla_bridge.get_backend().platform == 'cpu':
raise absltest.SkipTest('Not running CNN models on CPU to save time.')
if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
(padding == 'VALID' and filter_shape !=
(1, 1)))):
raise absltest.SkipTest('Different paths in a residual models need to '
'return outputs of the same shape.')
elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or
strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
use_pooling):
raise absltest.SkipTest('FC models do not have these parameters.')
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
model, phi_name, width, 'same_inputs'
if same_inputs else 'different_inputs', 'filter_shape=%s' %
str(filter_shape), 'padding=%s' % padding, 'strides=%s' %
str(strides), 'pool' if use_pooling else 'flatten',
'NTK' if is_ntk else 'NNGP', 'RESNET' if is_res else 'serial',
proj_into_2d),
'model':
model,
'width':
width,
'strides':
strides,
'padding':
padding,
'phi':
phi,
'same_inputs':
same_inputs,
'filter_shape':
filter_shape,
'use_pooling':
use_pooling,
'is_ntk':
is_ntk,
'is_res':
is_res,
'proj_into_2d':
proj_into_2d
}
for model in MODELS
for width in WIDTHS
for phi, phi_name in ACTIVATIONS.items()
for same_inputs in [False]
for padding in PADDINGS for strides in STRIDES
for filter_shape in FILTER_SHAPES
for use_pooling in [False, True]
for is_ntk in [False, True]
for is_res in [False, True]
for proj_into_2d in PROJECTIONS))
def test_exact(self, model, width, strides, padding, phi, same_inputs,
filter_shape, use_pooling, is_ntk, is_res, proj_into_2d):
is_conv = 'conv' in model
# Check for duplicate / incorrectly-shaped NN configs / wrong backend.
self._skip_test(filter_shape, is_conv, is_res, padding, proj_into_2d,
strides, use_pooling)
pool_type = 'AVG'
W_std, b_std = 2.**0.5, 0.5**0.5
layer_norm = None
parameterization = 'ntk'
use_dropout = False
net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
padding, phi, strides, width, is_ntk, proj_into_2d,
pool_type, layer_norm, parameterization, use_dropout)
self._check_agreement_with_empirical(
net, same_inputs, use_dropout, is_ntk, RTOL)
# pylint: disable=g-complex-comprehension
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_{}_{}_{}_{}_{}_{}'.format(
model, width, 'same_inputs'
if same_inputs else 'different_inputs', 'filter_shape=%s' %
str(filter_shape), proj_into_2d, 'NTK' if is_ntk else 'NNGP',
'parameterization=%s' % str(parameterization)),
'model':
model,
'width':
width,
'same_inputs':
same_inputs,
'filter_shape':
filter_shape,
'proj_into_2d':
proj_into_2d,
'is_ntk':
is_ntk,
'parameterization':
parameterization
} for model in MODELS for width in WIDTHS
for same_inputs in [False]
for is_ntk in [False, True]
for filter_shape in FILTER_SHAPES
for proj_into_2d in PROJECTIONS[:2]
for parameterization in PARAMETERIZATIONS))
def test_parameterizations(self, model, width, same_inputs, is_ntk,
filter_shape, proj_into_2d, parameterization):
is_conv = 'conv' in model
W_std, b_std = 2.**0.5, 0.5**0.5
padding = PADDINGS[0]
strides = STRIDES[0]
phi = stax.Relu()
use_pooling, is_res = False, False
layer_norm = None
pool_type = 'AVG'
use_dropout = False
# Check for duplicate / incorrectly-shaped NN configs / wrong backend.
if is_conv:
if xla_bridge.get_backend().platform == 'cpu':
raise absltest.SkipTest('Not running CNN models on CPU to save time.')
elif proj_into_2d != PROJECTIONS[0]:
raise absltest.SkipTest('FC models do not have these parameters.')
net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
padding, phi, strides, width, is_ntk, proj_into_2d,
pool_type, layer_norm, parameterization, use_dropout)
self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_{}_{}_{}_{}_{}'.format(
model,
width,
'same_inputs' if same_inputs else 'different_inputs',
'NTK' if is_ntk else 'NNGP',
proj_into_2d,
'layer_norm=%s' % str(layer_norm)),
'model':
model,
'width':
width,
'same_inputs':
same_inputs,
'is_ntk':
is_ntk,
'proj_into_2d':
proj_into_2d,
'layer_norm':
layer_norm
}
for model in MODELS
for width in WIDTHS
for same_inputs in [False]
for is_ntk in [False, True]
for proj_into_2d in PROJECTIONS[:2]
for layer_norm in LAYER_NORM))
def test_layernorm(self,
model,
width,
same_inputs,
is_ntk,
proj_into_2d,
layer_norm):
is_conv = 'conv' in model
# Check for duplicate / incorrectly-shaped NN configs / wrong backend.
if is_conv:
if xla_bridge.get_backend().platform == 'cpu':
raise absltest.SkipTest('Not running CNN models on CPU to save time.')
elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC'):
raise absltest.SkipTest('FC models do not have these parameters.')
W_std, b_std = 2.**0.5, 0.5**0.5
filter_shape = FILTER_SHAPES[0]
padding = PADDINGS[0]
strides = STRIDES[0]
phi = stax.Relu()
use_pooling, is_res = False, False
parameterization = 'ntk'
pool_type = 'AVG'
use_dropout = False
net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
padding, phi, strides, width, is_ntk, proj_into_2d,
pool_type, layer_norm, parameterization, use_dropout)
self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk,
0.05)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_{}_{}_{}_{}_{}_{}_{}'.format(
width, 'same_inputs' if same_inputs else 'different_inputs',
'filter_shape=%s' % str(filter_shape), 'padding=%s' %
padding, 'strides=%s' % str(strides),
'NTK' if is_ntk else 'NNGP', 'pool_type=%s' %
str(pool_type), 'normalize_edges=%s' % str(normalize_edges)),
'width':
width,
'same_inputs':
same_inputs,
'is_ntk':
is_ntk,
'pool_type':
pool_type,
'padding':
padding,
'filter_shape':
filter_shape,
'strides':
strides,
'normalize_edges':
normalize_edges
} for width in WIDTHS for same_inputs in [False]
for is_ntk in [False, True]
for pool_type in POOL_TYPES for padding in PADDINGS
for filter_shape in FILTER_SHAPES
for strides in STRIDES
for normalize_edges in [True, False]))
def test_pool(self, width, same_inputs, is_ntk, pool_type,
padding, filter_shape, strides, normalize_edges):
use_dropout = False
# Check for duplicate / incorrectly-shaped NN configs / wrong backend.
if xla_bridge.get_backend().platform == 'cpu':
raise absltest.SkipTest('Not running CNN models on CPU to save time.')
if pool_type == 'SUM' and normalize_edges:
raise absltest.SkipTest('normalize_edges not applicable to SumPool.')
net = _get_net_pool(width, is_ntk, pool_type,
padding, filter_shape, strides, normalize_edges)
self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk)
def test_avg_pool(self):
X1 = np.ones((4, 2, 3, 2))
X2 = np.ones((3, 2, 3, 2))
_, apply_fn, kernel_fn = stax.AvgPool((2, 2), (1, 1), 'SAME',
normalize_edges=False)
_, apply_fn_norm, kernel_fn_norm = stax.AvgPool((2, 2), (1, 1), 'SAME',
normalize_edges=True)
_, apply_fn_stax = stax.ostax.AvgPool((2, 2), (1, 1), 'SAME')
out1 = apply_fn((), X1)
out2 = apply_fn((), X2)
out1_norm = apply_fn_norm((), X1)
out2_norm = apply_fn_norm((), X2)
out1_stax = apply_fn_stax((), X1)
out2_stax = apply_fn_stax((), X2)
self.assertAllClose((out1_stax, out2_stax), (out1_norm, out2_norm))
out_unnorm = np.array([[1., 1., 0.5], [0.5, 0.5, 0.25]]).reshape(
(1, 2, 3, 1))
out1_unnormalized = np.broadcast_to(out_unnorm, X1.shape)
out2_unnormalized = np.broadcast_to(out_unnorm, X2.shape)
self.assertAllClose((out1_unnormalized, out2_unnormalized), (out1, out2))
ker = kernel_fn(X1, X2)
ker_norm = kernel_fn_norm(X1, X2)
self.assertAllClose(np.ones_like(ker_norm.nngp), ker_norm.nngp)
self.assertAllClose(np.ones_like(ker_norm.cov1), ker_norm.cov1)
self.assertAllClose(np.ones_like(ker_norm.cov2), ker_norm.cov2)
self.assertEqual(ker_norm.nngp.shape, ker.nngp.shape)
self.assertEqual(ker_norm.cov1.shape, ker.cov1.shape)
self.assertEqual(ker_norm.cov2.shape, ker.cov2.shape)
ker_unnorm = np.outer(out_unnorm, out_unnorm).reshape((2, 3, 2, 3))
ker_unnorm = np.transpose(ker_unnorm, axes=(0, 2, 1, 3))
nngp = np.broadcast_to(
ker_unnorm.reshape((1, 1) + ker_unnorm.shape), ker.nngp.shape)
cov1 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.cov1.shape)
cov2 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.cov2.shape)
self.assertAllClose((nngp, cov1, cov2), (ker.nngp, ker.cov1, ker.cov2))
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
model, phi_name, width, 'same_inputs'
if same_inputs else 'different_inputs', 'filter_shape=%s' %
str(filter_shape), 'padding=%s' % padding, 'strides=%s' %
str(strides), 'pool' if use_pooling else 'flatten',
'NTK' if is_ntk else 'NNGP', proj_into_2d),
'model':
model,
'width':
width,
'same_inputs':
same_inputs,
'is_ntk':
is_ntk,
'padding':
padding,
'strides':
strides,
'filter_shape':
filter_shape,
'phi':
phi,
'use_pooling':
use_pooling,
'proj_into_2d':
proj_into_2d
} for model in MODELS for width in WIDTHS
for same_inputs in [True, False]
for phi, phi_name in ACTIVATIONS.items()
for padding in ['SAME'] for strides in STRIDES
for filter_shape in [(2, 1)]
for is_ntk in [True, False]
for use_pooling in [True, False]
for proj_into_2d in ['FLAT', 'POOL']))
def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides,
filter_shape, phi, use_pooling, proj_into_2d):
pool_type = 'AVG'
use_dropout = True
is_conv = 'conv' in model
is_res = False
W_std, b_std = 2.**0.5, 0.5**0.5
layer_norm = None
parameterization = 'ntk'
# Check for duplicate / incorrectly-shaped NN configs / wrong backend.
self._skip_test(filter_shape, is_conv, is_res, padding, proj_into_2d,
strides, use_pooling)
net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
padding, phi, strides, width, is_ntk, proj_into_2d,
pool_type, layer_norm, parameterization, use_dropout)
self._check_agreement_with_empirical(net, same_inputs, use_dropout, is_ntk)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
f'_act={act}_kernel={kern}_do_stabilize={do_stabilize}',
'act': act,
'kernel': kern,
'do_stabilize': do_stabilize
}
for act in ['erf', 'relu']
for do_stabilize in [True, False]
for kern in ['nngp', 'ntk']))
def test_sparse_inputs(self, act, kernel, do_stabilize):
if do_stabilize and act != 'relu':
raise absltest.SkipTest('Stabilization possible only in Relu.')
key = random.PRNGKey(1)
input_count = 4
sparse_count = 2
input_size = 128
width = 4096
# NOTE(schsam): It seems that convergence is slower when inputs are sparse.
samples = N_SAMPLES
if xla_bridge.get_backend().platform == 'gpu':
jtu._default_tolerance[onp.dtype(onp.float64)] = 5e-4
samples = 100 * N_SAMPLES
else:
jtu._default_tolerance[onp.dtype(onp.float32)] = 5e-2
jtu._default_tolerance[onp.dtype(onp.float64)] = 5e-3
# a batch of dense inputs
x_dense = random.normal(key, (input_count, input_size))
x_sparse = ops.index_update(x_dense, ops.index[:sparse_count, :], 0.)
activation = (stax.Relu(do_stabilize=do_stabilize) if act == 'relu'
else stax.Erf())
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(width),
activation,
stax.Dense(1 if kernel == 'ntk' else width))
exact = kernel_fn(x_sparse, None, kernel)
mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn,
random.split(key, 2)[0],
samples,
vmap_axes=0,
implementation=2)(
x_sparse, None, kernel)
mc = np.reshape(mc, exact.shape)
assert not np.any(np.isnan(exact))
self.assertAllClose(exact[sparse_count:, sparse_count:],
mc[sparse_count:, sparse_count:])
def test_composition_dense(self):
rng = random.PRNGKey(0)
x1 = random.normal(rng, (10, 10))
x2 = random.normal(rng, (10, 10))
Block = stax.serial(stax.Dense(256), stax.Relu())
_, _, ker_fn = Block
_, _, composed_ker_fn = stax.serial(Block, Block)
ker_out = ker_fn(ker_fn(x1))
composed_ker_out = composed_ker_fn(x1)
self.assertAllClose(ker_out, composed_ker_out)
ker_out = ker_fn(ker_fn(x1, x2))
composed_ker_out = composed_ker_fn(x1, x2)
self.assertAllClose(ker_out, composed_ker_out)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name': '_avg_pool={}_same_inputs={}'.format(avg_pool,
same_inputs),
'avg_pool': avg_pool,
'same_inputs': same_inputs
} for avg_pool in [True, False] for same_inputs in [True, False]))
def test_composition_conv(self, avg_pool, same_inputs):
rng = random.PRNGKey(0)
x1 = random.normal(rng, (5, 10, 10, 3))
x2 = None if same_inputs else random.normal(rng, (5, 10, 10, 3))
Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
if avg_pool:
Readout = stax.serial(stax.Conv(256, (3, 3)),
stax.GlobalAvgPool(),
stax.Dense(10))
else:
Readout = stax.serial(stax.Flatten(), stax.Dense(10))
block_ker_fn, readout_ker_fn = Block[2], Readout[2]
_, _, composed_ker_fn = stax.serial(Block, Readout)
composed_ker_out = composed_ker_fn(x1, x2)
ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2,
diagonal_spatial=False))
ker_out_default = readout_ker_fn(block_ker_fn(x1, x2))
self.assertAllClose(composed_ker_out, ker_out_no_marg)
self.assertAllClose(composed_ker_out, ker_out_default)
if avg_pool:
with self.assertRaises(ValueError):
ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True))
else:
ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2,
diagonal_spatial=True))
self.assertAllClose(composed_ker_out, ker_out_marg)
def _check_agreement_with_empirical(
self,
net,
same_inputs,
use_dropout,
is_ntk,
rtol=RTOL
):
((init_fn, apply_fn, kernel_fn),
input_shape, device_count, channel_axis) = net
num_samples = N_SAMPLES * 5 if use_dropout else N_SAMPLES
key = random.PRNGKey(1)
x1, x2 = _get_inputs(key, same_inputs, input_shape)
if xla_bridge.get_backend().platform == 'tpu' and use_dropout:
# including a test case for tpu + dropout with (parallel + batching)
batch_size = 2
else:
batch_size = 0
x1_out_shape, params = init_fn(key, x1.shape)
if same_inputs:
assert x2 is None
if x2 is None:
x2_out_shape = x1_out_shape
else:
x2_out_shape, params = init_fn(key, x2.shape)
del params
def _get_empirical(n_samples, get):
kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn(
init_fn, apply_fn, key, n_samples, device_count=device_count,
trace_axes=(channel_axis,), batch_size=batch_size,
implementation=2
)
if same_inputs:
assert x2 is None
return kernel_fn_empirical(x1, x2, get)
if is_ntk:
exact, shape1, shape2 = kernel_fn(x1, x2, ('ntk', 'shape1', 'shape2'))
empirical = _get_empirical(num_samples, 'ntk')
else:
exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2'))
empirical = _get_empirical(num_samples, 'nngp')
test_utils.assert_close_matrices(self, exact, empirical, rtol)
self.assertEqual(shape1, x1_out_shape)
self.assertEqual(shape2, x2_out_shape)
class ActivationTest(test_utils.NeuralTangentsTestCase):
@stax.layer
def _RBF(self, gamma):
init_fn = lambda key, input_shape: (input_shape, ())
def apply_fn(unused_params, unused_xs, **kwargs):
raise NotImplementedError()
def kernel_fn(kernels, **kwargs):
if kernels.ntk is not None:
raise ValueError('RBF Kernel does not have an associated NTK.')
if kernels.nngp.ndim > 2:
raise ValueError(
('RBF Kernel is not defined for covariance matrices with dimension'
' greater than two.'))
input_dim = kernels.shape1[1]
cov1 = kernels.cov1
cov1 = np.reshape(cov1, (cov1.shape[0], 1))
cov2 = cov1 if kernels.cov2 is None else kernels.cov2
cov2 = np.reshape(cov2, (1, cov2.shape[0]))
nngp = kernels.nngp
# TODO(schsam): Update cov1 and cov2 if we want to compose this kernel
# with other kernels.
return kernels.replace(
nngp=np.exp(-input_dim * gamma * (cov1 + cov2 - 2 * nngp)))
return init_fn, apply_fn, kernel_fn
def _test_activation(self, activation_fn, same_inputs, model, get,
rbf_gamma=None):
platform = xla_bridge.get_backend().platform
if platform == 'cpu' and 'conv' in model:
raise absltest.SkipTest('Not running CNNs on CPU to save time.')
key = random.PRNGKey(1)
key, split = random.split(key)
output_dim = 2048 if get == 'nngp' else 1
b_std = 0.5
W_std = 2.0
if activation_fn[2].__name__ == 'Sin':
W_std = 0.9
if activation_fn[2].__name__ == 'Rbf':
W_std = 1.0
b_std = 0.0
if model == 'fc':
rtol = 0.05
X0_1 = random.normal(key, (6, 7))
X0_2 = None if same_inputs else random.normal(split, (10, 7))
affine = stax.Dense(1024, W_std, b_std)
readout = stax.Dense(output_dim)
depth = 1
else:
rtol = 0.1
X0_1 = random.normal(key, (4, 8, 8, 3))
X0_2 = None if same_inputs else random.normal(split, (6, 8, 8, 3))
affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
stax.Flatten(),
stax.Dense(output_dim))
depth = 2
if platform == 'cpu':
num_samplings = 200
rtol *= 2
else:
num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf')
else 300)
init_fn, apply_fn, kernel_fn = stax.serial(
*[affine, activation_fn]*depth, readout)
analytic_kernel = kernel_fn(X0_1, X0_2, get)
mc_kernel_fn = monte_carlo.monte_carlo_kernel_fn(
init_fn, apply_fn, split, num_samplings, implementation=2,
vmap_axes=0
)
empirical_kernel = mc_kernel_fn(X0_1, X0_2, get)
test_utils.assert_close_matrices(self, analytic_kernel,
empirical_kernel, rtol)
# Check match with explicit RBF
if rbf_gamma is not None and get == 'nngp' and model == 'fc':
input_dim = X0_1.shape[1]
_, _, kernel_fn = self._RBF(rbf_gamma / input_dim)
direct_rbf_kernel = kernel_fn(X0_1, X0_2, get)
test_utils.assert_close_matrices(self, analytic_kernel,
direct_rbf_kernel, rtol)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_{}_{}_{}_{}'.format(
model,
phi_name,
'Same_inputs' if same_inputs else 'Different_inputs',
get,
abc),
'model':
model,
'phi_name':
phi_name,
'same_inputs':
same_inputs,
'get': get,
'abc': abc,
}
for model in ['fc', 'conv-pool', 'conv-flatten']
for phi_name in ['Sin', 'Erf', 'Gelu', 'Sign']
for same_inputs in [False]
for get in ['nngp', 'ntk']
for abc in itertools.product(
[2., 0.3],
[1.5, 0.3],
[0., -np.pi/4., np.pi/2.])))
def test_activation(self, same_inputs, model, phi_name, get, abc):
platform = xla_bridge.get_backend().platform
if platform == 'cpu':
if abc != [0.3, 1.5, -np.pi/4]:
raise absltest.SkipTest('Skipping Activation test on CPU to save time.')
a, b, c = abc
if phi_name == 'Sin':
activation = stax.Sin(a=a, b=b, c=c)
elif phi_name == 'Erf':
activation = stax.Erf(a=a, b=b, c=c)
elif phi_name in ['Gelu', 'Sign']:
if a != 0.3 or b != 0.3 or c != 0.:
raise absltest.SkipTest('Skip `Gelu/Sign` test if '
' (a, b, c) != (.3, .3, 0.).')
activation = stax.Gelu() if phi_name == 'Gelu' else stax.Sign()
else:
raise absltest.SkipTest(f'Activation {phi_name} is not implemented.')
self._test_activation(activation, same_inputs, model, get)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_Rbf_{}_{}_{}'.format(
model,
'Same_inputs' if same_inputs else 'Different_inputs',
get,
gamma),
'model':
model,
'same_inputs':
same_inputs,
'get': get,
'gamma': gamma,
}
for model in ['fc', 'conv-pool', 'conv-flatten']
for same_inputs in [False, True]
for get in ['nngp', 'ntk']
for gamma in [1e-6, 1e-4, 1e-2, 1.0, 2.]
))
def test_rbf(self, same_inputs, model, get, gamma):
activation = stax.Rbf(gamma)
self._test_activation(activation, same_inputs, model, get,
rbf_gamma=gamma)
class ElementwiseNumericalTest(test_utils.NeuralTangentsTestCase):
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_{}_{}_{}'.format(
model,
phi[0].__name__,
'Same_inputs' if same_inputs else 'Different_inputs',
get),
'model': model,
'phi': phi,
'same_inputs': same_inputs,
'get': get,
}
for model in ['fc', 'conv-pool', 'conv-flatten']
for phi in [
stax.Erf(),
stax.Gelu(),
stax.Sin(),
]
for same_inputs in [False, True]
for get in ['nngp', 'ntk']))
def test_elementwise_numerical(self, same_inputs, model, phi, get):
platform = xla_bridge.get_backend().platform
if platform == 'cpu' and 'conv' in model:
raise absltest.SkipTest('Not running CNNs on CPU to save time.')
key, split = random.split(random.PRNGKey(1))