/
api_test.py
3756 lines (3060 loc) · 114 KB
/
api_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from contextlib import contextmanager
import copy
from functools import partial
import re
import unittest
import types
import warnings
import weakref
import functools
from absl import logging
from absl.testing import absltest, parameterized
import numpy as np
import concurrent.futures
import jax
import jax.numpy as jnp
from jax import jit, grad, device_put, jacfwd, jacrev, hessian
from jax import api, core, lax, lax_reference
from jax.core import Primitive
from jax.interpreters import ad
from jax.interpreters import xla
from jax.interpreters.sharded_jit import PartitionSpec as P
from jax.lib import xla_bridge as xb
from jax import test_util as jtu
from jax import tree_util
from jax import linear_util as lu
from jax.lib import version
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
class CPPJitTest(jtu.JaxTestCase):
"""Shared tests between the Python and the C++ jax,jit implementations.
Because the Python implementation supports more features, we need to have the
Python tests that extend the C++ tests (and not the other way around).
"""
@property
def jit(self):
# Right now, the CPP tests also test the Python code-path when jaxlib is
# too old.
# TODO(jblespiau,phawkins): Remove this when jaxlib has been released.
# This is in the future, because we are making a breaking change to
# Tensorflow.
if version < (0, 1, 56):
raise unittest.SkipTest("Disabled because it depends on some future "
"release of jax_jit.cc within jaxlib.")
else:
return jax.api._cpp_jit
def test_jit_of_noncallable(self):
self.assertRaisesRegex(TypeError, "Expected a callable value.*",
lambda: self.jit(3))
def test_jit_of_generator(self):
def gen(x):
yield x
self.assertRaisesRegex(TypeError,
"Expected a function, got a generator function.*",
lambda: self.jit(gen))
@parameterized.parameters([
# Integer support
(1, 2, 3, 4, 5),
# Numpy array support
(
np.asarray(1, np.int32),
np.asarray(2, np.int32),
np.asarray(3, np.int32),
np.asarray(4, np.int32),
np.asarray(5, np.int32),
),
])
def test_jit_static_args(self, one, two, three, four, five):
side = []
# For the CPP jit, we need to clear the cache to prevent cache hits between
# parameterized tests.
if hasattr(self.jit, "cache_clear"):
self.jit.cache_clear()
def f(x, y, z, flag=False, flag2=False):
del flag2 # unused
assert flag
side.append(None)
return 100 * x + 10 * y + z
f1 = self.jit(f, static_argnums=(3, 4))
assert f1(one, two, three, True, False) == 123
assert len(side) == 1
assert f1(one, two, three, True, False) == 123
assert len(side) == 1 # Obvious cache hit.
assert f1(two, one, three, True, False) == 213
assert len(side) == 1 # Should cache hit because same signature.
assert f1(two, one, three, True, True) == 213
assert len(side) == 2
side[:] = []
f2 = self.jit(f, static_argnums=(0, 2, 3, 4))
assert f2(one, two, three, True, False) == 123
assert len(side) == 1
assert f2(one, three, three, True, False) == 133
assert len(side) == 1
assert f2(two, two, three, True, False) == 223
assert len(side) == 2
assert f2(two, four, three, True, False) == 243
assert len(side) == 2
assert f2(two, four, three, True, True) == 243
assert len(side) == 3
assert f2(two, five, three, True, True) == 253
assert len(side) == 3
@parameterized.parameters([
(1, 2, 3),
(
np.asarray(1, np.int32),
np.asarray(2, np.int32),
np.asarray(3, np.int32),
),
])
def test_jit_kwargs(self, one, two, three):
side = []
# For the CPP jit, we need to clear the cache to prevent cache hits between
# parameterized tests.
if hasattr(self.jit, "cache_clear"):
self.jit.cache_clear()
def f(x, y, z):
print(x, y, z)
side.append(None)
return 100 * x + 10 * y + z
f = self.jit(f)
assert f(one, two, three) == 123
assert len(side) == 1
assert f(one, two, three) == 123
assert len(side) == 1
assert f(one, two, z=three) == 123
assert len(side) == 2 # actually recompiles from kwarg
assert f(one, two, z=three) == 123
assert len(side) == 2 # but should still cache
f(one, two, z=np.zeros(3)) # doesn't crash
if FLAGS.jax_enable_x64:
# In the above call, three is of a new type (int64), thus it should
# trigger a new compilation.
assert len(side) == 3
def test_jit_device(self):
device = xb.devices()[-1]
x = self.jit(lambda x: x, device=device)(3.)
self.assertIsInstance(x, xla.DeviceArray)
self.assertEqual(x.device_buffer.device(), device)
def test_complex_support(self):
self.assertEqual(self.jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
def test_jit_with_many_args_works(self):
@self.jit
def f(args_list):
return sum(args_list)
self.assertEqual(f(list(range(500))), sum(range(500)))
# Jit and Donate arguments
assertDeleted = lambda self, x: self._assertDeleted(x, True)
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)
def _assertDeleted(self, x, deleted):
if hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
else:
for buffer in x.device_buffers:
self.assertEqual(buffer.is_deleted(), deleted)
def test_jit_donate_argnums_warning_raised(self):
x = jnp.array([1.0, 2.0], jnp.float32)
y = jnp.array([1, 2], jnp.int32)
f = self.jit(lambda x, y: x.sum() + y.sum(), donate_argnums=(0, 1))
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
f(x, y)
self.assertLen(w, 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
self.assertIn(
"Some donated buffers were not usable: f32[2]{0}, s32[2]{0}",
str(w[-1].message))
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_jit_donate_argnums_invalidates_input(self):
# We can't just use `lambda x: x` because JAX simplifies this away to an
# empty XLA computation.
move = self.jit(lambda x: x + x - x, donate_argnums=0)
x = jnp.ones([])
y = move(x)
self.assertDeleted(x)
self.assertEqual(y, 1.)
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_jit_donate_argnums_static_argnums(self):
jit_fun = self.jit(
lambda a, b, c, d: ((a + b + c), (a + b + d)),
static_argnums=(0, 1),
donate_argnums=(2, 3))
a = jnp.array(1)
b = jnp.array(2)
c = jax.device_put(jnp.array([1., 1.]))
d = jax.device_put(jnp.array([1., 1., 1.]))
e, f = jit_fun(a, b, c, d)
np.testing.assert_allclose(e, jnp.array([4., 4.]))
np.testing.assert_allclose(f, jnp.array([4., 4., 4.]))
self.assertNotDeleted(a)
self.assertNotDeleted(b)
self.assertDeleted(c)
self.assertDeleted(d)
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_jnp_array_copy(self):
# https://github.com/google/jax/issues/3412
@partial(self.jit, donate_argnums=(0,))
def _test(array):
return array.at[0].set(77)
x = jnp.asarray([0, 1])
x_copy = jnp.array(x, copy=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
_test(x) # donation
# Gives: RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
print(x_copy) # doesn't crash
def test_jit_global_cache(self):
def f(x):
assert python_should_be_executing
return x
python_should_be_executing = True
self.jit(f)(2)
python_should_be_executing = False
self.jit(f)(3)
def test_jit_shallow_copy(self):
def f(x):
return copy.copy(x)
self.jit(f)(1)
def test_jit_deep_copy(self):
def f(x):
return copy.deepcopy(x)
self.jit(f)(1)
def test_disable_jit(self):
effects = []
@self.jit
def f(x):
effects.append(1)
return x
with api.disable_jit():
f(2)
f(2)
assert len(effects) == 2
f(2)
f(2)
assert len(effects) == 3
def test_static_argnum_on_method(self):
class A:
@functools.partial(self.jit, static_argnums=(0,))
def my_func_jit(self, x):
return x+2
A().my_func_jit(3)
def test_static_argnum_on_static_method_is_not_supported(self):
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
class A:
@functools.partial(self.jit, static_argnums=(0,))
@classmethod
def my_classmethod_jit(cls, x):
return x+2
def test_classmethod_is_not_supported(self):
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
class A:
@functools.partial(self.jit)
@staticmethod
def my_staticmethod_jit(x):
return x + 2
def test_concurrent_jit(self):
@self.jit
def f(x):
return x + x - 3.
xs = [np.random.randn(i) for i in range(10)]
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(partial(f, x)) for x in xs]
ys = [f.result() for f in futures]
for x, y in zip(xs, ys):
self.assertAllClose(x * 2 - 3., y)
class PythonJitTest(CPPJitTest):
@property
def jit(self):
return jax.api._python_jit
def test_jit_reference_dropping(self):
x = jnp.ones(10)
f = (lambda x: lambda: x)(x) # reference to x in f's closure
g = self.jit(f)
x = weakref.ref(x) # no more strong ref to x in this scope
assert x() is not None # x is still around
f() # f runs
g() # g runs
g() # g runs a second time
del f # delete the raw callable
assert x() is not None # x is still around
g() # g still runs
del g # no more references to x
assert x() is None # x is gone
def test_trivial_computations(self):
x = jnp.array([1, 2, 3])
y = self.jit(lambda x: x)(x)
self.assertIs(x, y)
z1, z2 = self.jit(lambda x: (x, x))(x)
self.assertIs(z1, z2)
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
z1, z2, z3 = self.jit(lambda x, y: (y, 1, x))(x1, x2)
self.assertIs(z1, x2)
self.assertIs(z3, x1)
self.assertEqual(z2, 1)
def test_jit_bad_input(self):
def f(x):
return x
self.assertRaisesRegex(
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
lambda: jit(f)("foo"))
def test_jit_on_all_devices(self):
# Verifies we can run the same computation on every device present, even
# if they are, for example, different models of GPU.
data = np.random.rand(1000).astype(np.float32)
f = self.jit(jnp.negative)
for device in jax.local_devices():
x = device_put(data, device=device)
np.testing.assert_array_equal(-data, f(x))
def test_jit_nested_donate_ignored(self):
jit_fun = self.jit(lambda x: self.jit(lambda y: y**2, donate_argnums=0)(x))
a = jax.device_put(jnp.array(1))
# NOTE(mattjj): stopped raising error here and instead just ignored
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
# jit_fun(a)
jit_fun(a) # doesn't crash
class APITest(jtu.JaxTestCase):
def test_grad_bad_input(self):
def f(x):
return x
self.assertRaisesRegex(
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
lambda: grad(f)("foo"))
def test_grad_argnums(self):
def f(x, y, z, flag=False):
assert flag
return 1.0 * x + 2.0 * y + 3.0 * z
assert grad(f)(1.0, 1.0, 1.0, flag=True) == 1.0
assert grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == 2.0
assert grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (3.0, 1.0)
def test_value_and_grad_argnums(self):
def f(x, y, z, flag=False):
assert flag
return 1.0 * x + 2.0 * y + 3.0 * z
y = f(1.0, 1.0, 1.0, flag=True)
assert api.value_and_grad(f)(1.0, 1.0, 1.0, flag=True) == (y, 1.0)
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
def test_grad_of_jit(self):
side = []
@jit
def f(x):
side.append(None)
return x * x
assert grad(f)(1.0) == 2.0
assert len(side) == 1
assert grad(f)(2.0) == 4.0
assert len(side) == 1
def test_jit_of_grad(self):
side = []
@jit
def f(x):
side.append(None)
return x * x
g = jit(grad(f))
assert g(1.0) == 2.0
assert len(side) == 1
assert g(2.0) == 4.0
assert len(side) == 1
def test_bad_input(self):
def f(x):
return x
self.assertRaisesRegex(
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
lambda: grad(f)("foo"))
self.assertRaisesRegex(
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
lambda: jit(f)("foo"))
def test_grad_tuple_output(self):
jtu.check_raises(lambda: grad(lambda x: (x,x))(1.0), TypeError,
"Gradient only defined for scalar-output functions. ")
def test_grad_unit_output(self):
jtu.check_raises(lambda: grad(lambda x: ())(np.zeros(3)), TypeError,
"Gradient only defined for scalar-output functions. ")
def test_grad_nonscalar_output(self):
jtu.check_raises(lambda: grad(lambda x: x)(np.zeros(3)), TypeError,
"Gradient only defined for scalar-output functions. ")
def test_unwrapped_numpy(self):
def f(x):
return np.exp(x)
with self.assertRaisesRegex(Exception, "The numpy.ndarray conversion .*"):
grad(f)(np.zeros(3))
def test_binop_mismatch(self):
def f(x, y):
return x + y
jtu.check_raises(
lambda: f(jnp.zeros(3), jnp.zeros(4)),
TypeError,
"add got incompatible shapes for broadcasting: (3,), (4,).")
jtu.check_raises(
lambda: grad(f)(np.zeros(3), np.zeros(4)),
TypeError,
"add got incompatible shapes for broadcasting: (3,), (4,).")
def test_dot_mismatch(self):
def f(x, y):
return jnp.dot(x, y)
self.assertRaisesRegex(
TypeError, "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).",
lambda: grad(f)(np.zeros(3), np.zeros(4)))
def test_abstract_error_message(self):
for castfun in [float, complex, int]:
def f(x):
return castfun(x)
self.assertRaisesRegex(
TypeError,
f"[Tt]ry using `x.astype\\({castfun.__name__}\\)`",
lambda: jit(f)(1.0))
def test_switch_value_jit(self):
def f(x):
y = x > 0
if y:
return x
else:
return -x
assert grad(f)(1.0) == 1.0
assert grad(f)(-1.0) == -1.0
with self.assertRaisesRegex(core.ConcretizationTypeError,
"Abstract tracer value"):
jit(f)(1)
def test_range_err(self):
def f(x, n):
for i in range(n):
x = x + i
return x
assert jit(f, static_argnums=(1,))(0, 5) == 10
self.assertRaisesRegex(
TypeError,
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
"|Abstract value passed to .*)",
lambda: jit(f)(0, 5))
def test_casts(self):
for castfun in [hex, oct, int]:
f = lambda x: castfun(x)
self.assertRaisesRegex(
TypeError,
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
"|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0))
def test_unimplemented_interpreter_rules(self):
foo_p = Primitive('foo')
def foo(x):
return foo_p.bind(x)
jtu.check_raises(lambda: foo(1.0), NotImplementedError,
"Evaluation rule for 'foo' not implemented")
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
"Abstract evaluation for 'foo' not implemented")
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
"Differentiation rule for 'foo' not implemented")
foo_p.def_abstract_eval(lambda x: x)
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
"XLA translation rule for primitive 'foo' not found")
foo_p.def_impl(lambda x: x)
ad.defjvp(foo_p, lambda g, x: foo(g))
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
"Transpose rule (for reverse-mode differentiation) for 'foo' not implemented")
def test_device_put_and_get(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")
dx = api.device_put(x)
self.assertIsInstance(dx, xla.DeviceArray)
x2 = api.device_get(dx)
self.assertIsInstance(x2, np.ndarray)
assert np.all(x == x2)
y = [x, (2 * x, 3 * x)]
dy = api.device_put(y)
y2 = api.device_get(dy)
self.assertIsInstance(y2, list)
self.assertIsInstance(y2[0], np.ndarray)
assert np.all(y2[0] == x)
self.assertIsInstance(y2[1], tuple)
self.assertIsInstance(y2[1][0], np.ndarray)
assert np.all(y2[1][0] == 2 * x)
self.assertIsInstance(y2[1][1], np.ndarray)
assert np.all(y2[1][1] == 3 * x)
def test_device_get_scalar(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")
x = api.device_put(x)
self.assertIsInstance(x, xla.DeviceArray)
y = [x, 2]
y2 = api.device_get(y)
self.assertIsInstance(y2, list)
self.assertIsInstance(y2[0], np.ndarray)
assert np.all(y2[0] == x)
self.assertIsInstance(y2[1], int)
self.assertEqual(y2[1], 2)
@parameterized.parameters([(3,)], [(2, 0)])
def test_device_put_across_devices(self, shape):
if len(api.local_devices()) < 2:
raise unittest.SkipTest("this test requires multiple devices")
d1, d2 = api.local_devices()[:2]
data = np.random.randn(*shape).astype(np.float32)
x = api.device_put(data, device=d1)
self.assertEqual(x.device_buffer.device(), d1)
y = api.device_put(x, device=d2)
self.assertEqual(y.device_buffer.device(), d2)
np.testing.assert_array_equal(data, np.array(y))
# Make sure these don't crash
api.device_put(x)
api.device_put(y)
@jtu.skip_on_devices("cpu")
def test_device_put_across_platforms(self):
default_device = jax.devices()[0]
cpu_device = jax.devices("cpu")[0]
np_arr = np.array([1,2,3])
scalar = 1
device_arr = jnp.array([1,2,3])
assert device_arr.device_buffer.device() is default_device
for val in [np_arr, device_arr, scalar]:
x = api.device_put(val, device=cpu_device)
self.assertEqual(x.device_buffer.device(), cpu_device)
def test_device_put_sharded_array(self):
devices = api.local_devices()
n_devices = len(devices)
x = [np.arange(i, i + 4) for i in range(n_devices)]
y = api.device_put_sharded(x, devices)
self.assertEqual(len(y.device_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
self.assertAllClose(y, jnp.stack(x))
def test_device_put_sharded_pytree(self):
devices = api.local_devices()
n_devices = len(devices)
x = [(i, np.arange(i, i + 4)) for i in range(n_devices)]
y1, y2 = api.device_put_sharded(x, devices)
self.assertAllClose(y1, jnp.array([a for a, _ in x]))
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
self.assertAllClose(y2, jnp.vstack([b for _, b in x]))
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
@jtu.skip_on_devices("tpu")
def test_jacobian(self):
R = np.random.RandomState(0).randn
A = R(4, 3)
x = R(3)
f = lambda x: jnp.dot(A, x)
assert np.allclose(jacfwd(f)(x), A)
assert np.allclose(jacrev(f)(x), A)
f = lambda x: jnp.tanh(jnp.dot(A, x))
assert np.allclose(jacfwd(f)(x), jacrev(f)(x))
@jtu.skip_on_devices("tpu")
def test_hessian(self):
R = np.random.RandomState(0).randn
A = R(4, 4)
x = R(4)
f = lambda x: jnp.dot(x, jnp.dot(A, x))
assert np.allclose(hessian(f)(x), A + A.T)
def test_std_basis(self):
basis = api._std_basis(jnp.zeros(3))
assert getattr(basis, "shape", None) == (3, 3)
assert np.allclose(basis, np.eye(3))
basis = api._std_basis(jnp.zeros((3, 3)))
assert getattr(basis, "shape", None) == (9, 3, 3)
assert np.allclose(basis, np.eye(9).reshape(9, 3, 3))
basis = api._std_basis([0., (jnp.zeros(3), jnp.zeros((3, 4)))])
assert isinstance(basis, list) and len(basis) == 2
assert getattr(basis[0], "shape", None) == (16,)
assert isinstance(basis[1], tuple) and len(basis[1]) == 2
assert getattr(basis[1][0], "shape", None) == (16, 3)
assert getattr(basis[1][1], "shape", None) == (16, 3, 4)
@jtu.skip_on_devices("tpu")
def test_jacobian_on_pytrees(self):
for jacfun in [jacfwd, jacrev]:
ans = jacfun(lambda x, y: (x, y))(0., 1.)
expected = (1., 0.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = jacfun(lambda x, y: (x, y), 1)(0., 1.)
expected = (0., 1.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = jacfun(lambda x, y: (x, y), (0, 1))(0., 1.)
expected = ((1., 0.),
(0., 1.),)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = jacfun(lambda x: x[:2])((1., 2., 3.))
expected = ((1., 0., 0.),
(0., 1., 0.))
self.assertAllClose(ans, expected, check_dtypes=False)
R = np.random.RandomState(0).randn
x = R(2)
y = R(3)
ans = jacfun(lambda x, y: {'x': x, 'xy': jnp.outer(x, y)})(x, y)
expected = {'x': np.eye(2),
'xy': np.kron(np.eye(2), y[:, None]).reshape(2, 3, 2)}
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("tpu")
def test_hessian_on_pytrees(self):
ans = hessian(lambda x: jnp.array(x)**2)((1., 2.))
expected = ((np.array([2., 0.]), np.array([0., 0.])),
(np.array([0., 0.]), np.array([0., 2.])))
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("tpu")
def test_issue1372(self):
def quad(x):
return jnp.dot(x, x)
def f(x, u):
return quad(x) + quad(u)
x, u = jnp.ones(5), jnp.ones(2)
rev = jacrev
fwd = jacfwd
# Diagonal entries
self.assertEqual(rev(rev(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(rev(fwd(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(fwd(rev(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(fwd(fwd(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(rev(rev(f, 1), 1)(x, u).shape, (2, 2))
self.assertEqual(rev(fwd(f, 1), 1)(x, u).shape, (2, 2))
self.assertEqual(fwd(rev(f, 1), 1)(x, u).shape, (2, 2))
self.assertEqual(fwd(fwd(f, 1), 1)(x, u).shape, (2, 2))
# Off-diagonal entries by reverse-mode on the outside
self.assertEqual(rev(rev(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(rev(fwd(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(rev(rev(f, 0), 1)(x, u).shape, (5, 2))
self.assertEqual(rev(fwd(f, 0), 1)(x, u).shape, (5, 2))
# Off-diagonal entries by forward-mode on the outside
self.assertEqual(fwd(rev(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(fwd(fwd(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2))
self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2))
def test_large_device_constant(self):
ans = jit(lambda x: 2 * x)(jnp.ones(int(2e6))) # doesn't crash
self.assertAllClose(ans, np.ones(int(2e6)) * 2., check_dtypes=False)
def test_grad_and_aux_basic(self):
g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.)
self.assertAllClose(g, grad(lambda x: x**3)(3.))
self.assertAllClose(aux, [9.], check_dtypes=False)
def test_grad_and_aux_nested(self):
def f(x):
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
return aux[0]
f2 = lambda x: x**3
self.assertEqual(grad(f)(4.), grad(f2)(4.))
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
def f(x):
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
return aux[0] * jnp.sin(x)
f2 = lambda x: x**3 * jnp.sin(x)
self.assertEqual(grad(f)(4.), grad(f2)(4.))
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
def test_grad_and_aux_constant(self):
g, aux = grad(lambda x: (x**3, [4.]), has_aux=True)(4.)
self.assertEqual(g, grad(lambda x: x**3)(4.))
self.assertEqual(aux, [4.])
g, aux = grad(lambda x: (x**3, [x**2, 4.]), has_aux=True)(4.)
self.assertEqual(g, grad(lambda x: x**3)(4.))
self.assertEqual(aux, [4.**2, 4.])
def test_grad_and_aux_no_tracers(self):
# see https://github.com/google/jax/issues/1950
def f(x):
aux = dict(identity=x, p1=x+1)
return x ** 2, aux
_, aux = jax.grad(f, has_aux=True)(3.)
self.assertIsInstance(aux, dict)
for val in aux.values():
self.assertNotIsInstance(val, core.Tracer)
def test_jvp_mismatched_arguments(self):
self.assertRaisesRegex(
TypeError,
("primal and tangent arguments to jax.jvp must have the same tree "
"structure"),
lambda: api.jvp(lambda x, y: x * y, (np.float32(2),), ()))
# If primals and tangents must both be tuples or both lists
self.assertRaisesRegex(
TypeError,
("primal and tangent arguments to jax.jvp must have the same tree "
"structure"),
lambda: api.jvp(lambda x, y: x * y, (np.float32(2),), [np.float32(2)]))
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must have equal types",
lambda: api.jvp(lambda x: -x, (np.float16(2),), (np.float32(4),)))
def test_jvp_non_tuple_arguments(self):
def f(x, y): return x + y
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must be tuples or lists; found float and tuple.",
lambda: api.jvp(f, 0., (1.,)))
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must be tuples or lists; found tuple and ndarray.",
lambda: api.jvp(f, (0.,), np.array([1., 2.])))
def test_vjp_mismatched_arguments(self):
_, pullback = api.vjp(lambda x, y: x * y, np.float32(3), np.float32(4))
self.assertRaisesRegex(
TypeError,
"Tree structure of cotangent input.*does not match",
lambda: pullback((np.float32(7), np.float32(100))))
self.assertRaisesRegex(
TypeError,
"Type of cotangent input to vjp pullback.*does not match type",
lambda: pullback((np.float16(42))))
def test_jvp_jit_cached(self):
"""Bug in caching in presence of JVP and JIT."""
def func(x):
def inner(y):
return y * x
# Must have two calls to the inner jit (the second one hits the cache)
res1 = api.jit(inner)(4.)
res2 = api.jit(inner)(5.)
return res1 + res2
self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)))
def test_linear_transpose_abstract(self):
x = types.SimpleNamespace(shape=(3,), dtype=np.float32)
y = jnp.arange(3, dtype=np.float32)
transpose_fun = api.linear_transpose(lambda x: 2 * x, x)
z, = transpose_fun(y)
self.assertArraysEqual(2 * y, z, check_dtypes=True)
def test_linear_transpose_error(self):
with self.assertRaisesRegex(
TypeError, "linear_transpose only supports float and complex inputs"):
api.linear_transpose(lambda x: x, 1)
transpose_fun = api.linear_transpose(lambda x: [x, x], 1.0)
with self.assertRaisesRegex(TypeError, "cotangent tree does not match"):
transpose_fun(1.0)
transpose_fun = api.linear_transpose(lambda x: jnp.stack([x, x]), 1.0)
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
transpose_fun(1.0)
transpose_fun = api.linear_transpose(lambda x: 1j * x, 1.0)
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
transpose_fun(1.0)
transpose_fun = api.linear_transpose(lambda x: x, 1.0)
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
transpose_fun(1j)
def test_linear_transpose_complex(self):
f = lambda x: (1 + 2j) * x
transpose = api.linear_transpose(f, 1j)
actual, = transpose(3 + 4j)
expected = -5 + 10j
self.assertEqual(actual, expected)
def test_complex_grad_raises_error(self):
self.assertRaises(TypeError, lambda: grad(lambda x: jnp.sin(x))(1 + 2j))
def test_holomorphic_grad(self):
out = grad(lambda x: jnp.sin(x), holomorphic=True)(1 + 2j)
expected = 2.0327230070196656 - 3.0518977991518j
self.assertAllClose(out, expected, check_dtypes=False)
def test_nonholomorphic_grad(self):
zs = 0.5j * np.arange(5) + np.arange(5)
def f(z):
return jnp.sum(jnp.cos(jnp.abs(z)))
ans = grad(f)(zs)
expected = np.array([ 0. +0.j,
-0.80430663+0.40215331j,
-0.70368982+0.35184491j,
0.1886467 -0.09432335j,
0.86873727-0.43436864j])
self.assertAllClose(ans, expected, check_dtypes=False,
atol=jtu.default_gradient_tolerance,
rtol=jtu.default_gradient_tolerance)
def test_complex_output_jacrev_raises_error(self):
self.assertRaises(TypeError, lambda: jacrev(lambda x: jnp.sin(x))(1 + 2j))
def test_nonholomorphic_jacrev(self):
# code based on https://github.com/google/jax/issues/603
zs = 0.5j * np.arange(5) + np.arange(5)
def f(z):
return jnp.cos(jnp.linalg.norm(2 * z))
ans = jacrev(f)(zs)
expected = grad(f)(zs)
self.assertAllClose(ans, expected)
def test_complex_input_jacfwd_raises_error(self):
self.assertRaises(TypeError, lambda: jacfwd(lambda x: jnp.sin(x))(1 + 2j))
def test_legacy_devicearray_repr(self):
dx = device_put(3.)
str(dx.item()) # doesn't crash
def test_devicearray_repr(self):
x = device_put(jnp.zeros(3))
self.assertIsInstance(x, xla.DeviceArray)
repr(x) # doesn't crash
x = device_put(jnp.ones(3) + 1j * jnp.ones(3))
self.assertIsInstance(x, xla.DeviceArray)
repr(x) # doesn't crash
def test_devicearray_delete(self):
x = device_put(1.)
x.delete()
self.assertRaisesRegex(ValueError, "DeviceArray has been deleted.",
lambda: repr(x))
def test_devicearray_block_until_ready(self):
x = device_put(1.)
y = x.block_until_ready()
# Tests mostly that block_until_ready() does not produce an error.
self.assertTrue(y is x)
def test_namedtuple_transparency(self):
# See https://github.com/google/jax/issues/446
Point = collections.namedtuple("Point", ["x", "y"])
def f(pt):
return jnp.sqrt(pt.x ** 2 + pt.y ** 2)
pt = Point(1., 2.)
f(pt) # doesn't crash
g = api.grad(f)(pt)
self.assertIsInstance(g, Point)
f_jit = api.jit(f)
self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False)
def test_namedtuple_subclass_transparency(self):
# See https://github.com/google/jax/issues/806
Point = collections.namedtuple("Point", ["x", "y"])
class ZeroPoint(Point):
def is_zero(self):
return (self.x == 0) and (self.y == 0)
pt = ZeroPoint(0., 0.)