diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 5b50938b9c7d..782ec8ae5bdb 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -371,6 +371,21 @@ def make_ir_context() -> ir.Context: sharding_impls.ShardingContext, ] +class ShapePolyLoweringState: + # The names of the dimension variables, sorted by name. This is the order in + # which they are passed to the IR functions that need them. This is only + # used for native serialization with polymorphic shapes when + # --jax_dynamic_shapes is off. + dim_vars: Sequence[str] + # Whether the module uses dimension variables, either in its inputs or + # from an inner call to a polymorphic Exported. + uses_dim_vars: bool + + def __init__(self, dim_vars: Sequence[str]): + self.dim_vars = dim_vars + self.uses_dim_vars = (len(dim_vars) > 0) + + @dataclasses.dataclass class ModuleContext: """Module-wide context information for MLIR lowering.""" @@ -385,11 +400,8 @@ class ModuleContext: keepalives: List[Any] channel_iterator: Iterator[int] host_callbacks: List[Any] - # The names of the dimension variables, sorted by name. This is the order in - # which they are passed to the IR functions that need them. This is only - # used for native serialization with polymorphic shapes when - # --jax_dynamic_shapes is off. - dim_vars: Sequence[str] + # Keep state for the lowering of shape polymorphism + shape_poly_state: ShapePolyLoweringState # Cached primitive lowerings. cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp] @@ -417,7 +429,7 @@ def __init__( func_dialect.FuncOp]] = None, cached_call_jaxpr_lowerings: Optional[Dict[Any, func_dialect.FuncOp]] = None, - dim_vars: Sequence[str] = ()): + shape_poly_state = None): assert platform is not None self.context = context or make_ir_context() self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context)) @@ -435,7 +447,7 @@ def __init__( self.cached_call_jaxpr_lowerings = ({} if cached_call_jaxpr_lowerings is None else cached_call_jaxpr_lowerings) - self.dim_vars = dim_vars + self.shape_poly_state = shape_poly_state or ShapePolyLoweringState(()) @property def backend(self) -> xb.XlaBackend: @@ -466,7 +478,7 @@ class LoweringRuleContext: tokens_out: Optional[TokenSet] # Mutable store for output containers axis_size_env: Optional[Dict[core.Var, ir.Value]] = None # Dynamic axis sizes dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables - # in same order as module_context.dim_vars + # in same order as module_context.shape_poly_state.dim_vars def set_tokens_out(self, tokens_out: TokenSet): assert self.tokens_out is None, 'Should only set `tokens_out` once.' @@ -535,9 +547,9 @@ def eval_dynamic_shape(ctx: LoweringRuleContext, else: ctx = ctx.replace( primitive="eval_dynamic_shape", - avals_in=[core.dim_value_aval()] * len(ctx.module_context.dim_vars)) + avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars)) res = lower_fun( - partial(core.evaluate_shape, shape, ctx.module_context.dim_vars), + partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars), multiple_results=True)(ctx, *ctx.dim_var_values) return util.flatten(res) # type: ignore @@ -546,6 +558,7 @@ class LoweringResult(NamedTuple): module: ir.Module keepalive: Optional[Any] host_callbacks: List[Any] + shape_poly_state: ShapePolyLoweringState _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"] @@ -628,7 +641,8 @@ def lower_jaxpr_to_module( if result_shardings is not None else result_shardings) ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack, - keepalives, channel_iter, host_callbacks, dim_vars=dim_vars) + keepalives, channel_iter, host_callbacks, + shape_poly_state=ShapePolyLoweringState(dim_vars)) with ctx.context, ir.Location.unknown(ctx.context): # Remove module name characters that XLA would alter. This ensures that # XLA computation preserves the module name. @@ -658,7 +672,8 @@ def lower_jaxpr_to_module( raise ValueError( f"Cannot lower jaxpr with verifier errors: {module_string}") from e - return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks) + return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks, + ctx.shape_poly_state) def module_to_string(module: ir.Module) -> str: output = io.StringIO() @@ -805,7 +820,7 @@ def aval_to_types(aval): aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) - num_dim_vars = len(ctx.dim_vars) + num_dim_vars = len(ctx.shape_poly_state.dim_vars) dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars dim_var_types = map(aval_to_types, dim_var_avals) @@ -1006,7 +1021,7 @@ def _to_physical_op_sharding( def _emit_lowering_rule_as_fun(lowering_rule, ctx: LoweringRuleContext) -> func_dialect.FuncOp: """Emits the contents of a lowering rule as a private function.""" - num_dim_vars = len(ctx.module_context.dim_vars) + num_dim_vars = len(ctx.module_context.shape_poly_state.dim_vars) # TODO(necula) maybe only pass the dim_vars if they are needed? dim_var_types = map(aval_to_ir_types, [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars) @@ -1049,7 +1064,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, Assumes that an MLIR context, location, and insertion point are set. dim_var_values: the list of dimension variables values in the current - IR function, in the order of ctx.dim_vars. + IR function, in the order of ctx.shape_poly_state.dim_vars. """ assert ctx.platform != "gpu" def read(v: core.Atom) -> Sequence[ir.Value]: @@ -1075,7 +1090,7 @@ def write(v: core.Var, node: Sequence[ir.Value]): assert len(args) == len(jaxpr.invars), (jaxpr, args) assert len(consts) == len(jaxpr.constvars), (jaxpr, consts) assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts - assert len(ctx.dim_vars) == len(dim_var_values), (ctx.dim_vars, dim_var_values) + assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values) map(write, jaxpr.constvars, consts) map(write, jaxpr.invars, args) for eqn in jaxpr.eqns: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a4f6b4d55d76..867e97eb87aa 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1945,7 +1945,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) return (lowering_result.module, lowering_result.keepalive, lowering_result.host_callbacks, unordered_effects, ordered_effects, - nreps, tuple_args) + nreps, tuple_args, lowering_result.shape_poly_state) @dataclasses.dataclass(frozen=True) @@ -2080,7 +2080,7 @@ def lower_sharding_computation( semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore semantic_out_shardings = SemanticallyEqualShardings(out_shardings) (module, keepalive, host_callbacks, unordered_effects, ordered_effects, - nreps, tuple_args) = _cached_lowering_to_hlo( + nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, da_object, lowering_platform, donated_invars, name_stack) @@ -2111,7 +2111,8 @@ def lower_sharding_computation( device_assignment=da_object, committed=committed, pmap_nreps=nreps, - jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info) + jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, + shape_poly_state=shape_poly_state) def _to_logical_sharding( @@ -2285,7 +2286,8 @@ def lower_mesh_computation( backend=backend, device_assignment=_create_da_object(tuple(mesh.devices.flat)), committed=True, - jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info) + jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, + shape_poly_state=lowering_result.shape_poly_state) class MeshComputation(stages.XlaLowering): _hlo: Optional[ir.Module] @@ -2617,8 +2619,10 @@ def from_hlo(name: str, committed: bool, pmap_nreps: int = 1, jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None, + shape_poly_state: Optional[mlir.ShapePolyLoweringState] = None, compiler_options=None ) -> MeshExecutable: + del shape_poly_state compiler_options_keys = tuple( compiler_options.keys()) if compiler_options is not None else None compiler_options_values = tuple( diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 5e4804b4ef14..8cf5fe2b6d9e 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -76,6 +76,9 @@ class Exported: module_kept_var_idx: the sorted indices of the arguments among `in_avals` that must be passed to the module. The other arguments have been dropped because they are not used. Same length as `in_shardings`. + module_uses_dim_vars: whether the `mlir_module_serialized` uses shape + polymorphic dimension variables. This may be from `in_avals` but also + from inner calls of Exported modules. strict_checks: whether the module was serialized with the following safety checking: (A) the lowered computation can only be executed on a platform for which it was lowered; (B) the serialized computation contains only @@ -101,6 +104,7 @@ class Exported: mlir_module_serialized: bytes xla_call_module_version: int module_kept_var_idx: Tuple[int, ...] + module_uses_dim_vars: bool _get_vjp: Optional[Callable[["Exported"], "Exported"]] @@ -264,10 +268,9 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: else: # For pmap module_kept_var_idx = tuple(range(len(args_avals_flat))) - - if not all( - core.is_constant_shape(a.shape) for a in args_avals_flat - ) or lowering.compile_args.get("ordered_effects", []): + shape_poly_state = lowering.compile_args["shape_poly_state"] + if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) + or lowering.compile_args.get("ordered_effects", [])): # All arguments are kept if we have dimension variables. assert len(module_kept_var_idx) == len(args_avals_flat) mlir_module = _wrap_main_func( @@ -334,6 +337,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: strict_checks=strict_checks, mlir_module_serialized=mlir_module_serialized, module_kept_var_idx=module_kept_var_idx, + module_uses_dim_vars=shape_poly_state.uses_dim_vars, xla_call_module_version=xla_call_module_version, _get_vjp=lambda exported: _export_native_vjp(fun_jax, exported)) @@ -387,7 +391,6 @@ def _wrap_main_func( Returns the wrapped module. """ dim_vars = shape_poly.all_dim_vars(args_avals_flat) - # Make a new module, do not mutate the "module" because it may be cached context = mlir.make_ir_context() with context, ir.Location.unknown(context): @@ -512,7 +515,7 @@ def _check_lowering(lowering) -> None: "spmd_lowering", "auto_spmd_lowering", "tuple_args", "ordered_effects", "unordered_effects", "keepalive", "host_callbacks", "pmap_nreps", "committed", - "device_assignment", "jaxpr_debug_info"] + "device_assignment", "jaxpr_debug_info", "shape_poly_state"] for compile_arg in lowering.compile_args.keys(): if compile_arg not in allowed_compile_args: raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") @@ -538,6 +541,7 @@ def _check_lowering(lowering) -> None: # used on all platforms for callbacks. Not supported yet. ("keepalive", lambda v: not v, "empty"), ("pmap_nreps", lambda v: v == 1, "1"), + ("shape_poly_state", lambda v: True, "N/A"), ): if compile_arg in lowering.compile_args: if not check_value(lowering.compile_args[compile_arg]): @@ -810,6 +814,9 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, f"The exported function '{exported.fun_name}' was lowered for " f"platform '{exported.lowering_platform}' but it is used " f"on '{platform}'.") + if any(not core.is_constant_shape(a.shape) for a in exported.in_avals): + ctx.module_context.shape_poly_state.uses_dim_vars = True + submodule = ir.Module.parse(exported.mlir_module) symtab = ir.SymbolTable(submodule.operation) # The called function may have been exported with polymorphic shapes and called diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index c8acdb4519e0..d24a0960be57 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -275,6 +275,8 @@ def _get_vjp(_): mlir_module_serialized=data.mlir_module_serialized, xla_call_module_version=data.xla_call_module_version, module_kept_var_idx=tuple(range(len(in_avals))), + module_uses_dim_vars=any(not core.is_constant_shape(a.shape) + for a in in_avals), _get_vjp=_get_vjp) # We use pjit in case there are shardings in the exported module. diff --git a/jax/experimental/jax2tf/tests/jax_export_test.py b/jax/experimental/jax2tf/tests/jax_export_test.py index 96212ff6f1ac..ba87588b8ab4 100644 --- a/jax/experimental/jax2tf/tests/jax_export_test.py +++ b/jax/experimental/jax2tf/tests/jax_export_test.py @@ -264,6 +264,8 @@ def inner(x): # x: inner_poly_spec inner_exp = jax_export.export(inner)( jax_export.poly_spec(inner_x.shape, inner_x.dtype, inner_poly_spec)) + self.assertEqual(inner_exp.module_uses_dim_vars, + (inner_poly_spec != "3,4,12")) outer_x = np.arange(np.prod(outer_x_shape), dtype=np.float32).reshape(outer_x_shape) # outer_x : f32[3,4,12] def outer(x): # x: outer_poly_spec @@ -278,12 +280,17 @@ def outer(x): # x: outer_poly_spec # Call it after exporting again, with polymorphic shapes outer_exp = jax_export.export(outer)( jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec)) - # TODO: for now, we use XlaCallModule to run modules with polymorphic shapes - # until we create the python bindings to invoke shape refinement. - if jax2tf is not None: - res2 = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy() - # res2 = jax_export.call_exported(exp2)(x2) - self.assertAllClose(2. * inner(outer_x), res2) + self.assertEqual(outer_exp.module_uses_dim_vars, + (inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12")) + if not outer_exp.module_uses_dim_vars: + res = jax_export.call_exported(outer_exp)(outer_x) + self.assertAllClose(2. * inner(outer_x), res) + else: + # TODO: for now, we use XlaCallModule to run modules with polymorphic shapes + # until we create the python bindings to invoke shape refinement. + if jax2tf is not None: + res = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy() + self.assertAllClose(2. * inner(outer_x), res) if __name__ == "__main__":