Skip to content

Commit

Permalink
Introduce a LoweringParameters dataclass for easier plumbing
Browse files Browse the repository at this point in the history
There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.

We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.

Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
  • Loading branch information
gnecula committed Sep 29, 2023
1 parent 3247db7 commit 552fef6
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 102 deletions.
9 changes: 5 additions & 4 deletions jax/_src/api.py
Expand Up @@ -582,7 +582,8 @@ def computation_maker(*args, **kwargs):
wrap_name(fun_name, "xla_computation")),
donated_args=donated_invars,
arg_shardings=None,
result_shardings=None)
result_shardings=None,
lowering_parameters=mlir.LoweringParameters())
built = xc._xla.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(lowering_result.module),
use_tuple_args=tuple_args,
Expand Down Expand Up @@ -1904,8 +1905,8 @@ def lower(*args, **kwargs) -> stages.Lowered:
Returns:
A ``Lowered`` instance representing the post-map lowering.
"""
_experimental_lowering_platform = kwargs.pop(
'_experimental_lowering_platform', None)
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
devices, backend, axis_size, args, kwargs)
Expand All @@ -1920,7 +1921,7 @@ def lower(*args, **kwargs) -> stages.Lowered:
donated_invars=p.donated_invars,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args,
lowering_platform=_experimental_lowering_platform)
lowering_parameters=lowering_parameters)
return stages.Lowered.from_flat_info(
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())

Expand Down
8 changes: 5 additions & 3 deletions jax/_src/dispatch.py
Expand Up @@ -149,7 +149,7 @@ def prim_fun(*args):
computation = sharded_lowering(
flat_fun, prim.name, donated_invars, keep_unused=False,
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
lowering_platform=None)
lowering_parameters=mlir.LoweringParameters())
compiled = computation.compile()
if xla_extension_version >= 192:
if config.jax_disable_jit:
Expand All @@ -169,7 +169,8 @@ def prim_fun(*args):
def sharded_lowering(
fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool],
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
in_shardings: Sequence[Sharding | None], lowering_platform: str | None
in_shardings: Sequence[Sharding | None],
lowering_parameters: mlir.LoweringParameters
) -> pxla.MeshComputation:
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]

Expand All @@ -179,7 +180,8 @@ def sharded_lowering(
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
in_avals, keep_unused=keep_unused, inline=inline,
devices_from_context=None, lowering_platform=lowering_platform)
devices_from_context=None,
lowering_parameters=lowering_parameters)


def simple_impl(prim):
Expand Down
100 changes: 62 additions & 38 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -390,12 +390,6 @@ def make_ir_context() -> ir.Context:
]

class ShapePolyLoweringState:
# The current lowering platforms, a non-empty tuple containing some of
# 'cpu', 'cuda', 'rocm', 'tpu'.
# TODO: this state should be in ModuleContext, but since for now
# multi-platform lowering is implemented only for jax_export, like shape
# polymorphism, we keep it here.
lowering_platforms: tuple[str, ...]
# The names of the dimension variables, sorted by name. This is the order in
# which they are passed to the IR functions that need them. This is only
# used for native serialization with polymorphic shapes when
Expand All @@ -410,17 +404,48 @@ class ShapePolyLoweringState:
# from an inner call to a polymorphic Exported.
uses_dim_vars: bool

def __init__(self, dim_vars: tuple[str, ...],
lowering_platforms: tuple[str, ...]):
self.lowering_platforms = lowering_platforms
# If the first dimension variable is a platform index argument
has_platform_index_argument: bool

def __init__(self,
dim_vars: tuple[str, ...],
lowering_platforms: tuple[str, ...] | None):
self.uses_dim_vars = (len(dim_vars) > 0)
if len(lowering_platforms) > 1:
if lowering_platforms is not None and len(lowering_platforms) > 1:
dim_vars = ("platform_index_",) + tuple(dim_vars)
self.has_platform_index_argument = True
else:
self.has_platform_index_argument = False
self.dim_vars = dim_vars


@dataclasses.dataclass(frozen=True)
class LoweringParameters:
# A mapping between primitives and user-defined LoweringRules.
# When lowering a primitive, give priorioty to the rule in this map over
# existing Jax rules.
override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None

# The current lowering platforms, a non-empty tuple containing some of
# 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are
# doing multi-platform lowering, otherwise it can specify cross-platform
# lowering. The value None specify default lowering platform.
# This is used only in export and jax2tf.
platforms: tuple[str, ...] | None = None

@property
def has_platform_index_argument(self):
return len(self.lowering_platforms) > 1
def override_platform(self) -> str | None:
"""Overrides the lowering platform for cross-platform lowering.
One of 'cpu', 'cuda', 'rocm', 'tpu'.
If None, use the default JAX mechanisms to pick the lowering platform.
This is currently used for export and jax2tf.
"""
if self.platforms is not None:
return self.platforms[0]
else:
return None


@dataclasses.dataclass
class ModuleContext:
Expand All @@ -443,24 +468,23 @@ class ModuleContext:
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp]

# A mapping between primitives and user-defined LoweringRules.
# When lowering a primitive, give priorioty to the rule in this map over
# existing Jax rules.
override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None
lowering_parameters: LoweringParameters

@property
def axis_env(self) -> sharding_impls.AxisEnv:
return self.axis_context.axis_env

def __init__(
self,
*,
backend_or_name: str | xb.XlaBackend | None,
platform: str,
axis_context: AxisContext,
name_stack: source_info_util.NameStack,
keepalives: list[Any],
channel_iterator: Iterator[int],
host_callbacks: list[Any],
lowering_parameters: LoweringParameters,
context: ir.Context | None = None,
module: ir.Module | None = None,
ip: ir.InsertionPoint | None = None,
Expand All @@ -469,8 +493,6 @@ def __init__(
func_dialect.FuncOp]) = None,
cached_call_jaxpr_lowerings: None | (dict[Any,
func_dialect.FuncOp]) = None,
override_lowering_rules: None | (
tuple[tuple[core.Primitive, LoweringRule]]) = None,
shape_poly_state = None):
assert platform is not None
self.context = context or make_ir_context()
Expand All @@ -489,9 +511,9 @@ def __init__(
self.cached_call_jaxpr_lowerings = ({}
if cached_call_jaxpr_lowerings is None
else cached_call_jaxpr_lowerings)
self.override_lowering_rules = override_lowering_rules
self.shape_poly_state = shape_poly_state or ShapePolyLoweringState((),
(platform,))
self.lowering_parameters = lowering_parameters

@property
def backend(self) -> xb.XlaBackend:
Expand Down Expand Up @@ -664,6 +686,7 @@ def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
def lower_jaxpr_to_module(
module_name: str,
jaxpr: core.ClosedJaxpr,
*,
ordered_effects: list[core.Effect],
backend_or_name: str | xb.XlaBackend | None,
platform: str | tuple[str, ...],
Expand All @@ -678,24 +701,19 @@ def lower_jaxpr_to_module(
num_replicas: int = 1,
num_partitions: int = 1,
all_default_mem_kind: bool = True,
override_lowering_rules: None | (
tuple[tuple[core.Primitive, LoweringRule]]) = None,
lowering_parameters: LoweringParameters,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
# TODO(necula): for now we receive the tuple of lowering platforms through
# the `platform` arg. For now we lower only for the first specified platform
# TODO(necula): change to "platforms" here and elsewhere.
if isinstance(platform, str):
platforms = (platform,)
else:
platforms = tuple(platform) # type: ignore
platform = platform[0]
if lowering_parameters.platforms is not None:
# Only for multi-platform lowering
# TODO(necula): for now we lower only for the first platform
platform = lowering_parameters.platforms[0]

platform = xb.canonicalize_platform(platform)
platform = xb.canonicalize_platform(platform) # type: ignore
if not xb.is_known_platform(platform):
raise ValueError(f"Unknown platform {platform}")
input_output_aliases = None
Expand Down Expand Up @@ -750,11 +768,16 @@ def lower_jaxpr_to_module(
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings)

ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
keepalives, channel_iter, host_callbacks,
override_lowering_rules=override_lowering_rules,
shape_poly_state=ShapePolyLoweringState(dim_vars,
platforms))
ctx = ModuleContext(backend_or_name=backend_or_name,
platform=platform, axis_context=axis_context,
name_stack=name_stack,
keepalives=keepalives,
channel_iterator=channel_iter,
host_callbacks=host_callbacks,
lowering_parameters=lowering_parameters,
shape_poly_state=ShapePolyLoweringState(
dim_vars,
lowering_parameters.platforms))
with ctx.context, ir.Location.unknown(ctx.context):
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
Expand Down Expand Up @@ -1292,9 +1315,9 @@ def write(v: core.Var, node: Sequence[ir.Value]):
env[v] = tuple(node)

def get_lowering(primitive: core.Primitive) -> LoweringRule | None:
if ctx.override_lowering_rules is None:
if ctx.lowering_parameters.override_lowering_rules is None:
return None
for p, rule in ctx.override_lowering_rules:
for p, rule in ctx.lowering_parameters.override_lowering_rules:
if primitive is p:
return rule
return None
Expand Down Expand Up @@ -2187,7 +2210,8 @@ def build_xla_computation_helper(
backend_or_name=backend_or_name, ordered_effects=[],
name_stack=source_info_util.NameStack(),
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
axis_context=axis_context, platform=platform)
axis_context=axis_context, platform=platform,
lowering_parameters=LoweringParameters())
return xc._xla.mlir.mlir_module_to_xla_computation(
module_to_string(lowering_result.module), use_tuple_args=False,
return_tuple=False)
Expand Down

0 comments on commit 552fef6

Please sign in to comment.