Skip to content

Commit

Permalink
[shape_poly] Keep track of whether a lowering contains shape polymorp…
Browse files Browse the repository at this point in the history
…hism

Previously, we kept the `dim_vars` in the `mlir.ModuleContext`. Now we
replace that with a mutable `ShapePolyLoweringState` that also tracks
whether we encounter shape polymorphism anywhere in the lowering.
For this purpose, we also add `shape_poly_state` to the lowering.compile_args.

We need to keep track of whether a module contains dimension variables
because such modules need shape refinement before they can be converted
to MHLO and compiled. For now, we just test that we set the
`Exported.module_uses_dim_vars` correctly.
  • Loading branch information
gnecula committed May 31, 2023
1 parent f884b4d commit 5cbc38d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 32 deletions.
47 changes: 31 additions & 16 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -371,6 +371,21 @@ def make_ir_context() -> ir.Context:
sharding_impls.ShardingContext,
]

class ShapePolyLoweringState:
# The names of the dimension variables, sorted by name. This is the order in
# which they are passed to the IR functions that need them. This is only
# used for native serialization with polymorphic shapes when
# --jax_dynamic_shapes is off.
dim_vars: Sequence[str]
# Whether the module uses dimension variables, either in its inputs or
# from an inner call to a polymorphic Exported.
uses_dim_vars: bool

def __init__(self, dim_vars: Sequence[str]):
self.dim_vars = dim_vars
self.uses_dim_vars = (len(dim_vars) > 0)


@dataclasses.dataclass
class ModuleContext:
"""Module-wide context information for MLIR lowering."""
Expand All @@ -385,11 +400,8 @@ class ModuleContext:
keepalives: List[Any]
channel_iterator: Iterator[int]
host_callbacks: List[Any]
# The names of the dimension variables, sorted by name. This is the order in
# which they are passed to the IR functions that need them. This is only
# used for native serialization with polymorphic shapes when
# --jax_dynamic_shapes is off.
dim_vars: Sequence[str]
# Keep state for the lowering of shape polymorphism
shape_poly_state: ShapePolyLoweringState

# Cached primitive lowerings.
cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp]
Expand Down Expand Up @@ -417,7 +429,7 @@ def __init__(
func_dialect.FuncOp]] = None,
cached_call_jaxpr_lowerings: Optional[Dict[Any,
func_dialect.FuncOp]] = None,
dim_vars: Sequence[str] = ()):
shape_poly_state = None):
assert platform is not None
self.context = context or make_ir_context()
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
Expand All @@ -435,7 +447,7 @@ def __init__(
self.cached_call_jaxpr_lowerings = ({}
if cached_call_jaxpr_lowerings is None
else cached_call_jaxpr_lowerings)
self.dim_vars = dim_vars
self.shape_poly_state = shape_poly_state or ShapePolyLoweringState(())

@property
def backend(self) -> xb.XlaBackend:
Expand Down Expand Up @@ -466,7 +478,7 @@ class LoweringRuleContext:
tokens_out: Optional[TokenSet] # Mutable store for output containers
axis_size_env: Optional[Dict[core.Var, ir.Value]] = None # Dynamic axis sizes
dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables
# in same order as module_context.dim_vars
# in same order as module_context.shape_poly_state.dim_vars

def set_tokens_out(self, tokens_out: TokenSet):
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
Expand Down Expand Up @@ -535,9 +547,9 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
else:
ctx = ctx.replace(
primitive="eval_dynamic_shape",
avals_in=[core.dim_value_aval()] * len(ctx.module_context.dim_vars))
avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars))
res = lower_fun(
partial(core.evaluate_shape, shape, ctx.module_context.dim_vars),
partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars),
multiple_results=True)(ctx, *ctx.dim_var_values)
return util.flatten(res) # type: ignore

Expand All @@ -546,6 +558,7 @@ class LoweringResult(NamedTuple):
module: ir.Module
keepalive: Optional[Any]
host_callbacks: List[Any]
shape_poly_state: ShapePolyLoweringState


