Skip to content

Commit

Permalink
Merge pull request #21191 from gnecula:export_simplify
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633179742
  • Loading branch information
jax authors committed May 13, 2024
2 parents 54ca3d4 + 78d4d0a commit e66a234
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 249 deletions.
15 changes: 0 additions & 15 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,18 +571,6 @@ class LoweringParameters:
# or multi-platform lowering.
global_constant_computation: bool = False

# TODO(b/302258959): in JAX native execution we cannot lower the tokens
# to stablehlo.token for the top-level function, due to runtime limitations.
# Instead, we use dummy bool[0] arrays. This is controlled by setting
# replace_tokens_with_dummy to True (default). However, when exporting StableHLO
# we can use real tokens, because the resulting StableHLO will not be
# executed directly, but will be embedded as an inner function in a larger
# JAX or TensorFlow program. In these cases, replace_tokens_with_dummy must
# be set to False (for serialization versions >= 9).
# Once the PJRT is extended to use tokens, we can use tokens even in the
# native execution (and we can remove this parameter).
# This parameter can be removed when minimum xla_extension_version is >= 260.
replace_tokens_with_dummy: bool = True

@dataclasses.dataclass
class TracebackCaches:
Expand Down Expand Up @@ -971,13 +959,10 @@ def lower_jaxpr_to_module(
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
replace_tokens_with_dummy = False
lower_jaxpr_to_fun(
ctx, "main", jaxpr, ordered_effects,
name_stack=name_stack,
public=True,
create_tokens=replace_tokens_with_dummy,
replace_tokens_with_dummy=replace_tokens_with_dummy,
num_output_tokens=0,
replicated_args=replicated_args,
arg_shardings=arg_shardings,
Expand Down
172 changes: 59 additions & 113 deletions jax/experimental/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@
minimum_supported_serialization_version = 9
maximum_supported_serialization_version = 9

_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9


class DisabledSafetyCheck:
"""A safety check should be skipped on (de)serialization.
Expand Down Expand Up @@ -106,11 +103,16 @@ def custom_call(cls, target_name: str) -> DisabledSafetyCheck:

@classmethod
def shape_assertions(cls) -> DisabledSafetyCheck:
"""Allows invocations with shapes that do not meet the constraints.
"""A noop. DEPRECATED.
Has effect on serialization (to suppress the generation of the assertions)
and also on deserialization (to suppress the checking of the assertions).
Was used previously to allow invocations with shapes that do not meet the
constraints. Has no effect anymore, shape assertions cannot be disabled.
"""
# TODO(necula): remove this after compatibility period. Was deprecated in
# May 2024.
warnings.warn(
"DisabledSafetyCheck.shape_assertions is deprecated, has no effect anymore",
DeprecationWarning, stacklevel=2)
return DisabledSafetyCheck("shape_assertions")

def is_custom_call(self) -> str | None:
Expand Down Expand Up @@ -344,10 +346,6 @@ def args_specs(
return _shape_poly.symbolic_args_specs(args, polymorphic_shapes)



def _keep_main_tokens(serialization_version: int) -> bool:
return serialization_version >= _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS

def export(fun_jax: Callable,
*,
lowering_platforms: Sequence[str] | None = None,
Expand Down Expand Up @@ -391,71 +389,58 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
# convert(f_jax), in which case a "jit" is implied. In that case we raise
# an error if the lowered function contains non-replicated sharding annotations.
wrapped_fun_jax = jax.jit(fun_jax)
allow_non_replicated_sharding = False
else:
# If we have a pjit or pmap already we do not wrap with another, and we
# allow shardings.
wrapped_fun_jax = fun_jax # type: ignore
allow_non_replicated_sharding = True

if lowering_platforms is not None:
actual_lowering_platforms = tuple(lowering_platforms)
else:
actual_lowering_platforms = (default_lowering_platform(),)

# Do not include shape assertions if the version is < 7.
enable_shape_assertions = (
DisabledSafetyCheck.shape_assertions() not in disabled_checks and
version >= _VERSION_START_SUPPORT_SHAPE_ASSERTIONS) # type: ignore
try:
prev_enable_shape_assertions = _shape_poly.thread_local_state.enable_shape_assertions
_shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions
replace_tokens_with_dummy = not _keep_main_tokens(version)

symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
# Static args may has no `shape` attribute.
if not hasattr(aval, "shape"):
continue
for d in aval.shape:
if _shape_poly.is_symbolic_dim(d):
if symbolic_scope is None:
symbolic_scope = (d.scope, k_path)
continue
symbolic_scope[0]._check_same_scope(
d, when=f"when exporting {fun_name}",
self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=_shape_poly.args_kwargs_path_to_str(k_path))

lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
replace_tokens_with_dummy=replace_tokens_with_dummy,
))

lowering = lowered._lowering # type: ignore
_check_lowering(lowering)
mlir_module = lowering.stablehlo()

args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
if "mut" in lowering.compile_args:
if lowering.compile_args["mut"]: raise NotImplementedError
if "kept_var_idx" in lowering.compile_args:
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
else:
# For pmap
module_kept_var_idx = tuple(range(len(args_avals_flat)))
shape_poly_state = lowering.compile_args["shape_poly_state"]
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
or lowering.compile_args.get("ordered_effects", [])):
mlir_module = _wrap_main_func(
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree,
has_platform_index_argument=shape_poly_state.has_platform_index_argument,
module_kept_var_idx=module_kept_var_idx,
serialization_version=version)
finally:
_shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions
# TODO: move to `lower`
symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
# Static args may has no `shape` attribute.
if not hasattr(aval, "shape"):
continue
for d in aval.shape:
if _shape_poly.is_symbolic_dim(d):
if symbolic_scope is None:
symbolic_scope = (d.scope, k_path)
continue
symbolic_scope[0]._check_same_scope(
d, when=f"when exporting {fun_name}",
self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=_shape_poly.args_kwargs_path_to_str(k_path))

lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
))

