-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
jax2tf.py
3440 lines (2798 loc) · 133 KB
/
jax2tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides JAX and TensorFlow interoperation APIs."""
from collections.abc import Iterable, Sequence
from functools import partial
import contextlib
import math
import operator
import os
import re
import threading
from typing import Any, Callable, Optional, Union, cast
import warnings
from absl import logging
import numpy as np
import jax
from jax import lax
from jax import config
from jax import custom_derivatives
from jax import random
from jax import numpy as jnp
from jax import tree_util
from jax import sharding
from jax.experimental import maps
from jax.experimental.jax2tf import shape_poly
from jax.experimental.jax2tf import impl_no_xla
from jax.experimental.jax2tf import jax_export
from jax.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import pjit
from jax._src import prng
from jax._src import random as random_internal
from jax._src import source_info_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import slicing as lax_slicing
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.numpy.ufuncs import logaddexp
import tensorflow as tf # type: ignore[import]
# These don't have public equivalents.
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
from tensorflow.core.framework import attr_value_pb2 # type: ignore[import]
try:
from tensorflow.python.compiler.xla.experimental import xla_sharding # type: ignore[import]
except ModuleNotFoundError:
# This can be removed when TF 2.10 support is no longer needed.
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import]
from tensorflow.python.framework import ops as tf_ops # type: ignore[import]
from tensorflow.python.eager import context as tf_context # type: ignore[import]
# pylint: enable=g-direct-tensorflow-import
NameStack = source_info_util.NameStack
PolyShape = shape_poly.PolyShape
DType = Any
DisabledSafetyCheck = jax_export.DisabledSafetyCheck
# A temporary internal flag, to enable the wrapping of jax.jit functions
# with tf.function(jit_compile=True). See #7389. This change has triggered a
# number of failures in TF. We keep this until we are confident that it does
# not create problems.
# TODO(b/207464757): figure out why this change breaks test
_WRAP_JAX_JIT_WITH_TF_FUNCTION = False
# The scope name need to be a valid TensorFlow name. See
# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731
_VALID_SCOPE_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
_INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/-]")
map = util.safe_map
zip = util.safe_zip
def _sanitize_scope_name(name):
scope_name = _INVALID_SCOPE_CHAR.sub("_", name)
if not _VALID_SCOPE_REGEX.match(scope_name):
scope_name = f".{scope_name}"
return scope_name
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
try:
# Include all convertible types, even if not supported on accelerators.
with tf.device("CPU"):
tf.constant(v)
return True
except:
return False
class _DefaultNativeSerialization:
pass
DEFAULT_NATIVE_SERIALIZATION = _DefaultNativeSerialization()
# The implementation rules for primitives. The rule will be called with the
# arguments (TfVal) and must return TfVal (or a sequence thereof,
# if primitive.multiple_results). The exception are primarily the
# control-flow primitives.
tf_impl: dict[core.Primitive, Callable[..., Any]] = {}
# Some primitive implementation rules need the abstract values of arguments
# and the results. This is the case for the primitives implemented using
# _convert_jax_impl and those that need to adjust the shape of the outputs
# due to missing TF shape inference rules for TFXLA ops. The rules for these
# primitives should be added to `tf_impl_with_avals`.
# The abstract value are passed to the implementation as two special kwargs
# `_in_avals` (a tuple of core.ShapedArray) and `_out_aval` (a
# core.ShapedArray, or a tuple thereof when primitive.multiple_results).
tf_impl_with_avals: dict[core.Primitive, Callable[..., Any]] = {}
# XLA is not linked in all environments when converting a primitive. If this is
# the case, we first search for implementation rules for primitives in the
# following map. These implementations are workarounds, making use of TF ops
# that do work when XLA is not linked in.
tf_impl_no_xla = impl_no_xla.tf_impl_no_xla
# In order to ensure that JAX picks up the proper user-frame for source
# locations we will register the TensorFlow source path as an internal
# path with source_info_util. The typical stack when a JAX primitive
# conversion happens is:
# jax2tf.process_primitive (top of stack)
# jax tracing machinery ...
# tf.custom_gradient machinery ...
# jax2tf.converted_fun
# tf function machinery ...
# user code invokes the converted function on TF tensors
#
# We need to skip over not only JAX internal frames, but TF internal frames
# also.
# We register the TensorFlow source path lazily
_has_registered_tf_source_path = False
class _ThreadLocalState(threading.local):
def __init__(self):
# XLA is not linked in all environments; when converting a primitive, if this
# variable is disabled, we try harder to use only standard TF ops if they are
# applicable to the concrete use case; if the resulting conversion path ends up
# requiring a TFXLA operation, an exception is thrown instead.
self.enable_xla = True
# Keep track if we are inside a call_tf. In that context we disable the
# safety check that we are not inside JAX transformations.
self.inside_call_tf = False
# Maps dimension variables to TF expressions, for non-native lowering
self.shape_env: Sequence[tuple[str, TfVal]] = ()
# Whether to actually include XLA op metadata in the generated TF ops
# TODO(b/189306134): implement support for XLA metadata
self.include_xla_op_metadata = False
# A cache for the tf.convert_to_tensor for constants. We try to preserve
# sharing for constants, to enable tf.Graph to take advantage of it.
# See https://github.com/google/jax/issues/7992.
self.constant_cache = None # None means that we don't use a cache. We
# may be outside a conversion scope.
# A cache for the outside tf name_scope when the converted
# function is running. We will add this as the prefix to the generated tf op
# name. For example, the tf op name will be like
# "{tf_outer_name_scope}/JAX_NAME_STACKS"
self.tf_outer_name_scope = ""
# A dict collecting all tf concrete_functions called by stablehlo.custom_call
# This is used only by native serialization (unlike all the other
# thread-local state).
self.call_tf_concrete_function_list: Optional[list[Any]] = None
_thread_local_state = _ThreadLocalState()
def _get_current_name_stack() -> Union[NameStack, str]:
return source_info_util.current_name_stack()
@contextlib.contextmanager
def inside_call_tf():
# Set the inside_call_tf flag for a context.
prev = _thread_local_state.inside_call_tf
_thread_local_state.inside_call_tf = True
try:
yield
finally:
_thread_local_state.inside_call_tf = prev
def get_thread_local_state_call_tf_concrete_function_list() -> (
Optional[list[Any]]
):
return _thread_local_state.call_tf_concrete_function_list
@partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun_jax: Callable,
*,
polymorphic_shapes: Optional[str] = None,
with_gradient: bool = True,
enable_xla: bool = True,
native_serialization: Union[bool, _DefaultNativeSerialization] = DEFAULT_NATIVE_SERIALIZATION,
native_serialization_platforms: Sequence[str] = (),
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (),
) -> Callable:
"""Allows calling a JAX function from a TensorFlow program.
See
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
for more details about usage and common problems.
Args:
fun_jax: target JAX function to be called. Its arguments and return value
should be JAX arrays, or nested standard Python containers
(tuple/list/dict) thereof (pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during lowering.
.. warning:: The shape-polymorphic lowering is an experimental feature.
It is meant to be sound, but it is known to reject some JAX programs
that are shape polymorphic. The details of this feature can change.
It should be `None` (all arguments are monomorphic), a single PolyShape
or string (applies to all arguments), or a tuple/list of the same length
as the function arguments. For each argument the shape specification
should be `None` (monomorphic argument), or a Python object with the
same pytree structure as the argument.
See [how optional parameters are matched to
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification for an array argument should be an object
`PolyShape(dim0, dim1, ..., dimn)`
where each `dim` is a dimension specification: a positive integer denoting
a monomorphic dimension of the given size, or a string denoting a
dimension variable assumed to range over non-zero dimension sizes, or
the special placeholder string "_" denoting a monomorphic dimension
whose size is given by the actual argument. As a shortcut, an Ellipsis
suffix in the list of dimension specifications stands for a list of "_"
placeholders.
For convenience, a shape specification can also be given as a string
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
with surrounding parentheses: "(batch, ...)".
The lowering fails if it cannot ensure that the it would produce the same
sequence of TF ops for any non-zero values of the dimension variables.
polymorphic_shapes are only supported for positional arguments; shape
polymorphism is not supported for keyword arguments.
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
with_gradient: if set (default), add a tf.custom_gradient to the lowered
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
TensorFlow AD is supported for the output TensorFlow function, and the
value of the gradient will be JAX-accurate.
enable_xla: if set (default), use the simplest conversion
and use XLA TF ops when necessary. These ops are known to create issues
for the TFLite and TFjs converters. For those cases, unset this parameter
so the lowering tries harder to use non-XLA TF ops to lower the
function and aborts if this is not possible. Cannot be set to `False`
when using `native_serialization`.
native_serialization: serialize the JAX function natively to
StableHLO with compatibility guarantees. This makes it easier to have
confidence that the code executed when calling this function from
TensorFlow is exactly the same as JAX would run natively.
The DEFAULT_NATIVE_SERIALIZATION value defers to `False` if `enable_xla`
is set to `False` or to the configuration flag
`--jax2tf_default_native_serialization` otherwise.
Native serialization cannot be used with `enable_xla=False`.
native_serialization_platforms: In conjunction with
`native_serialization`, specify the platform(s)
for which to lower the code. Must be a tuple of
strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'.
The default (empty tuple), specifies the JAX default
backend on the machine where the lowering is done.
native_serialization_disabled_checks: In conjunction with
`native_serialization`, disable the specified safety checks.
See docstring of `DisabledSafetyCheck`.
Returns:
A version of `fun_jax` that expects TfVals as arguments (or
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
only TensorFlow ops and thus can be called from a TensorFlow program.
"""
if native_serialization is DEFAULT_NATIVE_SERIALIZATION:
if not enable_xla:
native_serialization = False
else:
native_serialization = config.jax2tf_default_native_serialization
if native_serialization and not enable_xla:
raise ValueError(
"native_serialization is not supported with enable_xla=False")
if native_serialization_platforms:
if not native_serialization:
warnings.warn(
"using native_serialization_platforms without native_serialization. "
"The parameter will have no effect, since the same code is serialized "
"for all platforms without native_serialization.")
if (not isinstance(native_serialization_platforms, (list, tuple)) or
not all(p in ["tpu", "cpu", "gpu"] for p in native_serialization_platforms)):
raise ValueError(
"native_serialization_platforms must be a sequence "
"containing a subset of {'cpu', 'gpu', 'tpu'}. "
f"Got: {native_serialization_platforms}")
native_serialization_platforms = tuple(native_serialization_platforms)
if len(native_serialization_platforms) > 1:
raise NotImplementedError(
"native_serialization_platforms is not yet implemented for multiple platforms")
api.check_callable(fun_jax)
def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
# It is Ok to nest convert when we are inside a call_tf
raise ValueError(
"convert must be used outside all JAX transformations." +
f"Trace state: {core.thread_local_state.trace_state.trace_stack}")
global _has_registered_tf_source_path
if not _has_registered_tf_source_path:
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
_has_registered_tf_source_path = True
def shape_and_dtype_tf(a: TfVal) -> tuple[Sequence[Optional[int]], DType]:
# The shape and JAX dtype for a TF argument
tf_arg_shape = np.shape(a)
# Fix the shape for TF1
tf_arg_shape = tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape)
_, a_jax_dtype = _tfval_to_tensor_jax_dtype(a)
return tf_arg_shape, a_jax_dtype
args_specs = jax_export.poly_specs(args_tf,
polymorphic_shapes=polymorphic_shapes,
get_shape_and_dtype=shape_and_dtype_tf)
# The polymorphic_shapes argument refers to positional arguments only.
# We assume None for the kwargs.
kwargs_specs = jax_export.poly_specs(kwargs_tf,
polymorphic_shapes=None,
get_shape_and_dtype=shape_and_dtype_tf)
combined_args_tf = (args_tf, kwargs_tf)
args_flat_tf: Sequence[TfVal]
args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf)
args_flat_tf = tuple(
map(preprocess_arg_tf, range(len(args_flat_tf)), args_flat_tf))
impl: SerializationImpl
if native_serialization:
impl = NativeSerializationImpl(
fun_jax,
args_specs=args_specs, kwargs_specs=kwargs_specs,
native_serialization_platforms=native_serialization_platforms,
native_serialization_disabled_checks=native_serialization_disabled_checks)
else:
impl = GraphSerializationImpl(
fun_jax,
args_specs=args_specs, kwargs_specs=kwargs_specs,
args_flat_tf=args_flat_tf,
enable_xla=enable_xla)
try:
impl.before_conversion()
outs_tree: tree_util.PyTreeDef = None # type: ignore
if with_gradient:
@tf.custom_gradient
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
nonlocal outs_tree
outs_tf, outs_avals, outs_tree = impl.run_fun_tf(args_flat_tf)
return (tuple(outs_tf),
_make_custom_gradient_fn_tf(
impl=impl,
args_tf=args_flat_tf,
outs_avals=outs_avals,
outs_tf=outs_tf))
outs_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
else:
outs_tf, _, outs_tree = impl.run_fun_tf(args_flat_tf)
message = ("The jax2tf-converted function does not support gradients. "
"Use `with_gradient` parameter to enable gradients")
# We use PreventGradient, which is propagated through a SavedModel.
outs_flat_tf = [
tf.raw_ops.PreventGradient(input=o, message=message)
for o in outs_tf
]
finally:
impl.after_conversion()
outs_flat_tf = [tf.identity(x, "jax2tf_out") for x in outs_flat_tf]
out_tf = tree_util.tree_unflatten(outs_tree, outs_flat_tf)
return out_tf
return converted_fun_tf
class SerializationImpl:
"""Implementation details for jax2tf serialization.
Abstract superclass for subclassing.
"""
def before_conversion(self):
"""Called in the resulting TF function, before any other method.
Useful to set any global context."""
raise NotImplementedError
def after_conversion(self):
"""Called in the resulting TF function, after conversion is done.
Useful to restore any global context set up by `before_conversion`."""
raise NotImplementedError
def run_fun_tf(self,
args_flat_tf: Sequence[TfVal]
) -> tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]:
"""Runs the resulting TF function.
Args:
args_flat_tf: a flat tuple of tf.Tensor arguments
Returns: a tuple with:
outs_tfs: a flat tuple of tf.Tensor results
outs_avals: a flat tuple of JAX abstract values for the underlying JAX
function.
outs_tree: the PyTreeDef for the outputs
"""
raise NotImplementedError
def run_vjp_fun_tf(self,
vjp_args_flat_tf: Sequence[TfVal],
outs_avals: Sequence[core.AbstractValue]) -> Sequence[TfVal]:
"""Runs the VJP function as a TF function.
Args:
vjp_args_flat_tf: the flattened sequence of tf.Tensor, including the
primal arguments followed by the output cotangents.
outs_avals: the flattened primal outputs avals
Returns: the flattened sequence of input cotangents.
"""
raise NotImplementedError
class NativeSerializationImpl(SerializationImpl):
def __init__(self, fun_jax, *,
args_specs, kwargs_specs,
native_serialization_platforms: Sequence[str],
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]):
self.fun_jax = fun_jax
self.args_specs = args_specs
self.kwargs_specs = kwargs_specs
self.native_serialization_disabled_checks = native_serialization_disabled_checks
if native_serialization_platforms:
self.lowering_platform: Optional[str] = native_serialization_platforms[0]
else:
self.lowering_platform = None
def before_conversion(self):
_prev_func_list = _thread_local_state.call_tf_concrete_function_list
_thread_local_state.call_tf_concrete_function_list = []
def _restore_context():
_thread_local_state.call_tf_concrete_function_list = _prev_func_list
self._restore_context = _restore_context
self.exported = jax_export.export(
self.fun_jax,
lowering_platform=self.lowering_platform,
disabled_checks=self.native_serialization_disabled_checks
)(*self.args_specs, **self.kwargs_specs)
def after_conversion(self):
self._restore_context()
def run_fun_tf(self,
args_flat_tf: Sequence[TfVal]
) -> tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]:
results = _run_exported_as_tf(args_flat_tf, self.exported)
return results, tuple(self.exported.out_avals), self.exported.out_tree
def run_vjp_fun_tf(self,
vjp_args_flat_tf: Sequence[TfVal],
outs_avals: Sequence[core.AbstractValue]) -> Sequence[TfVal]:
del outs_avals
exported_vjp = self.exported.vjp()
vjp_args_flat_tf = tuple(tf.identity(arg, f"jax2tf_arg_{arg_idx}")
for arg_idx, arg in enumerate(vjp_args_flat_tf))
in_cts_flat = _run_exported_as_tf(vjp_args_flat_tf, exported_vjp)
return tuple(tf.identity(arg, "jax2tf_out") for arg in in_cts_flat)
class GraphSerializationImpl(SerializationImpl):
def __init__(self, fun_jax, *,
args_specs, kwargs_specs,
args_flat_tf: Sequence[TfVal],
enable_xla: bool):
self.fun_jax = fun_jax
self.args_specs = args_specs
self.kwargs_specs = kwargs_specs
self.enable_xla = enable_xla
fun_name = getattr(fun_jax, "__name__", "unknown")
name_stack = util.wrap_name(fun_name, "jax2tf")
self.name_stack = name_stack
self.args_flat_tf = args_flat_tf
def before_conversion(self):
prev_enable_xla = _thread_local_state.enable_xla
prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
prev_tf_outer_name_scope = _thread_local_state.tf_outer_name_scope
def _restore_context():
_thread_local_state.enable_xla = prev_enable_xla
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata
_thread_local_state.tf_outer_name_scope = prev_tf_outer_name_scope
_thread_local_state.shape_env = ()
self._restore_context = _restore_context
_thread_local_state.enable_xla = self.enable_xla
# TODO(b/189306134): implement support for XLA metadata
_thread_local_state.include_xla_op_metadata = False
_thread_local_state.tf_outer_name_scope = tf.get_current_name_scope()
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
args_specs_flat, self.in_tree = tree_util.tree_flatten(
(self.args_specs, self.kwargs_specs))
self.args_avals_flat = tuple(
map(lambda a: core.raise_to_shaped(core.get_aval(a)), args_specs_flat))
dim_vars = shape_poly.all_dim_vars(self.args_avals_flat)
dim_values, _ = _interpret_fun_jax(
partial(shape_poly.compute_dim_vars_from_arg_shapes,
self.args_avals_flat, args_kwargs_tree=self.in_tree),
self.args_flat_tf, self.args_avals_flat, self.name_stack)
_thread_local_state.shape_env = zip(dim_vars, dim_values)
fun_flat_jax, out_tree_thunk = flatten_fun_jax(self.fun_jax, self.in_tree)
# out_tree_thunk will be ready after we call run_fun_tf below.
self.fun_flat_jax = fun_flat_jax
self.out_tree_thunk = out_tree_thunk
def after_conversion(self):
self._restore_context()
def run_fun_tf(self,
args_flat_tf: Sequence[TfVal]
) -> tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]:
outs_tf, outs_avals = _interpret_fun_jax(
self.fun_flat_jax,
args_flat_tf, self.args_avals_flat,
self.name_stack,
fresh_constant_cache=True)
return outs_tf, outs_avals, self.out_tree_thunk()
def run_vjp_fun_tf(self,
vjp_args_flat_tf: Sequence[TfVal],
outs_avals: Sequence[core.AbstractValue]) -> Sequence[TfVal]:
def fun_vjp_jax(*args_and_out_cts_flat_jax):
# Takes a flat list of primals and output cotangents
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(self.args_avals_flat)])
_, pullback_jax = jax.vjp(self.fun_flat_jax, *args_flat_jax)
return pullback_jax(out_cts_flat_jax)
vjp_in_avals = tuple(self.args_avals_flat) + tuple(outs_avals)
vjp_polymorphic_shapes = tuple(str(a.shape) # Note: may be _DimExpr, not just DimVar
for a in vjp_in_avals) # type: ignore
return convert(
fun_vjp_jax,
with_gradient=False,
polymorphic_shapes=vjp_polymorphic_shapes,
native_serialization=False)(*vjp_args_flat_tf)
def dtype_of_val(val: TfVal) -> DType:
"""Computes the TensorFlow dtype using JAX's typing rules.
If the value is a tf.Tensor, it starts with its dtype. If the value is a
constant it uses JAX to infer its dtype. The resulting dtype follows the
JAX type inference rules, and depends on the value of the
JAX_ENABLE_X64 flag.
See README.md for how 64-bit values are treated.
"""
tval, _ = _tfval_to_tensor_jax_dtype(val)
return tval.dtype
@partial(api_util.api_hook, tag="jax2tf_eval_polymorphic_shapes")
def eval_polymorphic_shape(fun_jax: Callable,
*,
polymorphic_shapes=None) -> Callable:
"""Evaluates the output shape in presence of shape polymorphism.
This is done without lowering or executing the function, same as for
`jax.eval_shape`.
Args:
fun_jax: target JAX function to be called. Its arguments and return value
should be JAX arrays, or nested standard Python containers
(tuple/list/dict) thereof (pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during shape evaluation. See discussion for `jax2tf.convert`.
.. warning:: The shape-polymorphic lowering is an experimental feature.
Returns: a function that takes `jax.ShapeDtypeStruct`s (or any values
with `.shape` and `.dtype` attributes) corresponding to the inputs for
`fun_jax`, and returns a tuple with:
* the jax.ShapeDtypeStruct corresponding to the result, as for
`jax.eval_shape`. The shape may contain symbolic dimension expressions.
* the value that can be passed to `polymorphic_shapes` for a subsequent
call to `jax2tf.eval_polymorphic_shape`, or `jax2tf.convert`.
For example:
>>> import jax
>>> from jax.experimental import jax2tf
>>> from jax import numpy as jnp
>>>
>>> f = lambda A, x: jnp.sin(jnp.dot(A, x))
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
>>> out_spec, out_poly_shape = jax2tf.eval_polymorphic_shape(f, polymorphic_shapes=["a, b", "b, c"])(A, x)
>>> print(out_spec.shape)
("a", "c")
>>> print(out_poly_shape)
(a, c)
>>> res_spec, res_poly_shape = jax2tf.eval_polymorphic_shape(lambda x: x.T, polymorphic_shapes=[out_poly_shape])(out_spec)
>>> print(res_poly_shape)
(c, a)
"""
def do_eval_polymorphic_shape(*args_specs) -> Any:
args_poly_specs = jax_export.poly_specs(
args_specs, polymorphic_shapes=polymorphic_shapes)
res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs)
# TODO(necula): For now we export the polymorphic shapes using `str`.
res_polymorphic_shape = tree_util.tree_map(lambda r: str(r.shape), res_poly_spec)
return res_poly_spec, res_polymorphic_shape
return do_eval_polymorphic_shape
# Internals
def flatten_fun_jax(fun_jax: Callable,
in_tree,
) -> tuple[Callable, Callable]:
"""Wraps the function to take a (flat) list of positional args.
jax2tf works better and is simpler when the JAX function takes and returns
just a tuple of values (no pytrees, no kwargs). This is in part because
jax.vjp does not support kwargs and we can only set
tf.custom_gradient on functions with flat arguments and results
Returns:
* the wrapped JAX function taking and returning a flat list of arguments
* a thunk that can be called after the wrapped function has been called
to return the output pytree.
"""
out_tree_ref = None
def fun_flat_jax(*args_flat_jax):
tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax)
tree_res = fun_jax(*tree_args, **tree_kwargs)
res_flat_jax, out_tree = tree_util.tree_flatten(tree_res)
nonlocal out_tree_ref
assert out_tree_ref is None or out_tree_ref == out_tree
out_tree_ref = out_tree
return res_flat_jax
return fun_flat_jax, lambda: out_tree_ref
def preprocess_arg_tf(arg_idx: int,
arg_tf: TfVal) -> TfVal:
"""Pre-processes the TF args.
Returns: a tuple with the pre-processed TF arg, the TF shape, and the
JAX dtype.
"""
if not _is_tfval(arg_tf):
msg = (f"Argument {arg_tf} of type {type(arg_tf)} of jax2tf.convert(f) should "
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
raise TypeError(msg)
# May cast the args_flat to JAX types, using JAX's interpretation
# of types of constants.
arg_tf, _ = _tfval_to_tensor_jax_dtype(arg_tf)
# Name input tensors; do this after we have cast the arguments
arg_tf = tf.identity(arg_tf, f"jax2tf_arg_{arg_idx}")
return arg_tf
def _make_custom_gradient_fn_tf(*,
impl: SerializationImpl,
args_tf: Sequence[TfVal],
outs_avals: Sequence[core.ShapedArray],
outs_tf: Sequence[TfVal]):
"""Prepares the TF function to be used with tf.custom_gradient.
Args:
impl: the serialization implementation details
args_tf: the flattened TF arguments of the primal function
outs_avals: the flattened output JAX abstract values of the primal function
outs_tf: the flattened TF outputs of the primal function
"""
def grad_fn_tf(*out_cts_flat_tf: TfVal,
variables=None):
if variables:
raise ValueError(
"Unexpected variables used in forward pass. "
"This should not happen for first-order differentiation. "
f"{variables=}")
# TODO: enable higher-order gradients
with tf.name_scope("jax2tf_vjp"):
def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal):
# If the primal function has outputs of integer or bool types, and if we are
# under a tf.function context, then TF will pass None in _out_cts_flat
# in place of these values. We should change these to float0 or
# else JAX gets unhappy. See issue #6975.
if out_ct_tf is not None:
return out_ct_tf
assert core.primal_dtype_to_tangent_dtype(out_ct_aval.dtype) == dtypes.float0, f"{out_ct_tf=}"
# Note that out_ct_aval.shape contains dimension variable from the
# primal function scope. We use tf.zeros_like to make a 0 of the right shape.
return tf.zeros_like(out_tf, dtype=_tf_np_dtype_for_float0)
out_cts_fixed_flat_tf = tuple(map(fix_out_ct, out_cts_flat_tf, outs_avals, outs_tf))
vjp_args_flat_tf = tuple(args_tf) + out_cts_fixed_flat_tf
in_cts_flat = impl.run_vjp_fun_tf(vjp_args_flat_tf, outs_avals)
# We do not need to fix the in_cts because the TF gradient machinery
# will adjust the unconnected gradients and those for integer types.
return in_cts_flat
return grad_fn_tf
@contextlib.contextmanager
def _extended_name_stack(extra_name_stack: Optional[str]):
name_ctx = (source_info_util.extend_name_stack(extra_name_stack)
if extra_name_stack
else contextlib.nullcontext())
with name_ctx:
yield
return
def _interpret_fun_jax(
fun_jax: Callable,
args_tf: Sequence[TfVal],
args_avals: Sequence[core.ShapedArray],
extra_name_stack: Optional[str],
fresh_constant_cache: bool = False,
) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]:
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals)
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \
_call_wrapped_with_new_constant_cache(subtrace_fun, args_tf,
fresh_constant_cache=fresh_constant_cache)
del main
return util.unzip2(out_vals)
def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
exported: jax_export.Exported,
) -> Sequence[TfVal]:
"""Runs the `exported` as an XlaCallModule TF op.
Returns: the flattened tuple of results.
"""
args_avals = exported.in_avals
# TF values may be integer types for float0
def _convert_value(val, aval):
# Check the shape
assert all(d_aval == d_val
for d_aval, d_val in zip(aval.shape, val.shape)
if core.is_constant_dim(d_aval)), (aval, val)
conversion_dtype = _to_tf_dtype(aval.dtype)
if conversion_dtype != aval.dtype:
return tf.cast(val, conversion_dtype)
else:
return val
args_flat_tf = tuple(map(_convert_value, args_flat_tf, args_avals))
out_shapes_tf = tuple(
tuple(d if core.is_constant_dim(d) else None
for d in out_aval.shape)
for out_aval in exported.out_avals)
out_types = tuple(_to_tf_dtype(out_aval.dtype) for out_aval in exported.out_avals)
kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx]
kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx]
version = exported.serialization_version
call_module_attrs = dict(
version=version,
Tout=out_types,
Sout=out_shapes_tf,
function_list=[
concrete_fn.function_def.signature.name
for concrete_fn in _thread_local_state.call_tf_concrete_function_list
] if _thread_local_state.call_tf_concrete_function_list is not None else [],
)
call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
if version >= 6:
call_module_attrs["disabled_checks"] = tuple(
str(dc)
for dc in exported.disabled_checks)
else:
if version >= 3:
if DisabledSafetyCheck.platform() in exported.disabled_checks:
call_module_attrs["platforms"] = () # No platform checking
if logging.vlog_is_on(3):
# We already logged the MLIR module when we exported it.
logging.vlog(3, "XlaCallModule %s", str(call_module_attrs))
call_module_attrs["module"] = exported.mlir_module_serialized
# Apply the shardings on arguments and results for pjit. This is redundant
# because the mlir_module_text will already contain the shardings, but it
# makes it easier for tools like the TPU inference converter to see the
# sharding without digging into the `module` attribute of the `XlaCallModule`
# op, in the same way as it is done for the legacy jax2tf conversion.
# Do not apply XlaSharding for REPLICATED, on inputs and outputs.
# This is an agreed convention, and also improves usability under TF eager.
# See b/255511660.
if exported.in_shardings is not None:
args_flat_tf = tuple(
map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()),
kept_args_flat_tf, kept_args_avals, exported.in_shardings))
res = tfxla.call_module(args_flat_tf, **call_module_attrs)
# TODO(b/278940799): Replace the TF v1 API with public TF2 API.
# Add the custom call tf.function into the default graph, so those functions
# will be available during tf.SavedModel.save.
if _thread_local_state.call_tf_concrete_function_list is not None:
for concrete_fn in _thread_local_state.call_tf_concrete_function_list:
tf.compat.v1.get_default_graph()._add_function_recursive(
concrete_fn._inference_function
)
if exported.out_shardings is not None:
res = list(map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()),
res, exported.out_avals, exported.out_shardings))
res = tuple(map(_convert_value, res, exported.out_avals))
return res
def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
in_vals: Sequence[TfVal],
fresh_constant_cache: bool = False
) -> Sequence[tuple[TfVal, core.ShapedArray]]:
try:
prev_constant_cache = _thread_local_state.constant_cache
# Start a new cache, so that we don't share constants across tf.function
# boundaries.
if fresh_constant_cache:
_thread_local_state.constant_cache = {}
else:
prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set()
out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \
fun.call_wrapped(*in_vals)
finally:
if (not fresh_constant_cache and
prev_constant_cache is not None and
_WRAP_JAX_JIT_WITH_TF_FUNCTION):
newly_added_keys = set(prev_constant_cache.keys()) - prev_constant_cache_keys
# Delete the newly added keys
for k in newly_added_keys:
del prev_constant_cache[k]
_thread_local_state.constant_cache = prev_constant_cache
return out_vals
def _convert_jax_impl(impl_jax: Callable, *,
multiple_results=True,
with_physical_avals=False,
extra_name_stack: Optional[str] = None) -> Callable:
"""Convert the JAX implementation of a primitive.
Args:
impl_jax: typically the impl-rule for a primitive, with signature
`(*args_jax: JaxVal, **kwargs) -> Sequence[JaxVal]`. This function implements
a primitive in terms of other primitives.
multiple_results: whether `impl_jax` returns a sequence of results.
extra_name_stack: additional element to add to the name stack for the
converted ops.
Returns:
a function with signature `(*args_tf: TfVal, _in_avals, _out_aval, **kwargs)
-> Sequence[TfVal]`.
"""
def wrapped_tf(*args_tf: TfVal, _in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
**kwargs) -> Sequence[TfVal]:
if with_physical_avals:
_in_avals = map(_jax_physical_aval, _in_avals)
_out_aval = _jax_physical_aval(_out_aval)
# We wrap the impl_jax to always return a tuple of results.
def impl_multiple_results_jax(*args_jax):
results_jax = impl_jax(*args_jax, **kwargs)
return results_jax if multiple_results else [results_jax]
results_tf, _ = _interpret_fun_jax(
impl_multiple_results_jax, args_tf, _in_avals,
extra_name_stack)
return results_tf if multiple_results else results_tf[0]
return wrapped_tf
@lu.transformation
def _interpret_subtrace(main: core.MainTrace,
in_avals: Sequence[core.ShapedArray],
*in_vals: TfVal):
trace = TensorFlowTrace(main, core.cur_sublevel())
in_tracers = tuple(
TensorFlowTracer(trace, val, aval)
for val, aval in zip(in_vals, in_avals))
outs = yield in_tracers, {} # type: Sequence[TfVal]
out_tracers: Iterable[TensorFlowTracer] = (
map(trace.full_raise, outs)) # type: ignore
out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = (
tuple((t.val, t.aval) for t in out_tracers))
yield out_vals_with_avals
def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal,
extra_name_stack: Optional[str],
fresh_constant_cache: bool = True) -> Sequence[TfVal]:
"""Evaluates a Jaxpr with tf.Tensor arguments.
This is most often used as the body of a tf.function, or tf.switch_case,
in which case it should use a fresh constant cache.
The output is a sequence of TfVal, suitable for use with TF.
"""
outs_tf, _ = _interpret_fun_jax(core.jaxpr_as_fun(jaxpr),
args_tf, jaxpr.in_avals, extra_name_stack,
fresh_constant_cache=fresh_constant_cache)
return outs_tf
def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
"""Converts JAX avals from logical to physical, if relevant.
JAX might have avals whose logical vs physical shape/dtype may
differ, and only the physical view is expected to possibly
relate to TF. TF impl rules should operate on the physical form.
A JAX logical aval might even correspond, in principle, to several
physical avals, but we don't support those here. Instead we assert
there is only one and return it.
"""