Skip to content

Commit

Permalink
Typecheck avals and sharding for arguments that were DCE'd.
Browse files Browse the repository at this point in the history
This keeps the promise of AOT that recompilation is guaranteed.

Fixes #18686

PiperOrigin-RevId: 585855658
  • Loading branch information
yashk2810 authored and jax authors committed Nov 28, 2023
1 parent 37f1142 commit 88d980f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 13 deletions.
58 changes: 46 additions & 12 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -1875,6 +1875,14 @@ def are_all_shardings_default_mem_kind(da_object, shardings):

MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]]


class AllArgsInfo(NamedTuple):
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
in_avals: Sequence[core.ShapedArray]
in_shardings: Any
debug_info: core.JaxprDebugInfo | None


@profiler.annotate_function
def lower_sharding_computation(
closed_jaxpr: core.ClosedJaxpr,
Expand Down Expand Up @@ -1904,6 +1912,9 @@ def lower_sharding_computation(
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore

all_args_info = AllArgsInfo(global_in_avals, in_shardings,
closed_jaxpr.jaxpr.debug_info)

(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
kept_var_idx, name_stack) = _dce_jaxpr(
closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
Expand Down Expand Up @@ -2004,7 +2015,8 @@ def lower_sharding_computation(
pmap_nreps=nreps,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=shape_poly_state,
all_default_mem_kind=all_default_mem_kind)
all_default_mem_kind=all_default_mem_kind,
all_args_info=all_args_info)


def _to_logical_sharding(
Expand Down Expand Up @@ -2090,6 +2102,8 @@ def lower_mesh_computation(
out_jaxpr_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts

all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info)

assert len(out_shardings) == len(out_jaxpr_avals)
if spmd_lowering:
global_out_avals = out_jaxpr_avals
Expand Down Expand Up @@ -2179,7 +2193,8 @@ def lower_mesh_computation(
in_layouts=(None,) * len(global_in_avals),
out_layouts=(None,) * len(global_out_avals),
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=lowering_result.shape_poly_state)
shape_poly_state=lowering_result.shape_poly_state,
all_args_info=all_args_info)

class MeshComputation(stages.XlaLowering):
_hlo: ir.Module | None
Expand Down Expand Up @@ -2568,6 +2583,7 @@ class UnloadedMeshExecutable:
jaxpr_debug_info: core.JaxprDebugInfo | None
in_layouts: Sequence[SpecifiedLayout | None]
out_layouts: Sequence[SpecifiedLayout | None]
all_args_info: AllArgsInfo | None

def build_unsafe_call(self):
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
Expand All @@ -2590,7 +2606,7 @@ def load(self) -> MeshExecutable:
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.in_layouts, self.out_layouts,
self.jaxpr_debug_info, self)
self.jaxpr_debug_info, self.all_args_info, self)

# May return a MeshExecutable in the compile_replicated case.
@staticmethod
Expand Down Expand Up @@ -2618,6 +2634,7 @@ def from_hlo(name: str,
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
all_default_mem_kind: bool = True,
all_args_info: AllArgsInfo | None = None,
compiler_options=None,
) -> MeshExecutable:
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
Expand Down Expand Up @@ -2710,7 +2727,8 @@ def from_hlo(name: str,
auto_spmd_lowering=auto_spmd_lowering,
jaxpr_debug_info=jaxpr_debug_info,
in_layouts=in_layouts, # type: ignore
out_layouts=out_layouts).load() # type: ignore
out_layouts=out_layouts, # type: ignore
all_args_info=all_args_info).load() # type: ignore


class MeshExecutableFastpathData(NamedTuple):
Expand All @@ -2735,12 +2753,14 @@ class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
"_in_layouts", "_out_layouts", "_jaxpr_debug_info", "_unloaded_executable",
"_in_layouts", "_out_layouts", "_jaxpr_debug_info",
"_all_args_info", "_unloaded_executable",
]

