Skip to content

Commit

Permalink
Move functions out of xla.py closer to their users.
Browse files Browse the repository at this point in the history
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
  • Loading branch information
hawkinsp authored and jax authors committed Aug 8, 2023
1 parent d01695c commit ca17b6c
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 139 deletions.
25 changes: 22 additions & 3 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -333,14 +333,21 @@ def _token_constant_handler(val, canonicalize_types):

# Source locations

def get_canonical_source_file(frame: source_info_util.Frame) -> str:
source_file = frame.file_name
if config.jax_hlo_source_file_canonicalization_regex:
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
'', source_file)
return source_file

def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
"""Converts a full traceback to a callsite() MLIR location."""
frame_locs = []
for code, lasti in zip(*tb.raw_frames()):
frame = source_info_util.raw_frame_to_frame(code, lasti)
if source_info_util.is_user_filename(frame.file_name):
file_loc = ir.Location.file(
xla.get_canonical_source_file(frame),
get_canonical_source_file(frame),
frame.start_line,
frame.start_column,
)
Expand Down Expand Up @@ -371,7 +378,7 @@ def _source_info_to_location(
if frame is None:
loc = ir.Location.unknown()
else:
loc = ir.Location.file(xla.get_canonical_source_file(frame),
loc = ir.Location.file(get_canonical_source_file(frame),
frame.start_line, frame.start_column)
loc = ir.Location.name(eqn_str, childLoc=loc)
# TODO(phawkins): also include primitive.name as the operator type.
Expand Down Expand Up @@ -1383,13 +1390,25 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
return func_op


def check_backend_matches(inner_backend, outer_backend):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
if inner_backend is None:
return
if (inner_backend != outer_backend and
outer_backend not in xb.expand_platform_alias(inner_backend)):
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")


def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
avals_out, tokens_in, *args,
dim_var_values: Sequence[ir.Value],
arg_names=None, result_names=None):
if isinstance(call_jaxpr, core.Jaxpr):
call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
xla.check_backend_matches(backend, ctx.platform)
check_backend_matches(backend, ctx.platform)
effects = list(tokens_in.effects())
output_types = map(aval_to_ir_types, avals_out)
output_types = [token_type()] * len(effects) + output_types
Expand Down
79 changes: 72 additions & 7 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -1235,15 +1235,43 @@ def _pmap_dce_rule(used_outputs, eqn):
return used_inputs, new_eqn


def _xla_call_partial_eval_update_params(
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
) -> core.ParamDict:
donated_invars = params['donated_invars']
if not kept_inputs and donated_invars:
# JaxprTrace.post_process_call creates a call with no input tracers
donated_invars = (False,) * num_new_inputs
else:
assert len(kept_inputs) == len(donated_invars)
# JaxprTrace.process_call drops known input tracers
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
# Any new inputs are prepended to the left, so mark those as not donated.
donated_invars = [False] * num_new_inputs + donated_invars
return dict(params, donated_invars=tuple(donated_invars))

def xla_call_jvp_update_params(params, nz_tangents):
donated_invars = params['donated_invars']
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)

def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
donated_cotangents = [False for nz in nonzero_cts if nz]
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))


# Set param update handlers to update `donated_invars` just like xla_call_p
pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params
pe.call_param_updaters[xla_pmap_p] = _xla_call_partial_eval_update_params
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
partial(pe.call_partial_eval_custom_rule,
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
res_aval=_pmap_partial_eval_custom_res_maker)
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params
ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params
ad.call_param_updaters[xla_pmap_p] = xla_call_jvp_update_params
ad.call_transpose_param_updaters[xla_pmap_p] = _xla_call_transpose_update_params

ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)

Expand Down Expand Up @@ -1289,6 +1317,38 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
raise TypeError(aval)


def _axis_read(axis_env, axis_name):
try:
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
except ValueError:
raise NameError(f"unbound axis name: {axis_name}") from None

def axis_groups(axis_env: sharding_impls.AxisEnv, name) -> tuple[tuple[int, ...]]:
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(_axis_read, axis_env), name))
trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes))
assert not ragged
mesh_spec = axis_env.sizes + (trailing_size,)
return _axis_groups(mesh_spec, mesh_axes)

def _axis_groups(mesh_spec, mesh_axes):
"""Computes replica group ids for a collective performed over a subset of the mesh.
Args:
mesh_spec: A sequence of integers representing the mesh shape.
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
indicating over which axes the collective is performed.
Returns:
A tuple of replica groups (i.e. tuples containing replica ids).
"""
iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(math.prod(np.take(mesh_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))


# TODO(b/110096942): more efficient gather
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
if aval is core.abstract_token:
Expand All @@ -1311,7 +1371,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
x, mlir.dense_int_elements([1])).result
padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
replica_groups = mlir.dense_int_elements(
xla.axis_groups(axis_env, axis_env.names[-1]))
axis_groups(axis_env, axis_env.names[-1]))
out = hlo.CrossReplicaSumOp(padded, replica_groups).result
if out_axis != 0:
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
Expand All @@ -1335,18 +1395,23 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
raise TypeError(aval)


def _extend_axis_env(env: sharding_impls.AxisEnv, name, size: int):
return sharding_impls.AxisEnv(env.nreps, env.names + (name,),
env.sizes + (size,))


