Skip to content

Commit

Permalink
[Pallas] Automatically turn mesh indices -> physical ids for remote DMAs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570221510
  • Loading branch information
sharadmv authored and jax authors committed Oct 3, 2023
1 parent 17d89ad commit 1c796c0
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 43 deletions.
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/__init__.py
Expand Up @@ -22,6 +22,7 @@
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic.lowering import LoweringException
from jax._src.pallas.mosaic.primitives import DeviceIdType
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import device_id
Expand Down
106 changes: 93 additions & 13 deletions jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -25,6 +25,7 @@
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import pjit
from jax._src import source_info_util
from jax._src import state
Expand Down Expand Up @@ -68,13 +69,21 @@
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin


@dataclasses.dataclass
class MeshContext:
logical_to_mesh: ir.Value
axis_names: tuple[str, ...]
mesh_strides: tuple[int, ...]


@dataclasses.dataclass
class LoweringContext:
ir_context: ir.Context
grid_mapping: core.GridMapping | None
grid_indices: Sequence[ir.Value] | None
block_shapes: list[tuple[int | core.Mapped, ...]]
name_stack: source_info_util.NameStack
mesh_context: MeshContext | None
replace = dataclasses.replace


Expand Down Expand Up @@ -144,21 +153,44 @@ def lower_jaxpr_to_module(
grid_mapping: core.GridMapping,
jaxpr: jax_core.Jaxpr,
dimension_semantics: tuple[str | None, ...] | None,
mesh: mesh_lib.Mesh | None = None
) -> ir.Module:
m = ir.Module.create()
sym_tab = ir.SymbolTable(m.operation)
used_axis_names = jax_core.used_axis_names_jaxpr(jaxpr)
mesh_info = None
if used_axis_names:
axis_names = list(used_axis_names)
if mesh is None:
raise ValueError("Cannot use axis names in pallas_call without shard_map."
)
# We need mesh <-> logical translation tables. Since the logical IDs are
# just linearized versions of the mesh IDs, we create those tables.
mesh_strides = pallas_utils.strides_from_shape(tuple(
mesh.shape[a] for a in axis_names
))
logical_to_mesh = np.empty((mesh.size, len(axis_names)), dtype=np.int32)
for i, idx in enumerate(np.ndindex(*mesh.device_ids.shape)):
logical_to_mesh[i] = np.array(idx)
mesh_info = (logical_to_mesh, tuple(axis_names), mesh_strides)
extra_args = mesh_info[:1] if mesh_info else ()
if not grid_mapping.grid:
# Trivial grid-map, we don't need to populate the transform functions.
func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping,
name="main")
name="main", mesh_info=mesh_info)
m.body.append(func_op)
sym_tab.insert(func_op)
return m
return m, extra_args
func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping,
name="main")
name="main", mesh_info=mesh_info)
m.body.append(func_op)
sym_tab.insert(func_op)
num_smem_inputs = grid_mapping.num_index_operands
if mesh_info:
# We are now passing in the logical -> mesh index mapping
# TODO(sharadmv,apaszke): avoid stalling pipeline by marking the index
# mapping as scalar prefetch and instead just mark it as an SMEM operand.
num_smem_inputs += 1
window_params = []
grid = grid_mapping.grid
for i, bm in enumerate(grid_mapping.block_mappings):
Expand Down Expand Up @@ -215,7 +247,7 @@ def _get_semantics(s: str | None) -> str:
func_op.attributes["dimension_semantics"] = ir.ArrayAttr.get(
map(ir.Attribute.parse, func_dimension_semantics)
)
return m
return m, extra_args


def lower_jaxpr_to_transform_func(
Expand All @@ -225,7 +257,8 @@ def lower_jaxpr_to_transform_func(
arg_types = [*map(aval_to_ir_type, [invar.aval for invar in jaxpr.invars],
block_shapes, memspaces)]
lowering_context = LoweringContext(
ctx, None, None, block_shapes, source_info_util.NameStack())
ctx, None, None, block_shapes, source_info_util.NameStack(),
mesh_context=None)
body_func = functools.partial(jaxpr_subcomp, lowering_context, jaxpr)
body_func.__name__ = name
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
Expand Down Expand Up @@ -257,6 +290,7 @@ def lower_jaxpr_to_func(
*,
grid_mapping: core.GridMapping | None,
name: str,
mesh_info: Any,
) -> func.FuncOp:
memory_spaces = [None if bm is None else bm.memory_space
for bm in grid_mapping.block_mappings]
Expand All @@ -277,6 +311,13 @@ def _get_arg_type(aval, block_mapping: core.BlockMapping | None,
)
return (aval_to_ir_type(aval, shape=shape, memory_space=memory_space),
block_mapping.block_shape)

if mesh_info is not None:
l_to_m, axis_names, mesh_strides = mesh_info
l_to_m_aval = state.AbstractRef(
jax_core.raise_to_shaped(jax_core.get_aval(l_to_m)))
arg_types.append(_get_arg_type(l_to_m_aval, None, SMEM)[0])

if grid_mapping is None:
block_mappings = [None] * len(jaxpr.invars)
else:
Expand All @@ -300,12 +341,18 @@ def body_func(*args):
for i, g in enumerate(grid_indices)
if i not in grid_mapping.mapped_dims
]
if mesh_info is not None:
(l_to_m,), args = split_list(args, [1])
mesh_context = MeshContext(l_to_m, axis_names, mesh_strides)
else:
mesh_context = None
lowering_context = LoweringContext(
ctx,
grid_mapping,
tuple(grid_indices),
block_shapes,
source_info_util.NameStack(),
mesh_context=mesh_context,
)
return jaxpr_subcomp(lowering_context, jaxpr, *args)