lowering = lowered._lowering # type: ignore
_check_lowering(lowering)
mlir_module = lowering.stablehlo()

args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
if "mut" in lowering.compile_args:
if lowering.compile_args["mut"]: raise NotImplementedError
if "kept_var_idx" in lowering.compile_args:
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
else:
# For pmap
module_kept_var_idx = tuple(range(len(args_avals_flat)))
shape_poly_state = lowering.compile_args["shape_poly_state"]
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
or lowering.compile_args.get("ordered_effects", [])):
mlir_module = _wrap_main_func(
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree,
has_platform_index_argument=shape_poly_state.has_platform_index_argument,
module_kept_var_idx=module_kept_var_idx,
serialization_version=version)

with mlir_module.context:
mlir_module_attrs = mlir_module.operation.attributes
Expand Down Expand Up @@ -483,13 +468,10 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
logging.info("Dumped the exported MLIR module to %s", dumped_to)

_check_module(mlir_module,
allow_non_replicated_sharding=allow_non_replicated_sharding,
disabled_checks=disabled_checks)

ordered_effects = tuple(lowering.compile_args["ordered_effects"])
unordered_effects = tuple(lowering.compile_args["unordered_effects"])
if version < _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
ordered_effects = unordered_effects = ()

nr_devices = len(lowering.compile_args["device_assignment"])
def export_sharding(s: LoweringSharding,
Expand Down Expand Up @@ -636,17 +618,11 @@ def is_token(typ, attrs):
assert token_result_idxs == list(range(0, nr_token_results))
nr_array_results = len(orig_output_types) - nr_token_results
assert nr_array_results >= 0
if _keep_main_tokens(serialization_version):
new_main_arg_indices = (tuple(range(0, nr_platform_index_args)) +
tuple(range(nr_platform_index_args + nr_dim_args,
len(orig_input_types))))
new_main_result_indices = tuple(range(0, len(orig_output_types)))
else:
new_main_arg_indices = (
tuple(range(0, nr_platform_index_args)) +
tuple(range(nr_platform_index_args + nr_dim_args + nr_token_args,
len(orig_input_types))))
new_main_result_indices = tuple(range(nr_token_results, len(orig_output_types)))
new_main_arg_indices = (
*range(nr_platform_index_args),
*range(nr_platform_index_args + nr_dim_args, len(orig_input_types)))
new_main_result_indices = tuple(range(0, len(orig_output_types)))

new_main_input_types = [orig_input_types[idx] for idx in new_main_arg_indices]
new_main_output_types = [orig_output_types[idx] for idx in new_main_result_indices]
new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types)
Expand Down Expand Up @@ -714,11 +690,8 @@ def is_token(typ, attrs):
else:
orig_main_args.append(arg)
# Then the token arguments
if _keep_main_tokens(serialization_version):
orig_main_args.extend(
new_main_op.arguments[nr_platform_index_args: nr_platform_index_args + nr_token_args])
else:
orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args)
orig_main_args.extend(
new_main_op.arguments[nr_platform_index_args: nr_platform_index_args + nr_token_args])
# Then the array arguments. We insert a ConvertOp as the only use of
# an input argument. This helps the downstream shape refinement because
# it will set the type of input arguments to static shapes, and this
Expand Down Expand Up @@ -844,7 +817,6 @@ def _check_lowering(lowering) -> None:
check_sharding_pattern = re.compile(r"^({replicated}|{unknown shard_as.*}|"")$")

