-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
host_callback.py
2040 lines (1714 loc) · 83.8 KB
/
host_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Primitives for calling Python functions on the host from JAX accelerator code.
**Experimental: please give feedback, and expect changes.**
This module introduces the host callback functions :func:`call`,
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
to the host and invoke user-defined Python functions on the host, optionally
returning results back to the device computation.
We show below how these functions can be used. We start with :func:`call`,
and we discuss examples of calling from JAX to arbitrary Python functions
on the CPU, e.g., to use NumPy CPU custom kernels. Then we
show uses of :func:`id_tap` and :func:`id_print`, which have the restriction
that they cannot return values from the host to the device.
These primitives are generally faster
because they are executed asynchronously with the device code.
In particular, they can be used to tap into and to debug JAX code.
Using :func:`call` to call a host function and return results to device
-----------------------------------------------------------------------
Use :func:`call` to invoke a computation on the host and return
NumPy arrays to the device computation.
Host computation is useful, e.g., when a device computation needs some data
that requires I/O on the host, or it needs a library that is available on the
host and you do not want to code it in JAX.
For example, eigen decomposition for general matrices in JAX does not work on TPU.
We can call the Numpy implementation from any JAX accelerator computation,
using a host computation::
# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
return np.linalg.eigvals(m)
# This function is used in JAX
def device_fun(m):
# We send "m" to the host, asking it to call "host_eig" and return the result.
# We have to specify the result shape and dtype, either in the form of an
# example return value or any object that has `shape` and `dtype` attributes,
# e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
return hcb.call(host_eig, m,
# Given an input of shape (..., d, d), eig output has shape (..., d)
result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))
The :func:`call` function and the Python host function both take a single argument
and return a single result, but those can be pytrees. Note that we must tell
the :func:`call` what shape and dtype to expect from the host invocation, using
the ``result_shape`` keyword argument.
This is important because the device code is compiled with that expectation.
There will be an error raised at runtime if the actual invocation produces a
different result shape. In general, **such errors and also exceptions raised
by the host computation may be difficult to debug**. See the Debugging section
below.
This is a problem for :func:`call` but not for :func:`id_tap` because for the
latter the decice code does not expect a returned value.
The :func:`call` API can be used inside a jit or pmap computation or inside
cond/scan/while control flow. When used inside :func:`jax.pmap`, there will be
separate calls to the host from each of the participating devices::
def host_sin(x, *, device):
# The ``device`` argument is passed due to ``call_with_device=True`` below.
print(f"Invoking host_sin with {x.shape} on {device}")
return np.sin(x)
# Use pmap to run the computation on two devices
jax.pmap(lambda x: hcb.call(host_sin, x,
result_shape=x,
# Ask that the `host_sin` function be passed `device=dev`
call_with_device=True))(
np.ones((2, 4), dtype=np.float32))
# prints (in arbitrary order)
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1
Note that :func:`call` does not support any JAX transformations, but as we
show below one can make use of the
existing support for `Custom differentiation in JAX <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_.
Using :func:`id_tap` to call a Python function on the host, with no returned values
-----------------------------------------------------------------------------------
The :func:`id_tap` and :func:`id_print` are special cases of :func:`call`, when
you just want the side effects of your Python callback. These functions have
the advantage that once the arguments have been sent to the host, the device
computation can proceed without waiting for the Python callback to return.
For :func:`id_tap` you can specify your Python callback to be called, while
:func:`id_print` uses a built-in callback that prints the arguments to
`stdout` on the host.
The Python function passed
to :func:`id_tap` takes two positional arguments (the value tapped
from the device computation along with a ``transforms`` tuple,
described below). Optionally, the function may be passed a keyword argument
``device`` with the Device from which the value was tapped.
A few examples::
def host_func(arg, transforms):
...do something with arg...
# calls host_func(2x, []) on host
id_tap(host_func, 2 * x)
# calls host_func((2x, 3x), [])
id_tap(host_func, (2 * x, 3 * x)) # The argument can be a pytree
# calls host_func(2x, [], device=jax.devices()[0])
id_tap(host_func, 2 * x, tap_with_device=True) # Pass the device to the tap
# calls host_func(2x, [], what='activation')
id_tap(functools.partial(host_func, what='activation'), 2 * x)
# calls host_func(dict(x=x, y=y), what='data')
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))
The above examples can all be adapted to use :func:`id_print` instead, with
the difference that :func:`id_print` prints on the host the positional argument,
along with any additional kwargs and the automatic kwarg ``transforms``.
Using :func:`barrier_wait` to wait until all callbacks have executed
--------------------------------------------------------------------
If your Python callbacks have side-effects you may need to wait until the
computation has finished to ensure that the side-effects have been observed.
You can use the :func:`barrier_wait` function for that purpose::
accumulator = []
def host_log(arg, transforms):
# We just record the arguments in a list
accumulator.append(arg)
def device_fun(c):
id_tap(host_log, x)
id_tap(host_log, 2. * x)
jax.jit(device_fun)(1.)
jax.jit(device_fun)(1.)
# At this point, we have started two computations, each with two
# taps, but they may not have yet executed.
barrier_wait()
# Now we know that all the computations started before `barrier_wait`
# on all devices, have finished, and all the callbacks have finished
# executing.
Note that :func:`barrier_wait` will start one
tiny computation with one tap on each of the `jax.local_devices()` and
will wait for all these taps to be received.
An alternative to using :func:`barrier_wait` is to just wait for the end
of the computation, if all the callbacks are :func:`call`::
accumulator = p[]
def host_log(arg):
# We just record the arguments in a list
accumulator.append(arg)
return 0. # return something
def device_fun(c):
y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32))
z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32))
return y + z # return something that uses both results
res1 = jax.jit(device_fun)(1.)
res2 = jax.jit(device_fun)(1.)
res1.block_until_ready()
res2.block_until_ready()
Behavior under parallelization transformations
----------------------------------------------
In presence of :func:`jax.pmap` the code will run on multiple devices and
each device will tap its values independently.
It may be helpful to use the ``tap_with_device`` option for :func:`id_print`
or :func:`id_tap`, so that you see which device is sending which data::
jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: (3., 9.) # from the first device
# device=cpu:1 what=x,x^2: (4., 16.) # from the second device
When using :func:`jax.pmap` with multiple devices on multiple hosts, every
host will receive callbacks from all of its local devices, with an operand
that corresponds to each device slice. For a
:func:`call`, the callback must return to each device only the slice of the
result that pertains to the corresponding device.
When using the experimental :func:`pjit.pjit` the code will run on multiple
devices on different shards of the input. The current implementation of
host callbacks will ensure that a single device will collect and outfeed
the entire operand, in a single callback. The callback function is supposed
to return the entire array, which will then be sent in a single infeed to the
same device that issued the outfeed. This device is then responsible for
sending the required shards to the other devices::
with maps.Mesh(jax.local_devices()[:2], ["d"]):
pjit.pjit(power3, in_axis_resources=(P("d"),),
out_axis_resources=(P("d"),))(np.array([3., 4.]))
# device=TPU:0 what=x,x^2: ( [3., 4.],
# [9., 16.] )
Note that the collection of the operand on one device may result in OOM if
the operand was sharded across devices.
When using :func:`pjit.pjit` with multiple devices on multiple hosts, only
the host for the device 0 (w.r.t. the mesh) will receive the callback, with
the operand collected
from all participating devices on all hosts. For a :func:`call`, the callback
must return the entire array for all devices on all hosts.
Behavior under JAX autodiff transformations
-------------------------------------------
When used under a JAX autodiff transformation, the host callback functions
operate on the primal values only. Consider the following example::
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
return y * x
power3(3.)
# what: x,x^2 : (3., 9.)
(You can see these examples tested in `host_callback_test.HostCallbackTapTest.test_tap_transforms`.)
When used under :func:`jax.jvp` there will be one callback with the primal
values only::
jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
Similarly for :func:`jax.grad`, we get a callback from the forward computation
only::
jax.grad(power3)(3.)
# what: x,x^2 : (3., 9.)
If you want to invoke the callback on the tangents during a :func:`jax.jvp`,
you can use a custom_jvp. For example, you can define a function that does
nothing interesting except that its custom_jvp will print the tangents::
@jax.custom_jvp
def print_tangents(arg):
return None
@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
arg_dot, = tangents
hcb.id_print(arg_dot, what="tangents")
return primals, tangents
Then you use this function in the places where you want to tap the tangents::
def power3_with_tangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
print_tangents((x, y))
return y * x
jax.jvp(power3_with_tangents, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
# what: tangents : (0.1, 0.6)
You can do a similar thing for the cotangents during :func:`jax.grad`. This
time you must be careful to use in the rest of the computation the values whose
cotangents you want to tap. Hence we make the ``print_cotangents`` return
its argument::
@jax.custom_vjp
def print_cotangents(arg):
# Must return the argument for which we want the cotangent.
return arg
# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
return ct_b,
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
def power3_with_cotangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
(x1, y1) = print_cotangents((x, y))
# Must use the output of print_cotangents
return y1 * x1
jax.grad(power3_with_cotangents)(3.)
# what: x,x^2 : (3., 9.)
# what: cotangents : (9., 3.)
If you use :func:`ad_checkpoint.checkpoint` to rematerialize the residuals
for the backward pass, then the callbacks from the primal computation will
be called twice::
jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
# what: x,x^2 : (3., 9.)
# what: x,x^2 : (27., 729.)
# what: x,x^2 : (3., 9.)
The callbacks are, in order from: the primal computation of the inner ``power3``,
the primal computation of the outer ``power3``, and the rematerialization
of the residuals for the inner ``power3``.
Behavior under jax.vmap
-----------------------
The host callback functions :func:`id_print` and :func:`id_tap` support the
vectorization transformation :func:`jax.vmap`.
For :func:`jax.vmap` the arguments to the callback are batched,
and the callback function is
passed an additional special ``transforms`` containing a list of transformation descriptors
in the form ``("batch", {"batch_dims": ...})``, where ``...``` denotes the
batched dimensions for the tapped values (one entry per argument, `
`None`` denotes an argument that was broadcast).
jax.vmap(power3)(np.array([2., 3.]))
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.])
See documentation for :func:`id_tap`, :func:`id_print`, and :func:`call`.
For more usage example, see tests/host_callback_test.py.
Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff support
------------------------------------------------------------------------------------
Another possible use for host computation is to invoke a library written for
another framework, such as TensorFlow.
In this case it becomes interesting to support JAX autodiff for host callbacks
by deferring to the autodiff mechanism in TensorFlow,
using the :func:`jax.custom_vjp` mechanism.
This is relatively easy to do, once one understands both the JAX custom VJP
and the TensorFlow autodiff mechanisms.
The code for how this can be done is shown in the ``call_tf_full_ad``
function in `host_callback_to_tf_test.py <https://github.com/google/jax/blob/main/tests/host_callback_to_tf_test.py>`_.
This example supports arbitrary higher-order differentiation as well.
Note that if you just want to call TensorFlow functions from JAX, you can also
use the `jax2tf.call_tf function <https://github.com/google/jax/blob/main/jax/experimental/jax2tf/call_tf.py>`_.
Using :func:`call` to call a JAX function on another device, with reverse-mode autodiff support
------------------------------------------------------------------------------------------------
It should not be surprising that we can use host computation to invoke a JAX
computation on another device. The arguments are sent from the accelerator to
the host, and then to the outside device on which the JAX host
computation will run, and then the results are sent back to the original accelerator.
The code for how this can be done is shown in the ``call_jax_other_device function``
in `host_callback_test.py <https://github.com/google/jax/blob/main/tests/host_callback_test.py>`_.
Low-level details and debugging
-------------------------------
The host callback functions will be executed for each device in the order in
which the send operations were performed on the device.
The host callback functions for multiple devices may be interleaved.
The data from the devices is received by separate threads managed by the JAX
runtime (one thread per device). The runtime maintains a buffer of
configurable size (see the flag ``--jax_host_callback_max_queue_byte_size``).
When the buffer is full, all the receiving threads are paused
which eventually pauses the computation on devices. The runtime has one
additional thread for each device to invoke the Python user functions with the
received data. If the processing of the callbacks is slow, it may actually
lead to the runtime buffer filling up, and eventually pausing the computation
on the devices when they need to send something.
For more details on the outfeed receiver runtime mechanism see
`runtime code
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.
In order to pause the execution until all data from computations already
started on devices has arrived and has been processed, use :func:`barrier_wait`.
Exceptions from the user-defined callback functions are logged along with their
stack traces, but the receiving threads are not stopped. Instead the last
exception is recorded and the subsequent :func:`barrier_wait` will
raise :exc:`CallbackException` if any exception had occurred
in one of the tap functions. This exception will include the text and the
stack trace of the last exception encountered.
One further complication arises for callback functions that must return
results to the call origin device, such as :func:`call()`. This is handled
differently on CPU/GPU devices compared to TPU devices.
On CPU/GPU devices, in order to avoid the device computation
being stuck waiting for a result that will never arrive, in case of any
error during the processing of the callback (whether raised by the user-code
itself or due to a mismatch of the returned value and the expected return_shape)
we send the device a "fake" result of shape ``int8[12345]``.
This will make the device
computation abort because the received data is different than the one that
it expects. On CPU the runtime will crash with a distinctive error message:
```
Check failed: buffer->length() == buffer_length (12345 vs. ...)
```
On GPU, the failure is more user-friendly and will be surfaced to the Python
program as:
```
RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ...
```
To debug the underlying cause for these messages, see the Debugging section.
On TPU devices, there is currently no shape check for infeed, so we take the
safer route of not sending this fake result in case of errors. This means
that the computation will hang, and no exception will be raised (but any
exceptions in the callback functions will still appear in the logs).
The current implementation uses the outfeed mechanism provided by XLA. The
mechanism itself is quite primitive in the sense that a receiver must know
exactly the shape of each incoming packet, and how many packets are expected.
This makes it hard to use for multiple kinds of data in the same computation,
and it is practically impossible to use it under conditionals or in loops
of non-constant iteration count. Furthermore, code that uses the outfeed
mechanism directly cannot be transformed by JAX. All these limitations are
addressed by the host callback functions. The tapping API introduced here
makes it easy to share the outfeed mechanism for multiple purposes, while
supporting all transformations.
**Note that after you have used the host callback functions, you cannot
use lax.outfeed directly**. You may want to :func:`stop_outfeed_receiver`
if you later need to use lax.outfeed.
Since the actual calls to your callback functions are made from the C++
receiver, it may be hard to debug the calls. In particular, the stack trace
will not include the calling code. You can use the flag
``jax_host_callback_inline`` (or the environment variable
``JAX_HOST_CALLBACK_INLINE``) to ensure that the calls to the callbacks are
inlined. This works only if the calls are outside a staging context (``jit``
or a control-flow primitive).
The C++ `receiver
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_
is started automatically on the first call to :func:`id_tap`. In order to stop
it properly, upon start an ``atexit`` handler is registered to call
:func:`barrier_wait` with the logging name "at_exit".
There are a few environment variables that you can use to turn on logging
for the C++ outfeed `receiver backend
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.
* ``TF_CPP_MIN_LOG_LEVEL=0``: will turn on INFO logging, needed for all below.
* ``TF_CPP_MIN_VLOG_LEVEL=3``: will make all VLOG logging up to level 3 behave
like INFO logs. This may be too much, but you will see which modules are
logging relevant info, and then you can select which modules to log from.
* ``TF_CPP_VMODULE=<module_name>=3`` (the module name can be either C++ or
Python, without the extension).
You should also use the ``--verbosity=2`` flag so that you see the logs
from Python.
For example, you can try to enable logging in the ``host_callback`` module:
``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple``
If you want to enable logging in lower-level implementation modules try:
``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple``
(For bazel tests use --test_arg=--vmodule=...
Still to do:
* More performance tests.
* Explore implementation with outside compilation for TPU.
* Explore implementation with XLA CustomCall for CPU and GPU.
"""
import atexit
import functools
import itertools
import threading
import traceback
from typing import (Any, Callable, Dict, List, Optional, Sequence,
Tuple, cast)
import warnings
from absl import logging
from jax._src import api
from jax import core
from jax.config import config
from jax import custom_derivatives
from jax._src import dtypes
from jax import lax
from jax.experimental import pjit
from jax.interpreters import ad, xla, batching, masking, pxla
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax._src import dispatch
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src import util
from jax._src import lib as jaxlib
from jax._src.lib import pytree
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib.mlir.dialects import mhlo
import numpy as np
FLAGS = config.FLAGS
def _inline_host_callback() -> bool:
return FLAGS.jax_host_callback_inline
def _use_outfeed(platform: str) -> bool:
return (platform in ("tpu", "gpu", "cuda", "rocm") or FLAGS.jax_host_callback_outfeed)
xops = xla_client._xla.ops
XlaOp = xla_client.XlaOp
XlaShape = xla_client.Shape
XlaBuilder = xla_client.XlaBuilder
XlaDevice = xla_client.Device
XlaLocalClient = xla_client.Client
DType = Any
def id_tap(tap_func, arg, *, result=None, tap_with_device=False, **kwargs):
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
**Experimental: please give feedback, and expect changes!**
``id_tap`` behaves semantically like the identity function but has the
side-effect that a user-defined Python function is called with the runtime
value of the argument.
Args:
tap_func: tap function to call like ``tap_func(arg, transforms)``, with
``arg`` as described below and where ``transforms`` is the sequence of
applied JAX transformations in the form ``(name, params)``. If the
`tap_with_device` optional argument is True, then the invocation also
includes the device from which the value is tapped as a keyword argument:
``tap_func(arg, transforms, device=dev)``.
arg: the argument passed to the tap function, can be a pytree of JAX
types.
result: if given, specifies the return value of ``id_tap``. This value is
not passed to the tap function, and in fact is not sent from the device to
the host. If the ``result`` parameter is not specified then the return
value of ``id_tap`` is ``arg``.
tap_with_device: if True then the tap function is invoked with the
device from which the tap originates as a keyword argument.
Returns:
``arg``, or ``result`` if given.
The order of execution is by data dependency: after all the arguments and
the value of ``result`` if present, are computed and before the returned
value is used. At least one of the returned values of ``id_tap`` must be
used in the rest of the computation, or else this operation has no effect.
Tapping works even for code executed on accelerators and even for code under
JAX transformations.
For more details see the
`module documentation
<jax.experimental.host_callback.html>`_.
"""
if kwargs:
msg = (
"Support for **kwargs in ``id_tap`` has been removed. Instead, "
"pre-apply keyword arguments, either by using a closure or by passing "
"``functools.partial(tap_func, **kwargs)``.")
raise TypeError(msg)
if FLAGS.jax_host_callback_ad_transforms:
warnings.warn('The flag jax_host_callback_ad_transforms is for temporary '
'backwards compatibility mode. This flag, and the behavior '
'it enabled will be removed soon.',
FutureWarning)
if result is not None:
flat_results, result_treedef = pytree.flatten(result)
for r in flat_results:
api._check_arg(r)
call_res = _call(tap_func, arg, call_with_device=tap_with_device,
result_shape=None, identity=True)
if result is not None:
# Return the results, but add a dependency on the call, to ensure it
# is kept in the graph.
if FLAGS.jax_host_callback_ad_transforms:
call_flat_results, _ = pytree.flatten(call_res)
if call_flat_results:
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
for r in flat_results]
else:
call_flat_results = flat_results
return result_treedef.unflatten(call_flat_results)
else:
return result
else:
return call_res
def id_print(arg, *, result=None, tap_with_device=False,
output_stream=None, threshold=None, **kwargs):
"""Like :func:`id_tap` with a printing tap function.
**Experimental: please give feedback, and expect changes!**
On each invocation of the printing tap, the ``kwargs`` if present
will be printed first (sorted by keys). Then arg will be printed,
with the arrays stringified with ``numpy.array2string``.
See the :func:`id_tap` documentation.
Additional keyword arguments:
* ``tap_with_device`` if True, will print also the device from which
the value originates.
* ``output_stream`` if given then it will be used instead of the
built-in ``print``. The string will be passed as
``output_stream.write(s)``.
* ``threshold`` is passed to ``numpy.array2string``.
"""
printer = functools.partial(_print_tap_func,
output_stream=output_stream,
threshold=threshold, **kwargs)
return id_tap(printer, arg, result=result, tap_with_device=tap_with_device)
def call(callback_func: Callable, arg, *,
result_shape=None,
call_with_device=False):
"""Make a call to the host, and expect a result.
**Experimental: please give feedback, and expect changes!**
Args:
callback_func: The Python function to invoke on the host as
``callback_func(arg)``. If the ``call_with_device`` optional argument is True,
then the invocation also includes the ``device`` kwarg with the device
from which the call originates: ``callback_func(arg, device=dev)``. This function
must return a pytree of numpy ndarrays.
arg: the argument passed to the callback function, can be a pytree of JAX
types.
result_shape: a value that describes the expected shape and dtype of the
result. This can be a numeric scalar, from which a shape and dtype are
obtained, or an object that has ``.shape`` and ``.dtype`` attributes.
If the result of the callback is a pytree, then ``result_shape`` should
also be a pytree with the same structure. In particular, ``result_shape``
can be `()` or `None` if the function does not have any results.
The device code containing ``call`` is compiled with the expected result shape and dtype,
and an error will be raised at runtime if the actual ``callback_func``
invocation returns a different kind of result.
call_with_device: if True then the callback function is invoked with the
device from which the call originates as a keyword argument.
Returns:
the result of the ``callback_func`` invocation.
For more details see the
`module documentation
<jax.experimental.host_callback.html>`_.
"""
return _call(callback_func, arg, result_shape=result_shape,
call_with_device=call_with_device, identity=False)
# We need the wrapper function to have hash and equality defined since it is
# used as a primitive keyword argument, and we want a compilation cache hit if
# the user uses the same function twice.
class _CallbackWrapper:
def __init__(self, callback_func, identity, call_with_device):
self.callback_func = callback_func
self.identity = identity
self.call_with_device = call_with_device
def __hash__(self):
return hash((self.callback_func, self.identity, self.call_with_device))
def __eq__(self, other):
return (self.callback_func == other.callback_func and
self.identity == other.identity and
self.call_with_device == other.call_with_device)
def __call__(self, arg, device, transforms):
if self.identity:
# For id_tap, we pass the transforms, for backwards compatibility
if self.call_with_device:
return self.callback_func(arg, transforms, device=device)
else:
return self.callback_func(arg, transforms)
else:
if self.call_with_device:
return self.callback_func(arg, device=device)
else:
return self.callback_func(arg)
# Helper function to implement both `call` and `id_tap`. The two cases are
# differentiated by the `identity` flag.
def _call(callback_func: Callable, arg, *,
result_shape=None,
call_with_device=False,
identity=False):
# Lazy initialization
_initialize_outfeed_receiver(
max_callback_queue_size_bytes=FLAGS.jax_host_callback_max_queue_byte_size)
api._check_callable(callback_func)
flat_args, arg_treedef = pytree.flatten(arg)
for arg in flat_args:
api._check_arg(arg)
# See definition of outside_call_p for what parameters it takes
params: Dict[str, Any] = {}
# TODO: wrap function
params["callback"] = _CallbackWrapper(callback_func, identity,
call_with_device)
params["identity"] = identity
params["arg_treedef"] = arg_treedef
if not identity:
# Turn abstract values into ShapesDtypeStruct
flat_results_shape, result_treedef = pytree.flatten(result_shape)
try:
flat_results_aval = [core.ShapedArray(np.shape(r), dtypes.result_type(r))
for r in flat_results_shape]
except Exception:
msg = ("result_shape should be a pytree of values with structure "
"matching the expected result of the callback function. The "
"values must be either numeric scalars, or must have 'shape' and "
f"'dtype' attributes. Got {result_shape}")
raise ValueError(msg)
params["result_treedef"] = result_treedef
params["flat_results_aval"] = tuple(flat_results_aval)
flat_results = outside_call_p.bind(*flat_args, **params)
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
# We need the lock for when we use the CustomCall implementation of callbacks.
# The outfeed implementation is driven by a single thread from C++.
_print_tap_lock = threading.Lock()
def _print_tap_func(
arg, transforms, *, device=None,
output_stream=None, threshold=1024, **kwargs):
"""The consumer for id_print.
We provide this as a simple tapping function for printing.
This is **experimental** and may not want to add many features to it;
it should be easy for the user to roll their own printing function.
Args:
device: the device from which the value originates (only if
``tap_with_device`` was used for :func:`id_print`).
output_stream: a function whose `write` method is called with the strings to
be output.
threshold: the value of numpy.array2string threshold parameter.
**kwargs: all other keyword args are printed before printing `arg`.
"""
def emit_str(s: str):
if output_stream is not None:
output_stream.write(s + "\n")
else:
print(s)
if transforms:
kwargs['transforms'] = [(name, params) if params else name
for name, params in transforms]
if device is not None:
kwargs['device'] = device
kv_pairs = " ".join([
f"{k}: {v}" for k, v in sorted(kwargs.items())
])
def pp_val(arg) -> pp.Doc:
if isinstance(arg, tuple):
return pp.group(pp.concat([
pp.text("( "),
pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])),
pp.text(" )")
]))
elif isinstance(arg, list):
return pp.group(pp.concat([
pp.text("[ "),
pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])),
pp.text(" ]")
]))
elif isinstance(arg, dict):
return pp.group(pp.concat([
pp.text("{ "),
pp.nest(2, pp.join(pp.brk(), [
pp.text(f"{k}=") + pp_val(v) for k, v in sorted(arg.items())
])),
pp.text(" }")
]))
elif isinstance(arg, np.ndarray):
return pp.text(np.array2string(arg, threshold=threshold))
else:
return pp.text(str(arg))
with _print_tap_lock:
if kv_pairs:
emit_str(kv_pairs)
emit_str(str(pp_val(arg)))
def _values_to_avals(vals) -> Sequence[core.ShapedArray]:
return tuple(core.raise_to_shaped(core.get_aval(v)) for v in vals)
### The id_tap_dep primitive
# The id_tap_dep_p primitive is used to create a dependency of the result of
# id_tap on the actual tap operation. This is only needed when the
# id_tap function is used with the `result` parameter. This primitive acts
# as the identity operator on the first argument.
#
# For example, given `id_tap(f, (a, b), result=(r, s)`, we convert this to
#
# a1, b1 = outside_call_p(f, a, b)
# r1 = id_tap_dep_p(r, a1)
# s1 = id_tap_dep_p(s, a1)
#
# There are always two arguments and the result is equal to the first.
id_tap_dep_p = core.Primitive("id_tap_dep")
id_tap_dep_p.multiple_results = False
id_tap_dep_p.def_impl(lambda r, _: r)
xla.register_translation(id_tap_dep_p,
lambda ctx, avals_in, avals_out, a_res, a_tap: [a_res])
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
def _id_tap_dep_jvp_rule(primals, tangents):
if FLAGS.jax_host_callback_ad_transforms:
assert False
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
return (id_tap_dep_p.bind(primals[0], primals[1]),
id_tap_dep_p.bind(tangents_instantiated[0], tangents_instantiated[1]))
ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule
def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
if FLAGS.jax_host_callback_ad_transforms:
assert False
if ad.is_undefined_primal(arg_res):
ct_res = _instantiate_zeros(cts, arg_res)
else:
ct_res = None
if ad.is_undefined_primal(arg_tap):
ct_tap = ad.Zero(arg_tap.aval)
else:
ct_tap = None
return (ct_res, ct_tap)
ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule
def _id_tap_dep_batching_rule(batched_args, batch_dims):
if FLAGS.jax_host_callback_ad_transforms:
assert False
arg_res, arg_tap = batched_args
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]
batching.primitive_batchers[id_tap_dep_p] = _id_tap_dep_batching_rule
def _id_tap_dep_masking_rule(operands, operands_logical_shapes):
if FLAGS.jax_host_callback_ad_transforms:
assert False
arg_res, arg_tap = operands
return id_tap_dep_p.bind(arg_res, arg_tap)
masking.masking_rules[id_tap_dep_p] = _id_tap_dep_masking_rule
### The outside_call primitive
"""
This primitive is used to implement the `call` and `id_tap` functions.
It takes several positional arguments that are the flattened
according to `arg_treedef`.
The result of the primitive is computed based on the `identity` parameter,
as follows:
* if `identity` is True, then the results are the same as the
positional arguments of the primitive (except perhaps the last couple of
arguments, see `has_token`). In this case, `result_treedef` and
`flat_results_aval` are ignored, and `args_treedef` describes the result also.
* if `identity` is False, then the results are those from
the call to the outside computation:
flatten(callback(arg_treedef.unflatten(args), device=...))
In this case, the callback results must match `result_treedef`
and `flat_results_aval`.
It takes the following parameters:
* callback: the function to invoke with the unflattened arguments,
the device and the transforms: `callback(arrays, device, transforms)`
* arg_treedef: the treedef for the argument.
* identity: see description above.
* result_treedef, flat_results_aval: describes the expected result of the
callback. Only used when not `identity`.
* transforms: a tuple of the transformations that have been applied. Each
element of the tuple is itself a tuple with the first element the name
of the transform. The remaining elements depend on the transform. For
example, for `batch`, the parameters are the dimensions that have been
batched, and for `mask` the logical shapes. These are unpacked by
_outside_call_run_callback before passing to the user function.
* has_token: a boolean, when True it means that the last positional argument
is the current token. In this case, the result of the primitive is
going to be the non-token positional arguments, along with the updated
token. The tokens and this parameter are added after all the JAX
transformations, just before staging XLA.
"""
outside_call_p = core.Primitive("outside_call")
outside_call_p.multiple_results = True
core.outfeed_primitives.add(outside_call_p)
def _outside_call_abstract_eval(*args_a: pe.AbstractValue,
identity, **params) -> Sequence[pe.AbstractValue]:
if identity:
# Do some validation here
assert "result_treedef" not in params
assert "flat_results_aval" not in params
return args_a
assert params["result_treedef"] is not None
assert params["flat_results_aval"] is not None
flat_results_aval = params["flat_results_aval"]
if "has_token" in params and params["has_token"]:
assert len(args_a) >= 2
return flat_results_aval + args_a[-2:]
else:
return flat_results_aval
outside_call_p.def_abstract_eval(_outside_call_abstract_eval)
def _outside_call_impl(*args, **params):
assert not "has_token" in params
if _inline_host_callback():
device = api.devices()[0]
results = _outside_call_run_callback(args, device, send_infeed=False, **params)
return results
else:
# We use the jitted-version of the primitive even for eager execution, both
# so that we do not duplicate logic, but also so that all outfeed is received
# by the outfeed_listeners, in the same thread from a given device. If we were
# to process the tap here, it would be coming from the main thread. Also,
# even in eager execution some primitives, such as while, are compiled.
# It would be confusing to process a sequence "id_tap; while" in two
# different threads.
return dispatch.apply_primitive(outside_call_p, *args, **params)
outside_call_p.def_impl(_outside_call_impl)
def _outside_call_translation_rule(ctx, avals_in, avals_out,
*args_op: XlaOp,
has_token,
identity,
flat_results_aval=(),
**params):
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
assert has_token
current_token = args_op[-2]
current_itoken = args_op[-1]
comp = ctx.builder