Expand Down Expand Up @@ -1510,12 +1557,35 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr):

lowering_rules[tpu_primitives.run_scoped_p] = _run_scoped_lowering_rule

def _device_id_to_logical(
ctx: LoweringRuleContext, device_id,
device_id_type: tpu_primitives.DeviceIdType):
if device_id_type is tpu_primitives.DeviceIdType.MESH:
# Mesh means we are passed the mesh coordinates for the device
device_ids = tree_util.tree_leaves(device_id)
mesh_strides = ctx.lowering_context.mesh_context.mesh_strides
def _linearize_mesh_indices(*indices):
return sum([a * b for a, b in zip(indices, mesh_strides)])
lower_ctx = LoweringRuleContext(
lowering_context=ctx.lowering_context,
avals_in=[jax_core.ShapedArray((), jnp.int32)] * len(device_ids),
avals_out=[jax_core.ShapedArray((), jnp.int32)],
block_shapes=(None,) * len(device_ids),
)
return lower_fun(_linearize_mesh_indices, multiple_results=False)(
lower_ctx, *device_ids)
elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL:
return device_id
raise NotImplementedError(f"Unsupported device id type: {device_id_type}")

def _semaphore_signal_lowering_rule(ctx: LoweringRuleContext, semaphore,
value, *args, has_device_id: bool):
value, *args, has_device_id: bool,
device_id_type: tpu_primitives.DeviceIdType):
device_id = None
assert semaphore.type == ir.Type.parse("!tpu.semaphore")
if has_device_id:
(device_id,) = args
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
return tpu.SemaphoreSignalOp(semaphore, value, device_id=device_id).results
lowering_rules[tpu_primitives.semaphore_signal_p] = (
_semaphore_signal_lowering_rule)
Expand All @@ -1541,12 +1611,11 @@ def _indexer_to_start_size(indexer: NDIndexer):
partial(_ensure_mlir_value, aval=jax_core.ShapedArray((), jnp.int32)),
starts,
)
sizes = [
s.size if isinstance(s, primitives.Slice) else 1 for s in indexer.indices
]
sizes = indexer.get_indexer_shape()
return starts, sizes

def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree):
def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
device_id_type: tpu_primitives.DeviceIdType):
(src_ref, src_idx, dst_ref, dst_idx, sem, src_sem, device_id) = (
tree_util.tree_unflatten(tree, args)
)
Expand All @@ -1561,13 +1630,16 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree):
memory_space=dst_ref.type.memory_space)
src = tpu.MemRefSliceOp(src_ref_ty, src_ref, src_starts).result
dst = tpu.MemRefSliceOp(dst_ref_ty, dst_ref, dst_starts).result
return tpu.EnqueueDMAOp(source=src, target=dst, target_semaphore=sem,
source_semaphore=src_sem,
if device_id is not None:
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
return tpu.EnqueueDMAOp(src, dst, sem, source_semaphore=src_sem,
device_id=device_id).results
lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule


def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree):
def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
device_id_type: tpu_primitives.DeviceIdType):
del device_id_type
sem, ref, idx = tree_util.tree_unflatten(tree, args)
starts, sizes = _indexer_to_start_size(idx)
ref_ty = ir.MemRefType.get(
Expand All @@ -1580,3 +1652,11 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree):
def _device_id_lowering_rule(ctx: LoweringRuleContext):
return tpu.DeviceIdOp().result
lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule

def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: str):
device_id = _make_index(tpu.DeviceIdOp().result)
l_to_m = ctx.lowering_context.mesh_context.logical_to_mesh
axis_names = ctx.lowering_context.mesh_context.axis_names
col = _make_index(axis_names.index(axis_name))
return memref.LoadOp(l_to_m, [device_id, col]).result
lowering_rules[lax.axis_index_p] = _axis_index_rule
32 changes: 19 additions & 13 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Expand Up @@ -56,6 +56,11 @@ def pallas_call_tpu_lowering_rule(
grid_mapping=grid_mapping, **compiler_params)
if debug:
print(jaxpr)
mesh = None
axis_context = ctx.module_context.axis_context
if axis_context is not None:
if isinstance(axis_context, mlir.SPMDAxisContext):
mesh = axis_context.mesh
with ir.Context() as mlir_ctx, ir.Location.unknown(mlir_ctx):
tpu.register_dialect(mlir_ctx)
if mosaic_params is None:
Expand All @@ -64,21 +69,22 @@ def pallas_call_tpu_lowering_rule(
kernel_regeneration_metadata = mosaic_params.get(
"kernel_regeneration_metadata"
)
mosaic_module = lowering.lower_jaxpr_to_module(
mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics)
mosaic_module, extra_args = lowering.lower_jaxpr_to_module(
mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics,
mesh=mesh)
if debug:
print(mosaic_module)
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
return mlir.lower_fun(
mosaic.as_tpu_kernel(
mosaic_module,
out_avals,
backend=ctx.module_context.backend,
kernel_name=name,
kernel_regeneration_metadata=kernel_regeneration_metadata,
cost_estimate=mosaic_params.get('cost_estimate', None),
),
multiple_results=True,
)(ctx, *in_nodes)
def _lower_fun(*args):
return mosaic.as_tpu_kernel(
mosaic_module,
out_avals,
backend=ctx.module_context.backend,
kernel_name=name,
kernel_regeneration_metadata=kernel_regeneration_metadata,
cost_estimate=mosaic_params.get('cost_estimate', None),
)(*extra_args, *args)
return mlir.lower_fun(_lower_fun, multiple_results=True)(
ctx, *in_nodes)
mlir.register_lowering(pallas_call_p, pallas_call_tpu_lowering_rule,
platform="tpu")

0 comments on commit 1c796c0

Please sign in to comment.