def _check_module(mod: ir.Module, *,
allow_non_replicated_sharding: bool,
disabled_checks: Sequence[DisabledSafetyCheck]) -> None:
"""Run a number of checks on the module.
Expand All @@ -853,8 +825,6 @@ def _check_module(mod: ir.Module, *,
non_replicated sharding annotations.
disabled_checks: the safety checks that are disabled.
"""
sharding_attr = ir.StringAttr.get("Sharding", mod.context)
shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context)
allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
for dc in disabled_checks:
target = dc.is_custom_call()
Expand All @@ -865,34 +835,13 @@ def _check_module(mod: ir.Module, *,
ir.StringAttr.get(target, mod.context)
for target in allowed_custom_call_targets}
disallowed_custom_call_ops: list[str] = []
def check_sharding(op: ir.Operation, loc: ir.Location):
if not allow_non_replicated_sharding:
try:
sharding = op.attributes["mhlo.sharding"]
except KeyError:
pass
else:
if not re.match(check_sharding_pattern, ir.StringAttr(sharding).value):
raise ValueError(
"Lowered function does not have a top-level pjit but it has"
f" non-replicated sharding annotations, e.g., {op} at {loc}.\nSee"
" https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning"
" for a discussion."
)

def check_op(op: ir.Operation):
op_name = op.operation.name
if op_name == "func.func":
check_sharding(op.operation, op.location)

elif op_name == "stablehlo.custom_call":
if op_name == "stablehlo.custom_call":
call_target_name_attr = op.operation.attributes["call_target_name"]
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
disallowed_custom_call_ops.append(f"{op} at {op.location}")
if call_target_name_attr == sharding_attr:
check_sharding(op, op.location)
elif call_target_name_attr == shape_assertion_attr:
assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks)

def walk_operations(op):
check_op(op)
Expand Down Expand Up @@ -1227,10 +1176,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
else:
assert len(lowering_platforms) == 1

if _keep_main_tokens(exported.mlir_module_serialization_version):
ordered_effects = exported.ordered_effects
else:
ordered_effects = ()
ordered_effects = exported.ordered_effects
for eff in ordered_effects:
token_in = ctx.tokens_in.get(eff)[0]
submodule_args.append(token_in)
Expand Down
15 changes: 2 additions & 13 deletions jax/experimental/export/_shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import io
import copy
import operator as op
import threading
import tokenize
from typing import Any, Callable, Union, overload
import warnings
Expand Down Expand Up @@ -96,15 +95,6 @@ def __init__(self, message: str):
# https://github.com/python/mypy/issues/5887
super().__init__(error_msg) # type: ignore

class _ShapePolyThreadLocalState(threading.local):

def __init__(self):
# TODO(necula): this does not play well with some lowering caches, because
# this state is not part of the cache key.
self.enable_shape_assertions = True

thread_local_state = _ShapePolyThreadLocalState()


class Comparator(Enum):
EQ = 1
Expand Down Expand Up @@ -1311,9 +1301,8 @@ def shape_assertion(assert_what: jax.Array,
The format specifiers are sometimes processed with Python's
`string::format` method, and sometimes with `llvm::formatv`.
"""
if thread_local_state.enable_shape_assertions:
shape_assertion_p.bind(assert_what, *error_message_inputs,
error_message=error_message)
shape_assertion_p.bind(assert_what, *error_message_inputs,
error_message=error_message)

# A JAX primitive with no array arguments but with a dimension parameter
# that is a DimExpr. The value of the primitive is the value of the dimension,
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,7 @@ We list here a history of the serialization version numbers:
Supported by XlaCallModule since October 27th, 2023,
available in JAX since October 20th, 2023 (JAX 0.4.20),
and the default since February 1st, 2024 (JAX 0.4.24).
This is the only supported version as of 27th of March, 2024.

## Known issues

Expand Down

0 comments on commit e66a234

Please sign in to comment.