_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
Expand Down Expand Up @@ -628,7 +641,8 @@ def lower_jaxpr_to_module(
if result_shardings is not None else result_shardings)

ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
keepalives, channel_iter, host_callbacks, dim_vars=dim_vars)
keepalives, channel_iter, host_callbacks,
shape_poly_state=ShapePolyLoweringState(dim_vars))
with ctx.context, ir.Location.unknown(ctx.context):
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
Expand Down Expand Up @@ -658,7 +672,8 @@ def lower_jaxpr_to_module(
raise ValueError(
f"Cannot lower jaxpr with verifier errors: {module_string}") from e

return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks)
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
ctx.shape_poly_state)

def module_to_string(module: ir.Module) -> str:
output = io.StringIO()
Expand Down Expand Up @@ -805,7 +820,7 @@ def aval_to_types(aval):
aval = core.ShapedArray((), np.dtype(np.bool_))
return aval_to_ir_types(aval)

num_dim_vars = len(ctx.dim_vars)
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
dim_var_types = map(aval_to_types, dim_var_avals)

Expand Down Expand Up @@ -1006,7 +1021,7 @@ def _to_physical_op_sharding(
def _emit_lowering_rule_as_fun(lowering_rule,
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
"""Emits the contents of a lowering rule as a private function."""
num_dim_vars = len(ctx.module_context.dim_vars)
num_dim_vars = len(ctx.module_context.shape_poly_state.dim_vars)
# TODO(necula) maybe only pass the dim_vars if they are needed?
dim_var_types = map(aval_to_ir_types, [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars)

Expand Down Expand Up @@ -1049,7 +1064,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
Assumes that an MLIR context, location, and insertion point are set.
dim_var_values: the list of dimension variables values in the current
IR function, in the order of ctx.dim_vars.
IR function, in the order of ctx.shape_poly_state.dim_vars.
"""
assert ctx.platform != "gpu"
def read(v: core.Atom) -> Sequence[ir.Value]:
Expand All @@ -1075,7 +1090,7 @@ def write(v: core.Var, node: Sequence[ir.Value]):
assert len(args) == len(jaxpr.invars), (jaxpr, args)
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
assert len(ctx.dim_vars) == len(dim_var_values), (ctx.dim_vars, dim_var_values)
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
Expand Down
12 changes: 8 additions & 4 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -1945,7 +1945,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
return (lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args)
nreps, tuple_args, lowering_result.shape_poly_state)


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -2080,7 +2080,7 @@ def lower_sharding_computation(
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args) = _cached_lowering_to_hlo(
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, da_object, lowering_platform,
donated_invars, name_stack)
Expand Down Expand Up @@ -2111,7 +2111,8 @@ def lower_sharding_computation(
device_assignment=da_object,
committed=committed,
pmap_nreps=nreps,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=shape_poly_state)


def _to_logical_sharding(
Expand Down Expand Up @@ -2285,7 +2286,8 @@ def lower_mesh_computation(
backend=backend,
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
committed=True,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=lowering_result.shape_poly_state)

class MeshComputation(stages.XlaLowering):
_hlo: Optional[ir.Module]
Expand Down Expand Up @@ -2617,8 +2619,10 @@ def from_hlo(name: str,
committed: bool,
pmap_nreps: int = 1,
jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None,
shape_poly_state: Optional[mlir.ShapePolyLoweringState] = None,
compiler_options=None
) -> MeshExecutable:
del shape_poly_state
compiler_options_keys = tuple(
compiler_options.keys()) if compiler_options is not None else None
compiler_options_values = tuple(
Expand Down
19 changes: 13 additions & 6 deletions jax/experimental/jax2tf/jax_export.py
Expand Up @@ -76,6 +76,9 @@ class Exported:
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
must be passed to the module. The other arguments have been dropped
because they are not used. Same length as `in_shardings`.
module_uses_dim_vars: whether the `mlir_module_serialized` uses shape
polymorphic dimension variables. This may be from `in_avals` but also
from inner calls of Exported modules.
strict_checks: whether the module was serialized with the following safety
checking: (A) the lowered computation can only be executed on a platform
for which it was lowered; (B) the serialized computation contains only
Expand All @@ -101,6 +104,7 @@ class Exported:
mlir_module_serialized: bytes
xla_call_module_version: int
module_kept_var_idx: Tuple[int, ...]
module_uses_dim_vars: bool

_get_vjp: Optional[Callable[["Exported"], "Exported"]]

Expand Down Expand Up @@ -264,10 +268,9 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
else:
# For pmap
module_kept_var_idx = tuple(range(len(args_avals_flat)))

if not all(
core.is_constant_shape(a.shape) for a in args_avals_flat
) or lowering.compile_args.get("ordered_effects", []):
shape_poly_state = lowering.compile_args["shape_poly_state"]
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
or lowering.compile_args.get("ordered_effects", [])):
# All arguments are kept if we have dimension variables.
assert len(module_kept_var_idx) == len(args_avals_flat)
mlir_module = _wrap_main_func(
Expand Down Expand Up @@ -334,6 +337,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
strict_checks=strict_checks,
mlir_module_serialized=mlir_module_serialized,
module_kept_var_idx=module_kept_var_idx,
module_uses_dim_vars=shape_poly_state.uses_dim_vars,
xla_call_module_version=xla_call_module_version,
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported))

Expand Down Expand Up @@ -387,7 +391,6 @@ def _wrap_main_func(
Returns the wrapped module.
"""
dim_vars = shape_poly.all_dim_vars(args_avals_flat)

# Make a new module, do not mutate the "module" because it may be cached
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
Expand Down Expand Up @@ -512,7 +515,7 @@ def _check_lowering(lowering) -> None:
"spmd_lowering", "auto_spmd_lowering",
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info"]
"device_assignment", "jaxpr_debug_info", "shape_poly_state"]
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
Expand All @@ -538,6 +541,7 @@ def _check_lowering(lowering) -> None:
# used on all platforms for callbacks. Not supported yet.
("keepalive", lambda v: not v, "empty"),
("pmap_nreps", lambda v: v == 1, "1"),
("shape_poly_state", lambda v: True, "N/A"),
):
if compile_arg in lowering.compile_args:
if not check_value(lowering.compile_args[compile_arg]):
Expand Down Expand Up @@ -810,6 +814,9 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
f"The exported function '{exported.fun_name}' was lowered for "
f"platform '{exported.lowering_platform}' but it is used "
f"on '{platform}'.")
if any(not core.is_constant_shape(a.shape) for a in exported.in_avals):
ctx.module_context.shape_poly_state.uses_dim_vars = True

