forked from google/jax
/
lax_numpy_test.py
4211 lines (3845 loc) · 186 KB
/
lax_numpy_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
from functools import partial
import inspect
import itertools
import operator
from typing import cast, Optional
import unittest
from unittest import SkipTest
import warnings
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
import jax.ops
from jax import api
from jax import lax
from jax import linear_util
from jax import numpy as jnp
from jax import test_util as jtu
from jax import dtypes
from jax import tree_util
from jax.interpreters import partial_eval, xla
from jax.test_util import check_grads
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
one_dim_array_shapes = [(1,), (6,), (12,)]
empty_array_shapes = [(0,), (0, 4), (3, 0),]
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
array_shapes = nonempty_array_shapes + empty_array_shapes
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
nonempty_shapes = scalar_shapes + nonempty_array_shapes
all_shapes = scalar_shapes + array_shapes
float_dtypes = jtu.dtypes.all_floating
complex_dtypes = jtu.dtypes.complex
int_dtypes = jtu.dtypes.all_integer
unsigned_dtypes = jtu.dtypes.all_unsigned
bool_dtypes = jtu.dtypes.boolean
default_dtypes = float_dtypes + int_dtypes
inexact_dtypes = float_dtypes + complex_dtypes
number_dtypes = float_dtypes + complex_dtypes + int_dtypes
all_dtypes = number_dtypes + bool_dtypes
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
def _valid_dtypes_for_shape(shape, dtypes):
# Not all (shape, dtype) pairs are valid. In particular, Python scalars only
# have one type in each category (float, bool, etc.)
if shape is jtu.PYTHON_SCALAR_SHAPE:
return [t for t in dtypes if t in python_scalar_dtypes]
return dtypes
def _shape_and_dtypes(shapes, dtypes):
for shape in shapes:
for dtype in _valid_dtypes_for_shape(shape, dtypes):
yield (shape, dtype)
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
"test_name", "check_dtypes", "tolerance", "inexact"])
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name=None, check_dtypes=True,
tolerance=None, inexact=False):
test_name = test_name or name
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name, check_dtypes, tolerance, inexact)
JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("abs", 1, number_dtypes + unsigned_dtypes,
all_shapes, jtu.rand_default, ["rev"]),
op_record("add", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("ceil", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("float_power", 2, inexact_dtypes, all_shapes,
partial(jtu.rand_default, scale=1), ["rev"],
tolerance={jnp.bfloat16: 1e-2, np.float32: 1e-3,
np.float64: 1e-12, np.complex64: 2e-4,
np.complex128: 1e-12}, check_dtypes=False),
op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("floor", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("maximum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("minimum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"]),
op_record("trunc", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("trunc", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, [], check_dtypes=False),
op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("tan", 1, number_dtypes, all_shapes,
partial(jtu.rand_uniform, low=-1.5, high=1.5), ["rev"],
inexact=True),
op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
# TODO(b/142975473): on CPU, tanh for complex128 is only accurate to
# ~float32 precision.
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
tolerance={np.float64: 1e-7, np.complex128: 1e-7},
inexact=True),
op_record("arcsin", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arccos", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arctan", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True, tolerance={np.float64: 1e-9}),
]
JAX_COMPOUND_OP_RECORDS = [
# angle has inconsistent 32/64-bit return types across numpy versions.
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False, inexact=True),
op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("cbrt", 1, default_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero, ["rev"],
inexact=True),
op_record("divmod", 2, int_dtypes + float_dtypes, all_shapes,
jtu.rand_nonzero, []),
op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
tolerance={jnp.bfloat16: 4e-2, np.float16: 1e-2}, inexact=True),
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
# precision.
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive, [],
test_name="expm1_large", tolerance={np.float64: 1e-8}, inexact=True),
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive,
[], tolerance={np.float64: 1e-8}, inexact=True),
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("fix", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("floor_divide", 2, number_dtypes, all_shapes,
jtu.rand_nonzero, ["rev"]),
op_record("floor_divide", 2, unsigned_dtypes, all_shapes,
jtu.rand_nonzero, ["rev"]),
op_record("fmin", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("fmax", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("isfinite", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isinf", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isnan", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive, [],
test_name="log1p_large", tolerance={np.float64: 1e-12},
inexact=True),
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive, [],
tolerance={np.float64: 1e-12}, inexact=True),
op_record("logaddexp", 2, float_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"],
tolerance={np.float64: 1e-12}, inexact=True),
op_record("logaddexp2", 2, float_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"],
tolerance={np.float16: 1e-2, np.float64: 2e-14}, inexact=True),
op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes,
jtu.rand_default, [], check_dtypes=False,
tolerance={dtypes.bfloat16: 4e-2, np.float16: 1e-2,
np.float64: 1e-12}),
op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
tolerance={np.complex128: 1e-14}),
op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={np.float16: 1e-2}),
op_record("mod", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("modf", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("modf", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("rint", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan,
[]),
op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("sign", 1, number_dtypes + unsigned_dtypes,
all_shapes, jtu.rand_some_inf_and_nan, []),
op_record('copysign', 2, default_dtypes, all_shapes, jtu.rand_some_inf_and_nan, [],
check_dtypes=False),
op_record("sinc", 1, [t for t in number_dtypes if t != jnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"],
tolerance={np.complex64: 1e-5}, inexact=True,
check_dtypes=False),
op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("transpose", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"],
check_dtypes=False),
op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero,
["rev"], inexact=True),
op_record("diff", 1, number_dtypes, nonzerodim_shapes, jtu.rand_default, ["rev"]),
op_record("ediff1d", 3, [np.int32], all_shapes, jtu.rand_default, []),
op_record("unwrap", 1, float_dtypes, nonempty_nonscalar_array_shapes,
jtu.rand_default, ["rev"],
# numpy.unwrap always returns float64
check_dtypes=False,
# numpy cumsum is inaccurate, see issue #3517
tolerance={dtypes.bfloat16: 1e-1, np.float16: 1e-1})
]
JAX_BITWISE_OP_RECORDS = [
op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
]
JAX_REDUCER_RECORDS = [
op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_some_nan,
[], inexact=True),
op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("nansum", 1, number_dtypes, all_shapes, jtu.rand_some_nan, []),
]
JAX_REDUCER_NO_DTYPE_RECORDS = [
op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("nanvar", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
[], inexact=True),
op_record("nanstd", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
[], inexact=True),
]
JAX_ARGMINMAX_RECORDS = [
op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
op_record("argmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
op_record("nanargmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
op_record("nanargmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
]
JAX_OPERATOR_OVERLOADS = [
op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__le__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [],
tolerance={np.float32: 2e-4, np.complex64: 2e-4, np.complex128: 1e-14}),
op_record("__mod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={np.float16: 1e-1}),
op_record("__floordiv__", 2, default_dtypes, all_shapes,
jtu.rand_nonzero, []),
op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [],
inexact=True),
op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
# TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2
op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []),
# TODO(mattjj): investigate these failures
# op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
# op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
# TODO(mattjj): lshift, rshift
]
JAX_RIGHT_OPERATOR_OVERLOADS = [
op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [],
tolerance={np.float32: 2e-4, np.complex64: 1e-3}),
op_record("__rmod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={np.float16: 1e-1}),
op_record("__rfloordiv__", 2, default_dtypes, all_shapes,
jtu.rand_nonzero, []),
op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [],
inexact=True),
# op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
# op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
]
class _OverrideEverything(object):
pass
for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS:
if rec.nargs == 2:
setattr(_OverrideEverything, rec.name, lambda self, other: self)
class _OverrideNothing(object):
pass
for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS:
if rec.nargs == 2:
setattr(_OverrideNothing, rec.name, lambda self, other: NotImplemented)
numpy_version = tuple(map(int, np.version.version.split('.')))
if numpy_version >= (1, 15):
JAX_COMPOUND_OP_RECORDS += [
op_record("isclose", 2, [t for t in all_dtypes if t != jnp.bfloat16],
all_shapes, jtu.rand_small_positive, []),
# uint64 is problematic because with any other int type it promotes to float.
op_record("gcd", 2, [d for d in int_dtypes + unsigned_dtypes if d != np.uint64],
all_shapes, jtu.rand_default, []),
op_record("lcm", 2, [d for d in int_dtypes + unsigned_dtypes if d != np.uint64],
all_shapes, jtu.rand_default, []),
]
JAX_REDUCER_NO_DTYPE_RECORDS += [
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
]
def _dtypes_are_compatible_for_bitwise_ops(args):
if len(args) <= 1:
return True
is_signed = lambda dtype: jnp.issubdtype(dtype, np.signedinteger)
width = lambda dtype: jnp.iinfo(dtype).bits
x, y = args
if width(x) > width(y):
x, y = y, x
# The following condition seems a little ad hoc, but seems to capture what
# numpy actually implements.
return (
is_signed(x) == is_signed(y)
or (width(x) == 32 and width(y) == 32)
or (width(x) == 32 and width(y) == 64 and is_signed(y)))
def _shapes_are_broadcast_compatible(shapes):
accumulator = np.zeros([])
for shape in shapes:
try:
accumulator = accumulator + np.zeros(shape)
except ValueError:
return False
return True
def _shapes_are_equal_length(shapes):
return all(len(shape) == len(shapes[0]) for shape in shapes[1:])
def _promote_like_jnp(fun, inexact=False):
"""Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`.
jnp and np have different type promotion semantics; this decorator allows
tests make an np reference implementation act more like an jnp
implementation.
"""
def wrapper(*args, **kw):
flat_args = tree_util.tree_leaves(args)
if inexact and not any(jnp.issubdtype(jnp.result_type(x), jnp.inexact)
for x in flat_args):
dtype = jnp.result_type(jnp.float_, *flat_args)
else:
dtype = jnp.result_type(*flat_args)
args = tree_util.tree_map(lambda a: np.asarray(a, dtype), args)
return fun(*args, **kw)
return wrapper
class LaxBackedNumpyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy implementation."""
def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True):
def f():
out = [rng(shape, dtype or jnp.float_)
for shape, dtype in zip(shapes, dtypes)]
if np_arrays:
return out
return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a
for a in out]
return f
def testNotImplemented(self):
for name in jnp._NOT_IMPLEMENTED:
func = getattr(jnp, name)
with self.assertRaises(NotImplementedError):
func()
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance,
"inexact": rec.inexact}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
JAX_COMPOUND_OP_RECORDS)))
def testOp(self, np_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes,
tolerance, inexact):
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value.*")(np_op)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="divide by zero.*")(np_op)
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes)
tol = functools.reduce(jtu.join_tolerance,
[tolerance, tol, jtu.default_tolerance()])
self._CheckAgainstNumpy(_promote_like_jnp(np_op, inexact), jnp_op,
args_maker, check_dtypes=check_dtypes, tol=tol)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=check_dtypes,
atol=tol, rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name,
"tol": rec.tolerance}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in JAX_OPERATOR_OVERLOADS))
def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol):
rng = rng_factory(self.rng())
# np and jnp arrays have different type promotion rules; force the use of
# jnp arrays.
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
fun = lambda *xs: getattr(operator, name.strip('_'))(*xs)
self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name,
"op_tolerance": rec.tolerance}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in JAX_RIGHT_OPERATOR_OVERLOADS))
def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes,
op_tolerance):
if shapes[1] is jtu.PYTHON_SCALAR_SHAPE:
raise SkipTest("scalars not implemented") # TODO(mattjj): clean up
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
fun = lambda fst, snd: getattr(snd, name)(fst)
tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes)
self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": rec.test_name + "_{}".format(dtype),
"rng_factory": rec.rng_factory,
"op_name": rec.name, "dtype": dtype}
for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2
for dtype in rec.dtypes))
def testBinaryOperatorDefers(self, op_name, rng_factory, dtype):
rng = rng_factory(self.rng())
arg = jax.device_put(rng((), dtype))
op = getattr(operator, op_name)
other = _OverrideEverything()
assert op(other, arg) is other
assert op(arg, other) is other
other = _OverrideNothing()
if op_name == "__eq__":
assert op(other, arg) is False
assert op(arg, other) is False
elif op_name == "__ne__":
assert op(other, arg) is True
assert op(arg, other) is True
else:
with self.assertRaises(TypeError):
op(other, arg)
with self.assertRaises(TypeError):
op(arg, other)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.test_name, shapes, dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name)}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in filter(
_dtypes_are_compatible_for_bitwise_ops,
itertools.combinations_with_replacement(rec.dtypes, rec.nargs)))
for rec in JAX_BITWISE_OP_RECORDS))
def testBitwiseOp(self, np_op, jnp_op, rng_factory, shapes, dtypes):
rng = rng_factory(self.rng())
if not FLAGS.jax_enable_x64 and any(
jnp.iinfo(dtype).bits == 64 for dtype in dtypes):
self.skipTest("x64 types are disabled by jax_enable_x64")
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis,
"None" if out_dtype is None else np.dtype(out_dtype).name, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for out_dtype in [None] + rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True])
for rec in JAX_REDUCER_RECORDS))
def testReducer(self, np_op, jnp_op, rng_factory, shape, dtype, out_dtype,
axis, keepdims, inexact):
rng = rng_factory(self.rng())
@jtu.ignore_warning(category=np.ComplexWarning)
@jtu.ignore_warning(category=RuntimeWarning,
message="mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered.*")
def np_fun(x):
x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32)
t = out_dtype if out_dtype != jnp.bfloat16 else np.float32
return np_op(x_cast, axis, dtype=t, keepdims=keepdims)
np_fun = _promote_like_jnp(np_fun, inexact)
jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol_spec = {np.float16: 1e-2, np.int32: 1E-3, np.float32: 1e-3,
np.complex64: 1e-3, np.float64: 1e-5, np.complex128: 1e-5}
tol = jtu.tolerance(dtype, tol_spec)
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for rec in JAX_REDUCER_NO_DTYPE_RECORDS
for shape in rec.shapes for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True]))
def testReducerNoDtype(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, inexact):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
def np_fun(x):
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
np_fun = _promote_like_jnp(np_fun, inexact)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for shape in all_shapes for dtype in all_dtypes
for axis in list(range(-len(shape), len(shape))) + [None]))
def testCountNonzero(self, shape, dtype, axis):
rng = jtu.rand_some_zero(self.rng())
np_fun = lambda x: np.count_nonzero(x, axis)
jnp_fun = lambda x: jnp.count_nonzero(x, axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes for dtype in all_dtypes))
def testNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = lambda x: np.nonzero(x)
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np_fun)
jnp_fun = lambda x: jnp.nonzero(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes for dtype in all_dtypes))
def testFlatNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np.flatnonzero)
jnp_fun = jnp.flatnonzero
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes for dtype in all_dtypes))
def testArgWhere(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np.argwhere)
jnp_fun = jnp.argwhere
args_maker = lambda: [rng(shape, dtype)]
if shape in (scalar_shapes + [()]) and np.__version__ < "1.18":
self.skipTest("np.argwhere() result for scalar input changed in numpy 1.18.")
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis}
for rec in JAX_ARGMINMAX_RECORDS
for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes)
for axis in range(-len(shape), len(shape))))
def testArgMinMax(self, np_op, jnp_op, rng_factory, shape, dtype, axis):
rng = rng_factory(self.rng())
if dtype == np.complex128 and jtu.device_under_test() == "gpu":
raise unittest.SkipTest("complex128 reductions not supported on GPU")
if "nan" in np_op.__name__ and dtype == jnp.bfloat16:
raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays")
def np_fun(array_to_reduce):
return np_op(array_to_reduce, axis).astype(jnp.int_)
def jnp_fun(array_to_reduce):
return jnp_op(array_to_reduce, axis)
args_maker = lambda: [rng(shape, dtype)]
try:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
except ValueError as e:
if str(e) == "All-NaN slice encountered":
self.skipTest("JAX doesn't support checking for all-NaN slices")
else:
raise
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": rec.test_name.capitalize(), "name": rec.name,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name)}
for rec in JAX_ARGMINMAX_RECORDS))
def testArgMinMaxEmpty(self, name, np_op, jnp_op):
name = name[3:] if name.startswith("nan") else name
msg = "attempt to get {} of an empty sequence".format(name)
with self.assertRaises(ValueError, msg=msg):
jnp_op(np.array([]))
with self.assertRaises(ValueError, msg=msg):
jnp_op(np.zeros((2, 0)), axis=1)
np_fun = partial(np_op, axis=0)
jnp_fun = partial(jnp_op, axis=0)
args_maker = lambda: [np.zeros((2, 0))]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
axes),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"axes": axes, "rng_factory": rng_factory}
for rng_factory in [jtu.rand_default]
for lhs_shape, rhs_shape, axes in [
[(2,), (2,), (-1, -1, -1, None)], # scalar output
[(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors
[(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors
[(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting
[(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes
[(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting
[(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors
[(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting
[(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing
[(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before
]
for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2)))
def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
axisa, axisb, axisc, axis = axes
jnp_fun = lambda a, b: jnp.cross(a, b, axisa, axisb, axisc, axis)
def np_fun(a, b):
a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a
b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b
out = np.cross(a, b, axisa, axisb, axisc, axis)
return out.astype(jnp.promote_types(lhs_dtype, rhs_dtype))
tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15}
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
jtu.tolerance(rhs_dtype, tol_spec))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
name,
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"rng_factory": rng_factory}
for rng_factory in [jtu.rand_default]
for name, lhs_shape, rhs_shape in [
("matrix-scalar", (3, 3), ()),
("scalar-matrix", (), (3, 3)),
("matrix-vector", (4, 5), (5,)),
("vector-matrix", (6,), (6, 4)),
("matrix-matrix", (3, 4), (4, 5)),
("tensor-vector", (4, 3, 2), (2,)),
("vector-tensor", (2,), (3, 2, 4)),
("tensor-matrix", (4, 3, 2), (2, 5)),
("matrix-tensor", (5, 2), (3, 2, 4)),
("tensor-tensor", (2, 3, 4), (5, 4, 1))]
for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2)))
def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
tol = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-14,
np.complex128: 1e-14}
if jtu.device_under_test() == "tpu":
tol[np.float32] = tol[np.complex64] = 2e-1
def np_dot(x, y):
x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x
y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y
return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype))
self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker,
tol=tol)
self._CompileAndCheck(jnp.dot, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
name,
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"rng_factory": rng_factory}
for rng_factory in [jtu.rand_default]
for name, lhs_shape, rhs_shape in [
("vector-vector", (3,), (3,)),
("matrix-vector", (3, 3), (3,)),
("vector-matrix", (3,), (3, 3)),
("matrix-matrix", (3, 3), (3, 3)),
("vector-tensor", (3,), (5, 3, 2)),
("tensor-vector", (5, 3, 2), (2,)),
("matrix-tensor", (5, 2), (3, 2, 4)),
("tensor-matrix", (5, 2, 3), (3, 2)),
("tensor-tensor", (5, 3, 4), (5, 4, 1)),
("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2)))
def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory):
rng = rng_factory(self.rng())
def np_fun(x, y):
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
return np.matmul(x, y).astype(dtype)
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12,
np.complex128: 1e-12}
if jtu.device_under_test() == "tpu":
tol[np.float32] = tol[np.complex64] = 4e-2
self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol)
self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
axes),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"axes": axes, "rng_factory": rng_factory}
for rng_factory in [jtu.rand_default]
for lhs_shape, rhs_shape, axes in [
[(3,), (), 0],
[(2, 3, 4), (5, 6, 7), 0], # from issue #740
[(2, 3, 4), (3, 4, 5, 6), 2],
[(2, 3, 4), (5, 4, 3, 6), [1, 2]],
[(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]],
[(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]],
]
for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2)))
def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
jnp_fun = lambda a, b: jnp.tensordot(a, b, axes)
def np_fun(a, b):
a = a if lhs_dtype != jnp.bfloat16 else a.astype(np.float32)
b = b if rhs_dtype != jnp.bfloat16 else b.astype(np.float32)
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
return np.tensordot(a, b, axes).astype(dtype)
tol = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-12,
np.complex64: 1e-3, np.complex128: 1e-12}
if jtu.device_under_test() == "tpu":
tol[np.float32] = tol[np.complex64] = 2e-1
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
def testTensordotErrors(self):
a = np.random.random((3, 2, 2))
b = np.random.random((2,))
self.assertRaisesRegex(
TypeError, "Number of tensordot axes.*exceeds input ranks.*",
lambda: jnp.tensordot(a, b, axes=2))
self.assertRaisesRegex(
TypeError, "tensordot requires axes lists to have equal length.*",
lambda: jnp.tensordot(a, b, axes=([0], [0, 1])))
self.assertRaisesRegex(
TypeError, "tensordot requires both axes lists to be either ints, tuples or lists.*",
lambda: jnp.tensordot(a, b, axes=('bad', 'axes')))
self.assertRaisesRegex(
TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*",
lambda: jnp.tensordot(a, b, axes='badaxes'))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_invert={}".format(
jtu.format_shape_dtype_string(element_shape, dtype),
jtu.format_shape_dtype_string(test_shape, dtype), invert),
"element_shape": element_shape, "test_shape": test_shape,
"dtype": dtype, "invert": invert}
for element_shape in all_shapes
for test_shape in all_shapes
for dtype in default_dtypes
for invert in [True, False]))
def testIsin(self, element_shape, test_shape, dtype, invert):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert)
np_fun = lambda e, t: np.isin(e, t, invert=invert)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_invert={}".format(
jtu.format_shape_dtype_string(element_shape, dtype),
jtu.format_shape_dtype_string(test_shape, dtype), invert),
"element_shape": element_shape, "test_shape": test_shape,
"dtype": dtype, "invert": invert}
for element_shape in all_shapes
for test_shape in all_shapes
for dtype in default_dtypes
for invert in [True, False]))
def testIn1d(self, element_shape, test_shape, dtype, invert):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
jnp_fun = lambda e, t: jnp.in1d(e, t, invert=invert)
np_fun = lambda e, t: np.in1d(e, t, invert=invert)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_assume_unique={}_return_indices={}".format(
jtu.format_shape_dtype_string(shape1, dtype1),
jtu.format_shape_dtype_string(shape2, dtype2),
assume_unique,
return_indices),
"shape1": shape1, "dtype1": dtype1, "shape2": shape2, "dtype2": dtype2,
"assume_unique": assume_unique, "return_indices": return_indices}
for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16]
for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16]
for shape1 in all_shapes
for shape2 in all_shapes
for assume_unique in [False, True]
for return_indices in [False, True]))
def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, return_indices):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)