Skip to content

Commit

Permalink
Documentation improvements for XlaCallModule.disabled_checks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 542475518
  • Loading branch information
gnecula authored and jax authors committed Jun 22, 2023
1 parent 5742ab4 commit 1535fa0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 21 deletions.
9 changes: 5 additions & 4 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ This is primarily a limitation of the SavedModel support for custom gradients.
Applies to native serialization only.

JAX natively uses custom calls for lowering of certain primitives.
The most common example are for the implementation of PRNG on GPUs
The most common example is for the implementation of PRNG on GPUs
where we get better performance with a custom call (`cu_threefry32`)
than if we use native StableHLO. Another class of examples are for
FFT and some linear algebra primitives (e.g., QR decomposition).
Expand All @@ -1021,9 +1021,10 @@ code that backs the custom call. For this reason, we maintain
a list of allowed custom call targets. If you try to serialize
code that invokes other targets you will get an error.

If you do not care about the compatibility guarantees of the
serialized artifact, you can set `native_serialization_strict_checks`
to `False` to disable the check.
If you want to disable this safety check for a specific custom call
with target `my_target`, you can add
`jax2tf.DisabledSafetyCheck.custom_call("my_target")` to the `disabled_checks`
parameter of the `jax2tf` function.

### XlaCallModule not supported by some TensorFlow tools

Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def convert(fun_jax: Callable,
backend on the machine where the lowering is done.
native_serialization_disabled_checks: In conjunction with
`native_serialization`, disable the specified safety checks.
See docstring of DisabledSafetyCheck.
Returns:
A version of `fun_jax` that expects TfVals as arguments (or
Expand Down
51 changes: 34 additions & 17 deletions jax/experimental/jax2tf/jax_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,38 +51,55 @@
DType = Any

class DisabledSafetyCheck:
# Use a strings representation to aid human readability in serializations.
_impl: str
"""A safety check should be skipped on (de)serialization.
def __init__(self, _impl:str):
# Do not use directly, use builders `platform`, `custom_call`.
self._impl = _impl
Most of these checks are performed on serialization, but some are deferred to
deserialization. The list of disabled checks is attached to the serialization,
e.g., as a sequence of string attributes to `jax_export.Exported` or of
`tf.XlaCallModuleOp`.
def __str__(self):
return self._impl
__repr__ = __str__

def __eq__(self, other) -> bool:
return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl

def __hash__(self) -> int:
return hash(self._impl)
You can disable more deserialization safety checks by passing
`TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`.
"""
_impl: str

@classmethod
def platform(cls) -> "DisabledSafetyCheck":
"""Allows the execution platform to differ from the serialization platform."""
"""Allows the execution platform to differ from the serialization platform.
Has effect only on deserialization.
"""
return DisabledSafetyCheck("platform")

@classmethod
def custom_call(cls, target_name: str) -> "DisabledSafetyCheck":
"""Allows the serialization of a call target not known to be stable."""
"""Allows the serialization of a call target not known to be stable.
Has effect only on serialization.
Args:
target_name: the name of the custom call target to allow.
"""
return DisabledSafetyCheck(f"custom_call:{target_name}")

def is_custom_call(self) -> Optional[str]:
"""Returns the custom call target allowed by this directive."""
m = re.match(r'custom_call:(.+)$', self._impl)
return m.group(1) if m else None

def __init__(self, _impl:str):
# Do not use directly, use builders `platform`, `custom_call`.
self._impl = _impl

def __str__(self):
return self._impl
__repr__ = __str__

def __eq__(self, other) -> bool:
return isinstance(other, DisabledSafetyCheck) and self._impl == other._impl

def __hash__(self) -> int:
return hash(self._impl)


@dataclasses.dataclass(frozen=True)
class Exported:
Expand Down Expand Up @@ -115,7 +132,7 @@ class Exported:
polymorphic dimension variables. This may be from `in_avals` but also
from inner calls of shape-polymorphic Exported modules.
disabled_checks: a list of descriptors of safety checks that have been
disabled at export time.
disabled at export time. See docstring of DisabledSafetyCheck.
_get_vjp: an optional function that takes the current exported function and
returns the exported VJP function.
The VJP function takes a flat list of arguments,
Expand Down

0 comments on commit 1535fa0

Please sign in to comment.