submodule = ir.Module.parse(exported.mlir_module)
symtab = ir.SymbolTable(submodule.operation)
# The called function may have been exported with polymorphic shapes and called
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Expand Up @@ -275,6 +275,8 @@ def _get_vjp(_):
mlir_module_serialized=data.mlir_module_serialized,
xla_call_module_version=data.xla_call_module_version,
module_kept_var_idx=tuple(range(len(in_avals))),
module_uses_dim_vars=any(not core.is_constant_shape(a.shape)
for a in in_avals),
_get_vjp=_get_vjp)

# We use pjit in case there are shardings in the exported module.
Expand Down
19 changes: 13 additions & 6 deletions jax/experimental/jax2tf/tests/jax_export_test.py
Expand Up @@ -264,6 +264,8 @@ def inner(x): # x: inner_poly_spec
inner_exp = jax_export.export(inner)(
jax_export.poly_spec(inner_x.shape, inner_x.dtype, inner_poly_spec))

self.assertEqual(inner_exp.module_uses_dim_vars,
(inner_poly_spec != "3,4,12"))
outer_x = np.arange(np.prod(outer_x_shape),
dtype=np.float32).reshape(outer_x_shape) # outer_x : f32[3,4,12]
def outer(x): # x: outer_poly_spec
Expand All @@ -278,12 +280,17 @@ def outer(x): # x: outer_poly_spec
# Call it after exporting again, with polymorphic shapes
outer_exp = jax_export.export(outer)(
jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec))
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
# until we create the python bindings to invoke shape refinement.
if jax2tf is not None:
res2 = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
# res2 = jax_export.call_exported(exp2)(x2)
self.assertAllClose(2. * inner(outer_x), res2)
self.assertEqual(outer_exp.module_uses_dim_vars,
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
if not outer_exp.module_uses_dim_vars:
res = jax_export.call_exported(outer_exp)(outer_x)
self.assertAllClose(2. * inner(outer_x), res)
else:
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
# until we create the python bindings to invoke shape refinement.
if jax2tf is not None:
res = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
self.assertAllClose(2. * inner(outer_x), res)


if __name__ == "__main__":
Expand Down

0 comments on commit 5cbc38d

Please sign in to comment.