def _pmap_lowering(ctx, *in_nodes, axis_name,
axis_size, global_axis_size, devices, name,
call_jaxpr, backend=None, in_axes, out_axes,
donated_invars, is_explicit_global_axis_size):
del donated_invars # Unused.
xla.check_backend_matches(backend, ctx.module_context.platform)
mlir.check_backend_matches(backend, ctx.module_context.platform)
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
if ctx.module_context.axis_env.names and devices is not None:
raise ValueError("Nested pmap with explicit devices argument.")
new_env = xla.extend_axis_env(ctx.module_context.axis_env, axis_name,
global_axis_size)
new_env = _extend_axis_env(ctx.module_context.axis_env, axis_name,
global_axis_size)
# Shard the in_nodes that are mapped
in_avals = [v.aval for v in call_jaxpr.invars]
in_nodes_sharded = (
Expand Down
106 changes: 3 additions & 103 deletions jax/_src/interpreters/xla.py
Expand Up @@ -20,15 +20,11 @@
import functools
from functools import partial
import itertools as it
import math
import operator
import re
from typing import Any, Callable, Optional, Protocol, Union

import numpy as np

from jax._src.config import config

from jax._src import core
from jax._src import dtypes
from jax._src import source_info_util
Expand Down Expand Up @@ -59,13 +55,6 @@ def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]:
dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype
return (xc.Shape.array_shape(dtype, aval.shape),)

def get_canonical_source_file(frame: source_info_util.Frame):
source_file = frame.file_name
if config.jax_hlo_source_file_canonicalization_regex:
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
'', source_file)
return source_file

# Utilities

def parameter(builder, num, shape, name=None, replicated=None):
Expand Down Expand Up @@ -121,18 +110,6 @@ def tuple_sharding_proto(elems):
return proto


def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()

def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, sharding_to_proto(sharding), op_fn, *args,
**kwargs)


### handlers
Expand All @@ -141,16 +118,16 @@ def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):

def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
try:
return xla_shape_handlers[type(aval)](aval)
return _xla_shape_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err

xla_shape_handlers: dict[type[core.AbstractValue],
_xla_shape_handlers: dict[type[core.AbstractValue],
Callable[[Any], Sequence[xc.Shape]]] = {
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
}
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)


# IR constants
Expand Down Expand Up @@ -270,52 +247,6 @@ def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]

def check_backend_matches(inner_backend, outer_backend):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
if inner_backend is None:
return
if (inner_backend != outer_backend and
outer_backend not in xb.expand_platform_alias(inner_backend)):
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")


def extend_axis_env(env: AxisEnv, name, size: int):
return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))

def axis_read(axis_env, axis_name):
try:
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
except ValueError:
raise NameError(f"unbound axis name: {axis_name}") from None

def axis_groups(axis_env: AxisEnv, name) -> tuple[tuple[int, ...]]:
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes))
assert not ragged
mesh_spec = axis_env.sizes + (trailing_size,)
return _axis_groups(mesh_spec, mesh_axes)

def _axis_groups(mesh_spec, mesh_axes):
"""Computes replica group ids for a collective performed over a subset of the mesh.
Args:
mesh_spec: A sequence of integers representing the mesh shape.
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
indicating over which axes the collective is performed.
Returns:
A tuple of replica groups (i.e. tuples containing replica ids).
"""
iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(math.prod(np.take(mesh_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))


# TODO(mattjj,skyewm): the functions here are utilities for checking if
# not-yet-supported features are used with multi-host programming
Expand All @@ -329,37 +260,6 @@ def jaxpr_collectives(jaxpr):
for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)


### xla_call underlying jit


def xla_call_partial_eval_update_params(
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
) -> core.ParamDict:
donated_invars = params['donated_invars']
if not kept_inputs and donated_invars:
# JaxprTrace.post_process_call creates a call with no input tracers
donated_invars = (False,) * num_new_inputs
else:
assert len(kept_inputs) == len(donated_invars)
# JaxprTrace.process_call drops known input tracers
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
# Any new inputs are prepended to the left, so mark those as not donated.
donated_invars = [False] * num_new_inputs + donated_invars
return dict(params, donated_invars=tuple(donated_invars))

def xla_call_jvp_update_params(params, nz_tangents):
donated_invars = params['donated_invars']
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)

def xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
donated_cotangents = [False for nz in nonzero_cts if nz]
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))


### translation tables

MYPY = False
Expand Down
3 changes: 0 additions & 3 deletions jax/_src/lax/lax.py
Expand Up @@ -4184,9 +4184,6 @@ def _top_k_batch_rule(batched_args, batch_dims, *, k):
else:
return top_k(operand, k=k), (bdim, bdim)

def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
return xla.xla_destructure(ctx.builder, xops.TopK(x, k))

top_k_p = Primitive('top_k')
top_k_p.multiple_results = True
top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p))
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/parallel.py
Expand Up @@ -690,7 +690,7 @@ def _batched_reduction_collective(
return vals_out, [batching.not_mapped] * len(vals_out)

def _replica_groups(axis_env, axis_name, axis_index_groups):
replica_groups = xla.axis_groups(axis_env, axis_name)
replica_groups = pxla.axis_groups(axis_env, axis_name)
if axis_index_groups is not None:
replica_groups = [[axis_group[i] for i in axis_index_group]
for axis_group in replica_groups
Expand Down

0 comments on commit ca17b6c

Please sign in to comment.