-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
call_tf.py
683 lines (579 loc) · 26.8 KB
/
call_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Allows JAX to call TensorFlow functions with support for autodiff.
**Experimental: please give feedback, and expect changes.**
This module introduces the function :func:`call_tf` that allows JAX to call
TensorFlow functions.
For examples and details, see
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax.
"""
from __future__ import annotations
from collections.abc import Sequence
import dataclasses
import functools
from typing import Any, Callable, Optional
from absl import logging
import jax
from jax import dlpack
from jax import dtypes
from jax import numpy as jnp
from jax import tree_util
from jax._src import ad_util
from jax._src import core
from jax._src import effects
from jax._src import util
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import hlo
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
from jax.interpreters import mlir
import numpy as np
import tensorflow as tf
map = util.safe_map
zip = util.safe_zip
TfConcreteFunction = Any
TfVal = jax2tf_internal.TfVal
# The platforms for which to use DLPack to avoid copying (only works on GPU
# and CPU at the moment, and only for Array). For CPU we don't need
# DLPack, if we are careful.
_DLPACK_PLATFORMS = ("gpu",)
class UnspecifiedOutputShapeDtype:
pass
def call_tf(
callable_tf: Callable,
has_side_effects=True,
ordered=False,
output_shape_dtype=UnspecifiedOutputShapeDtype(),
call_tf_graph=False,
) -> Callable:
"""Calls a TensorFlow function from JAX, with support for reverse autodiff.
The ``callable_tf`` will be called with TensorFlow-compatible arguments (
numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
function must return the same type of results.
If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then
``callable_tf`` will be compiled with ``tf.function(callable_tf,
jit_compile=True)``
and the resulting XLA computation will be embedded in JAX's XLA computation.
If ``call_tf`` appears outside a JAX staging context, it will be called inline
using TensorFlow eager mode.
The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
``callable_tf`` will be differentiated using ``tf.GradientTape``. This means
that the gradient will be TensorFlow-accurate, e.g., will respect the
custom gradients that may be defined for the code in ``callable_tf``.
For an example and more details see the
`README
<https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.
Args:
callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
arguments.
has_side_effects: if True then it ensures that instances of this primitive
are not removed or replicated by JAX optimizations such as dead-code
elimination.
ordered: If true, calls are modeled as having ordered effects.
output_shape_dtype: An optional declaration of the expected shape and dtype
of the result of the called TensorFlow function. If given it will be used
during JAX tracing to form the abstract values of the results of the
`call_tf`. If not given then we form a `tf.Graph` for the called
TensorFlow function and we use the TensorFlow-inferred shapes and types.
Must be a pytree matching the structure of the nested structure returned
from the TensorFlow function, containing objects with `.shape` and
`.dtype` attributes, e.g., `jax.ShapeDtypeStruct` or `jax.Array`.
call_tf_graph: EXPERIMENTAL, DO NOT USE. We may change the name in the
future.
Returns: a JAX callable that can be invoked with JAX pytree arguments, in
op-by-op mode or in a staged context. This callable can be used with JAX's
reverse-mode autodiff (:func:`jax.grad`).
"""
@jax.custom_vjp
def make_call(*args_jax):
"""We wrap it all in `make_call` so that we can attach custom VJP."""
args_flat_jax, args_treedef = tree_util.tree_flatten(args_jax)
# Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode
def canonical_arg(v):
v = v if getattr(v, "dtype", None) else np.asarray(v)
dtype = dtypes.canonicalize_dtype(v.dtype)
if dtype != v.dtype:
v = v.astype(dtype)
return v
args_flat_jax = tuple(map(canonical_arg, args_flat_jax))
def make_tensorspec(a_jax):
a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
a_tf_shape = [d if core.is_constant_dim(d) else None for d in a_jax.shape]
return tf.TensorSpec(a_tf_shape, a_tf_dtype)
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
if not isinstance(output_shape_dtype, UnspecifiedOutputShapeDtype):
output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype)
output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat)
else:
output_avals, output_shape_dtype_tree = None, None
res_treedef = None # We'll store here the result treedef
res_tf_flat = None # For error reporting
# The function below will be called at least once, either in eager
# mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf()
def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
args_tf = args_treedef.unflatten(args_tf_flat)
res_tf = callable_tf(*args_tf)
# b/279454591: When `callable_tf` is a tf function with zero outputs, it
# returns a `StatefulPartitionedCall` (if the function is stateful) or
# `PartitionedCall` (if the function is stateless) op instead of
# tf.Tensors. We work around this issue by replacing the output `res_tf`
# with an empty list.
if isinstance(res_tf, tf.Operation):
assert (
res_tf.type == "StatefulPartitionedCall"
or res_tf.type == "PartitionedCall"
)
t_out = res_tf.get_attr("Tout")
# t_out should be an empty list.
assert not t_out, (
"The TF function returned an unexpected result, please check its"
f" function body. res_tf = {res_tf}"
)
res_tf = t_out
nonlocal res_treedef, res_tf_flat
res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
assert res_treedef is None or res_treedef == res_treedef_now, (
f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}")
res_treedef = res_treedef_now
if output_avals is not None:
if res_treedef != output_shape_dtype_tree:
raise ValueError(
"The pytree of the TensorFlow function results does not match the "
"pytree of the declared output_shape_dtype:\n"
f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}")
assert len(output_avals) == len(res_tf_flat)
checked_res_tf_flat = [
check_tf_result(i, r_tf, r_aval)
for i, (r_tf, r_aval) in enumerate(
zip(res_tf_flat,
(output_avals
if output_avals is not None
else (None,) * len(res_tf_flat))))]
return checked_res_tf_flat
# Prepare a tf.function ahead of time, to cache the concrete functions. This
# won't be used in op-by-op execution mode.
function_flat_tf = tf.function(
callable_flat_tf, autograph=False, jit_compile=not call_tf_graph)
res_jax_flat = call_tf_p.bind(
*args_flat_jax,
# Carry the actual function such that op-by-op call can call in TF eager mode.
callable_flat_tf=callable_flat_tf,
function_flat_tf=function_flat_tf,
args_flat_sig_tf=args_flat_sig_tf,
output_avals=output_avals,
has_side_effects=has_side_effects,
ordered=ordered,
call_tf_graph=call_tf_graph,
)
# We must have called callable_flat_tf by nοw
assert res_treedef is not None
return res_treedef.unflatten(res_jax_flat)
# Define the fwd and bwd custom_vjp functions
def make_call_vjp_fwd(*args_jax):
# Return the primal arguments as the residual
return make_call(*args_jax), args_jax
def make_call_vjp_bwd(residual_jax, ct_res_jax):
args_jax = residual_jax # residual is the primal argument
def tf_vjp_fun(args_tf, ct_res_tf):
"""Invoke TF gradient."""
# TF does not like us to watch non-float vars
def replace_non_float(arg_tf):
if arg_tf.dtype.is_floating or arg_tf.dtype.is_complex:
return arg_tf
else:
# When watched, this will be ignored. When used in results it will
# result in a floating 0. gradient, which JAX will ignore (and
# replace it with a float0)
return tf.zeros((), dtype=tf.float32)
watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf)
with tf.GradientTape(persistent=True) as tape:
tape.watch(watched_args_tf)
res = callable_tf(*args_tf)
tf.nest.assert_same_structure(res, ct_res_tf)
dres_darg = tape.gradient(
tf.nest.map_structure(replace_non_float, res),
sources=watched_args_tf,
output_gradients=ct_res_tf,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
dres_darg = tree_util.tree_map(
lambda x: x if x is None else tf.convert_to_tensor(x),
dres_darg,
)
tf.nest.assert_same_structure(dres_darg, args_tf)
return dres_darg
# Use call_tf to call the VJP function
ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
# We must make the float0s that JAX expects
def fix_float0(arg_jax, ct_arg_jax):
arg_dtype = dtypes.result_type(arg_jax) # May be scalar
ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
if ct_arg_dtype != ct_arg_jax.dtype:
return ad_util.zeros_like_aval(core.ShapedArray(np.shape(arg_jax),
ct_arg_dtype))
return ct_arg_jax
ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax)
return ct_args_jax_fixed
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
return util.wraps(callable_tf)(make_call)
def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> TfVal:
# Check that the TF function returns values of expected types. This
# improves error reporting, preventing hard-to-diagnose errors downstream
try:
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
except Exception as e:
msg = ("The called TF function returns a result that is not "
f"convertible to JAX: {r_tf}.")
raise ValueError(msg) from e
if r_aval is None:
return r_tf
# We convert to TF type, and canonicalize to 32-bit if necessary
r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype)
# Checking shapes is trickier in presence of dynamic shapes. I wish we could
# check at runtime that the returned shape matches the declared shape. I wish
# that tf.ensure_shape did this, but it can only take shapes that contain None
# not computed shapes. However, in eager mode we should be able to resolve
# the declared shapes to constants and we get better checking.
if tf.executing_eagerly():
r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape)
else:
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
# We do as much checking as we can here, instead of relying on tf.ensure_shape
# because the latter gives different errors in eager vs. compiled mode.
# TODO(b/279454591): This strange error is from TF. Eager function suppose
# return tf Val with concrete shape but not. Here we change exception to warn
# and bypass it. This case need revisit on TF side.
try:
_ = len(r_tf.shape)
except ValueError as e:
msg = (
"The shape check test cannot be performed because the shape of the"
"`r_tf` tensor cannot be obtained."
f"r_tf = {r_tf}, r_aval = {r_aval}"
)
msg += str(e)
logging.warning(msg)
return r_tf
if (r_tf.dtype != r_aval_dtype_tf or
len(r_tf.shape) != len(r_aval_shape_tf) or
any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d
for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))):
msg = ("The shapes or dtypes returned by the TensorFlow function "
"do not match the declared output_shape_dtype:\n"
f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]")
raise ValueError(msg)
# At this point tf.ensure_shape does not do much, it should never throw an
# error, albeit it may refine the shape a bit.
return tf.ensure_shape(r_tf, r_aval_shape_tf)
call_tf_p = core.Primitive("call_tf")
call_tf_p.multiple_results = True
# The impl will be used in op-by-op mode and calls callable_tf in TF eager mode.
def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
# On GPU we use dlpack to avoid copies of data to the host.
def _arg_jax_to_tf(arg_jax):
if (isinstance(arg_jax, jax.Array) and
list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
arg_jax.dtype in dlpack.SUPPORTED_DTYPES):
arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
return tf.experimental.dlpack.from_dlpack(arg_dlpack)
# The following avoids copies to the host on CPU, always for Array
# and even for ndarray if they are sufficiently aligned.
# TODO(necula): on TPU this copies to the host!
if getattr(arg_jax, 'dtype', None) == dtypes.float0:
return tf.zeros(shape=arg_jax.shape,
dtype=jax2tf_internal._tf_np_dtype_for_float0)
return tf.constant(np.asarray(arg_jax))
args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
with jax2tf_internal.inside_call_tf():
# Call in TF eager mode
res_tf_flat = callable_flat_tf(*args_tf_flat)
def _res_tf_to_jax(res_tf: TfVal):
res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
if isinstance(res_tf, tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES:
res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
res_jax_platform = res_tf_platform.lower()
if res_jax_platform in _DLPACK_PLATFORMS:
res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
return jax.dlpack.from_dlpack(res_dlpack)
# When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail.
# To handle this special case, we create a numpy copy.
if res_tf.shape == tf.TensorShape([]) and res_tf.dtype == tf.bfloat16:
return jax.device_put(jnp.array(res_tf.numpy()))
else:
return jax.device_put(np.asarray(res_tf))
return list(map(_res_tf_to_jax, res_tf_flat))
call_tf_p.def_impl(_call_tf_impl)
@functools.lru_cache(maxsize=128)
def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.ConcreteFunction
with jax2tf_internal.inside_call_tf():
return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
# Mark the effectful instances of call_tf
@dataclasses.dataclass(frozen=True)
class CallTfEffect(effects.Effect):
__str__ = lambda _: "CallTfEffect"
call_tf_effect = CallTfEffect()
effects.lowerable_effects.add_type(CallTfEffect)
effects.control_flow_allowed_effects.add_type(CallTfEffect)
effects.remat_allowed_effects.add_type(CallTfEffect)
effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
class CallTfOrderedEffect(effects.Effect):
__str__ = lambda _: "CallTfOrderedEffect"
call_tf_ordered_effect = CallTfOrderedEffect()
effects.lowerable_effects.add_type(CallTfOrderedEffect)
effects.control_flow_allowed_effects.add_type(CallTfOrderedEffect)
effects.remat_allowed_effects.add_type(CallTfOrderedEffect)
effects.custom_derivatives_allowed_effects.add_type(CallTfOrderedEffect)
effects.ordered_effects.add_type(CallTfOrderedEffect)
effects.shardable_ordered_effects.add_type(CallTfOrderedEffect)
def _call_tf_abstract_eval(
*args_flat_avals,
function_flat_tf,
args_flat_sig_tf,
has_side_effects,
ordered,
output_avals,
call_tf_graph,
**__,
):
# Called only when we form a Jaxpr, i.e., under jit, scan, etc.
effects = set()
if ordered:
effects.add(call_tf_ordered_effect)
elif has_side_effects:
effects.add(call_tf_effect)
# If no output_avals is given, then we ask TF to infer the output shapes.
# We call this even if output_avals is given because it will ensure that
# callable_flat_tf is called. Since _get_concrete_function_tf is cached
# there is a small cost of calling it more often than needed.
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
args_flat_sig_tf)
# In the case that the tf.function has no return value
if len(concrete_function_flat_tf.outputs) == 0:
return (), effects
if output_avals is not None:
return output_avals, effects
def is_fully_known_shape(s):
return s.rank is not None and all(d is not None for d in s)
if all(is_fully_known_shape(s)
for s in concrete_function_flat_tf.output_shapes):
avals_from_tf = tuple(
# We convert to JAX type, and canonicalize to 32-bit if necessary
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
concrete_function_flat_tf.output_shapes))
return avals_from_tf, effects
msg = ("call_tf cannot call functions whose output has dynamic shape. "
f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
"Consider using the `output_shape_dtype` argument to call_tf. "
"\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
" for a discussion.")
raise ValueError(msg)
call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
def _call_tf_lowering(
ctx: mlir.LoweringRuleContext,
*args_op,
platform,
function_flat_tf,
args_flat_sig_tf,
has_side_effects,
ordered,
call_tf_graph,
output_avals,
**_,
):
# We use the same TF lowering device as for the embedding JAX computation.
# One example when this is needed is when the code refers to variables on one
# device. Or, for sharding annotations (only supported on TPU).
if platform in ["cpu", "tpu"]:
tf_platform = platform.upper()
elif platform == "cuda":
tf_platform = "GPU"
else:
raise ValueError("platform {platform} not supported")
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf)
captured_inputs = []
if concrete_function_flat_tf.captured_inputs:
# The function uses either captured variables or tensors.
msg = (
"call_tf works best with a TensorFlow function that does not capture "
"variables or tensors from the context. "
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. "
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
logging.warning(msg)
for inp in concrete_function_flat_tf.captured_inputs:
if inp.dtype == tf.resource: # A variable; lookup by handle
inp_vars = [v for v in concrete_function_flat_tf.variables if inp is v.handle]
assert len(inp_vars) == 1, f"Found {inp_vars}"
captured_inputs.append(inp_vars[0])
else:
captured_inputs.append(inp)
captured_ops = tuple(
mlir.ir_constant(np.asarray(inp))
for inp in captured_inputs
)
if call_tf_graph:
with jax2tf_internal.inside_call_tf():
return emit_tf_embedded_graph_custom_call(
ctx,
concrete_function_flat_tf,
tuple(args_op) + captured_ops,
has_side_effects,
ordered,
output_avals,
)
def convert_to_spec(x):
if isinstance(x, tf.TensorSpec):
return x
else:
return tf.TensorSpec.from_tensor(x)
args_tf_flat = [convert_to_spec(a) for a in args_flat_sig_tf]
with jax2tf_internal.inside_call_tf():
# When the TF computation uses variables on a particular device, we must
# get_compiler_ir for that exact device.
tf_device_name = f"/device:{tf_platform}:0"
try:
func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(*args_tf_flat)(
stage="hlo_serialized", device_name=tf_device_name)
except Exception as e:
msg = ("Error compiling TensorFlow function (see below for the caught exception)." +
"\ncall_tf can used " +
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
"compilable functions with static output shapes.\n" +
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." +
"\n\nCaught TensorFlow exception: " + str(e))
raise ValueError(msg) from e
xla_comp = xla_client.XlaComputation(func_tf_hlo)
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray:
if not res_shape.is_static():
msg = ("Compiled TensorFlow function has dynamic output shape " +
f"{res_shape}. call_tf can used " +
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
"compilable functions with static output shapes. " +
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.")
raise ValueError(msg)
res_dtype = res_shape.numpy_dtype()
jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
return core.ShapedArray(res_shape.dimensions(), jax_res_dtype)
result_shape = xla_comp.program_shape().result_shape()
if not result_shape.is_tuple():
# TF does not wrap singletons as tuples, but JAX expects tuples because
# call_tf is a multiple_results primitive.
result_shapes = (result_shape,)
else:
result_shapes = result_shape.tuple_shapes() # type: ignore
result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore
submodule = mlir.xla_computation_to_mlir_module(xla_comp)
symtab = ir.SymbolTable(submodule.operation)
callee_result_types = symtab["main"].type.results
fn = mlir.merge_mlir_modules(ctx.module_context.module,
f"call_tf_{function_flat_tf.name}",
submodule)
call = func_dialect.CallOp(callee_result_types,
ir.FlatSymbolRefAttr.get(fn),
tuple(args_op) + captured_ops)
if result_shape.is_tuple():
flat_results = [hlo.get_tuple_element(call, mlir.i32_attr(i))
for i in range(len(result_shapes))]
else:
flat_results = call.results
if ordered:
raise NotImplementedError(
"ordered=True is not supported in the jitted context without"
" `call_tf_graph=True`"
)
outputs = []
for op, res_aval, res_shape in zip(flat_results, result_avals,
result_shapes):
if res_aval.dtype != res_shape.numpy_dtype():
op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
outputs.append(op)
return outputs
def _register_call_lowering(platform):
mlir.register_lowering(call_tf_p, functools.partial(_call_tf_lowering,
platform=platform),
platform=platform)
for platform in ("cpu", "cuda", "tpu"):
_register_call_lowering(platform)
# Support the call_tf under jax2tf.convert in eager mode
def _jax2tf_call_tf(*args: TfVal,
callable_flat_tf: Callable,
**_) -> TfVal:
with jax2tf_internal.inside_call_tf():
res_tf_flat = callable_flat_tf(*args)
return res_tf_flat
jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf
def emit_tf_embedded_graph_custom_call(
ctx: mlir.LoweringRuleContext,
concrete_function_flat_tf,
operands: Sequence[ir.Value],
has_side_effects,
ordered,
output_avals,
):
"""Emits a custom call referencing a tf.Graph embedding of the TF function.
All call_tf called function information is stored in tf.metadata.
This includes:
(1) The called function name: This name will be used by the runtime to execute
the callback.
(2) The called function index in the XLACallModule `function_list` attribute.
"""
call_tf_concrete_function_list = jax2tf_internal.get_thread_local_state_call_tf_concrete_function_list()
if call_tf_concrete_function_list is None:
raise ValueError(
"call_tf_graph=True only support exporting by jax2tf.convert currently."
)
# TODO(necula): It is dangerous to modify global state when lowering because
# there are a number of lowering caches that only cache the StableHLO.
# See call_tf_test.py:test_multi_platform_call_tf_graph.
called_index = add_to_call_tf_concrete_function_list(
concrete_function_flat_tf, call_tf_concrete_function_list)
tf_backend_config = {
"has_token_input_output": ir.BoolAttr.get(ordered),
"called_index": mlir.i64_attr(called_index),
}
result_avals = ctx.avals_out if ctx.avals_out is not None else ()
operands = list(operands)
result_types = list(
util.flatten([mlir.aval_to_ir_types(aval) for aval in result_avals])
)
if ordered:
operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect)[0])
result_types.insert(0, mlir.token_type()[0])
custom_call = hlo.CustomCallOp(
result_types,
operands,
call_target_name=ir.StringAttr.get("tf.call_tf_function"),
has_side_effect=ir.BoolAttr.get(has_side_effects),
api_version=mlir.i32_attr(2),
called_computations=ir.ArrayAttr.get([]),
backend_config=ir.StringAttr.get(""),
)
# Store TF metadata in unregistered attribute
custom_call.attributes["tf.backend_config"] = ir.DictAttr.get(
tf_backend_config
)
results = list(custom_call.results)
if ordered:
token = results.pop(0)
ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: (token,)}))
return results
def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
try:
called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
except ValueError:
called_index = len(call_tf_concrete_function_list)
call_tf_concrete_function_list.append(concrete_tf_fn)
return called_index