def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
out_shardings, auto_spmd_lowering, kept_var_idx,
in_layouts, out_layouts, jaxpr_debug_info=None,
all_args_info: AllArgsInfo | None = None,
unloaded_executable=None):
self.xla_executable = xla_executable
self.build_unsafe_call = build_unsafe_call
Expand All @@ -2755,27 +2775,38 @@ def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
self._in_layouts = in_layouts
self._out_layouts = out_layouts
self._jaxpr_debug_info = jaxpr_debug_info
self._all_args_info = all_args_info
self._unloaded_executable = unloaded_executable

@property
def unsafe_call(self) -> Callable[..., Any]:
if self._unsafe_call is None:
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call
return self._unsafe_call # type: ignore

# -- stages.XlaExecutable overrides

def xla_extension_executable(self):
return self.xla_executable

def call(self, *args):
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
if self._all_args_info is None:
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
ref_avals = self.in_avals
in_shardings = self._in_shardings
debug_info = self._jaxpr_debug_info
else:
kept_args = args
ref_avals = self._all_args_info.in_avals
iter_in_shardings = iter(self._in_shardings)
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
for i, s in enumerate(self._all_args_info.in_shardings)]
debug_info = self._all_args_info.debug_info

arg_avals = map(xla.abstractify, kept_args)
ref_avals = self.in_avals
check_arg_avals_for_call(ref_avals, arg_avals, self._jaxpr_debug_info)
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
# Check the GDA sharding and the input sharding.
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings,
self._jaxpr_debug_info)
check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable

def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
Expand Down Expand Up @@ -2922,7 +2953,8 @@ def _compile_replicated_mesh_executable_from_hlo(
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx, (None,) * len(global_in_avals),
(None,) * len(global_out_avals), jaxpr_debug_info, None)
(None,) * len(global_out_avals), jaxpr_debug_info,
None, None)


@lru_cache
Expand Down Expand Up @@ -2956,6 +2988,8 @@ def check_gda_or_array_xla_sharding_match(
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
if not isinstance(arg, ArrayImpl):
continue
if is_unspecified_or_auto(xs):
continue

db_xs = check_device_backend_on_shardings([xs])
if not db_xs:
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/export/export.py
Expand Up @@ -764,7 +764,7 @@ def _check_lowering(lowering) -> None:
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
"all_default_mem_kind", "in_layouts", "out_layouts"]
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"]
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
Expand Down
46 changes: 46 additions & 0 deletions tests/pjit_test.py
Expand Up @@ -4018,6 +4018,52 @@ def test_pjit_with_deleted_input_at_subsequent_call(self, committed):
x.delete()
_ = f(x)

def test_aot_error_on_dced_avals_mismatch(self):
x, y1, y2 = jnp.ones(4), jnp.ones(4), jnp.ones(1)

@jax.jit
def f(x, y):
return x + 1 if y.shape[0] > 2 else x + 2

f_out1 = f(x, y1)
f(x, y2)

g = f.lower(x, y1).compile()
g_out1 = g(x, y1)
self.assertArraysEqual(f_out1, g_out1)

with self.assertRaisesRegex(
TypeError,
'Argument types differ from the types for which this computation was'
' compiled'):
g(x, y2)

def test_aot_error_on_dced_shardings_mismatch(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 2)
np_inp = np.arange(math.prod(shape)).reshape(shape)

x = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
y1 = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
y2 = jax.device_put(np_inp, NamedSharding(mesh, P('y')))

@jax.jit
def f(x, y):
return x + 1

f_out1 = f(x, y1)
f(x, y2)

g = f.lower(x, y1).compile()
g_out1 = g(x, y1)
self.assertArraysEqual(f_out1, g_out1)

with self.assertRaisesRegex(
ValueError,
r"Compiled object called with input sharding.*does not match the "
r"sharding.*the computation was compiled with"):
g(x, y2)


@jtu.pytest_mark_if_available('multiaccelerator')
class UtilTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 88d980f

Please sign in to comment.