diff --git a/jax/_src/pallas/mosaic/__init__.py b/jax/_src/pallas/mosaic/__init__.py index 23db94c7299d..ed1c2afb929a 100644 --- a/jax/_src/pallas/mosaic/__init__.py +++ b/jax/_src/pallas/mosaic/__init__.py @@ -22,6 +22,7 @@ from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata from jax._src.pallas.mosaic.lowering import LoweringException +from jax._src.pallas.mosaic.primitives import DeviceIdType from jax._src.pallas.mosaic.primitives import async_copy from jax._src.pallas.mosaic.primitives import async_remote_copy from jax._src.pallas.mosaic.primitives import device_id diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index e2ffefac45b7..4e923cee0999 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -25,6 +25,7 @@ from jax._src import custom_derivatives from jax._src import debugging from jax._src import linear_util as lu +from jax._src import mesh as mesh_lib from jax._src import pjit from jax._src import source_info_util from jax._src import state @@ -68,6 +69,13 @@ zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin +@dataclasses.dataclass +class MeshContext: + logical_to_mesh: ir.Value + axis_names: tuple[str, ...] + mesh_strides: tuple[int, ...] + + @dataclasses.dataclass class LoweringContext: ir_context: ir.Context @@ -75,6 +83,7 @@ class LoweringContext: grid_indices: Sequence[ir.Value] | None block_shapes: list[tuple[int | core.Mapped, ...]] name_stack: source_info_util.NameStack + mesh_context: MeshContext | None replace = dataclasses.replace @@ -144,21 +153,44 @@ def lower_jaxpr_to_module( grid_mapping: core.GridMapping, jaxpr: jax_core.Jaxpr, dimension_semantics: tuple[str | None, ...] | None, + mesh: mesh_lib.Mesh | None = None ) -> ir.Module: m = ir.Module.create() sym_tab = ir.SymbolTable(m.operation) + used_axis_names = jax_core.used_axis_names_jaxpr(jaxpr) + mesh_info = None + if used_axis_names: + axis_names = list(used_axis_names) + if mesh is None: + raise ValueError("Cannot use axis names in pallas_call without shard_map." + ) + # We need mesh <-> logical translation tables. Since the logical IDs are + # just linearized versions of the mesh IDs, we create those tables. + mesh_strides = pallas_utils.strides_from_shape(tuple( + mesh.shape[a] for a in axis_names + )) + logical_to_mesh = np.empty((mesh.size, len(axis_names)), dtype=np.int32) + for i, idx in enumerate(np.ndindex(*mesh.device_ids.shape)): + logical_to_mesh[i] = np.array(idx) + mesh_info = (logical_to_mesh, tuple(axis_names), mesh_strides) + extra_args = mesh_info[:1] if mesh_info else () if not grid_mapping.grid: # Trivial grid-map, we don't need to populate the transform functions. func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping, - name="main") + name="main", mesh_info=mesh_info) m.body.append(func_op) sym_tab.insert(func_op) - return m + return m, extra_args func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping, - name="main") + name="main", mesh_info=mesh_info) m.body.append(func_op) sym_tab.insert(func_op) num_smem_inputs = grid_mapping.num_index_operands + if mesh_info: + # We are now passing in the logical -> mesh index mapping + # TODO(sharadmv,apaszke): avoid stalling pipeline by marking the index + # mapping as scalar prefetch and instead just mark it as an SMEM operand. + num_smem_inputs += 1 window_params = [] grid = grid_mapping.grid for i, bm in enumerate(grid_mapping.block_mappings): @@ -215,7 +247,7 @@ def _get_semantics(s: str | None) -> str: func_op.attributes["dimension_semantics"] = ir.ArrayAttr.get( map(ir.Attribute.parse, func_dimension_semantics) ) - return m + return m, extra_args def lower_jaxpr_to_transform_func( @@ -225,7 +257,8 @@ def lower_jaxpr_to_transform_func( arg_types = [*map(aval_to_ir_type, [invar.aval for invar in jaxpr.invars], block_shapes, memspaces)] lowering_context = LoweringContext( - ctx, None, None, block_shapes, source_info_util.NameStack()) + ctx, None, None, block_shapes, source_info_util.NameStack(), + mesh_context=None) body_func = functools.partial(jaxpr_subcomp, lowering_context, jaxpr) body_func.__name__ = name body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) @@ -257,6 +290,7 @@ def lower_jaxpr_to_func( *, grid_mapping: core.GridMapping | None, name: str, + mesh_info: Any, ) -> func.FuncOp: memory_spaces = [None if bm is None else bm.memory_space for bm in grid_mapping.block_mappings] @@ -277,6 +311,13 @@ def _get_arg_type(aval, block_mapping: core.BlockMapping | None, ) return (aval_to_ir_type(aval, shape=shape, memory_space=memory_space), block_mapping.block_shape) + + if mesh_info is not None: + l_to_m, axis_names, mesh_strides = mesh_info + l_to_m_aval = state.AbstractRef( + jax_core.raise_to_shaped(jax_core.get_aval(l_to_m))) + arg_types.append(_get_arg_type(l_to_m_aval, None, SMEM)[0]) + if grid_mapping is None: block_mappings = [None] * len(jaxpr.invars) else: @@ -300,12 +341,18 @@ def body_func(*args): for i, g in enumerate(grid_indices) if i not in grid_mapping.mapped_dims ] + if mesh_info is not None: + (l_to_m,), args = split_list(args, [1]) + mesh_context = MeshContext(l_to_m, axis_names, mesh_strides) + else: + mesh_context = None lowering_context = LoweringContext( ctx, grid_mapping, tuple(grid_indices), block_shapes, source_info_util.NameStack(), + mesh_context=mesh_context, ) return jaxpr_subcomp(lowering_context, jaxpr, *args) @@ -1510,12 +1557,35 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): lowering_rules[tpu_primitives.run_scoped_p] = _run_scoped_lowering_rule +def _device_id_to_logical( + ctx: LoweringRuleContext, device_id, + device_id_type: tpu_primitives.DeviceIdType): + if device_id_type is tpu_primitives.DeviceIdType.MESH: + # Mesh means we are passed the mesh coordinates for the device + device_ids = tree_util.tree_leaves(device_id) + mesh_strides = ctx.lowering_context.mesh_context.mesh_strides + def _linearize_mesh_indices(*indices): + return sum([a * b for a, b in zip(indices, mesh_strides)]) + lower_ctx = LoweringRuleContext( + lowering_context=ctx.lowering_context, + avals_in=[jax_core.ShapedArray((), jnp.int32)] * len(device_ids), + avals_out=[jax_core.ShapedArray((), jnp.int32)], + block_shapes=(None,) * len(device_ids), + ) + return lower_fun(_linearize_mesh_indices, multiple_results=False)( + lower_ctx, *device_ids) + elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL: + return device_id + raise NotImplementedError(f"Unsupported device id type: {device_id_type}") + def _semaphore_signal_lowering_rule(ctx: LoweringRuleContext, semaphore, - value, *args, has_device_id: bool): + value, *args, has_device_id: bool, + device_id_type: tpu_primitives.DeviceIdType): device_id = None assert semaphore.type == ir.Type.parse("!tpu.semaphore") if has_device_id: (device_id,) = args + device_id = _device_id_to_logical(ctx, device_id, device_id_type) return tpu.SemaphoreSignalOp(semaphore, value, device_id=device_id).results lowering_rules[tpu_primitives.semaphore_signal_p] = ( _semaphore_signal_lowering_rule) @@ -1541,12 +1611,11 @@ def _indexer_to_start_size(indexer: NDIndexer): partial(_ensure_mlir_value, aval=jax_core.ShapedArray((), jnp.int32)), starts, ) - sizes = [ - s.size if isinstance(s, primitives.Slice) else 1 for s in indexer.indices - ] + sizes = indexer.get_indexer_shape() return starts, sizes -def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree): +def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, + device_id_type: tpu_primitives.DeviceIdType): (src_ref, src_idx, dst_ref, dst_idx, sem, src_sem, device_id) = ( tree_util.tree_unflatten(tree, args) ) @@ -1561,13 +1630,16 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree): memory_space=dst_ref.type.memory_space) src = tpu.MemRefSliceOp(src_ref_ty, src_ref, src_starts).result dst = tpu.MemRefSliceOp(dst_ref_ty, dst_ref, dst_starts).result - return tpu.EnqueueDMAOp(source=src, target=dst, target_semaphore=sem, - source_semaphore=src_sem, + if device_id is not None: + device_id = _device_id_to_logical(ctx, device_id, device_id_type) + return tpu.EnqueueDMAOp(src, dst, sem, source_semaphore=src_sem, device_id=device_id).results lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule -def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree): +def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, + device_id_type: tpu_primitives.DeviceIdType): + del device_id_type sem, ref, idx = tree_util.tree_unflatten(tree, args) starts, sizes = _indexer_to_start_size(idx) ref_ty = ir.MemRefType.get( @@ -1580,3 +1652,11 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree): def _device_id_lowering_rule(ctx: LoweringRuleContext): return tpu.DeviceIdOp().result lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule + +def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: str): + device_id = _make_index(tpu.DeviceIdOp().result) + l_to_m = ctx.lowering_context.mesh_context.logical_to_mesh + axis_names = ctx.lowering_context.mesh_context.axis_names + col = _make_index(axis_names.index(axis_name)) + return memref.LoadOp(l_to_m, [device_id, col]).result +lowering_rules[lax.axis_index_p] = _axis_index_rule diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 7d90949263c2..05a2f8003fb0 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -56,6 +56,11 @@ def pallas_call_tpu_lowering_rule( grid_mapping=grid_mapping, **compiler_params) if debug: print(jaxpr) + mesh = None + axis_context = ctx.module_context.axis_context + if axis_context is not None: + if isinstance(axis_context, mlir.SPMDAxisContext): + mesh = axis_context.mesh with ir.Context() as mlir_ctx, ir.Location.unknown(mlir_ctx): tpu.register_dialect(mlir_ctx) if mosaic_params is None: @@ -64,21 +69,22 @@ def pallas_call_tpu_lowering_rule( kernel_regeneration_metadata = mosaic_params.get( "kernel_regeneration_metadata" ) - mosaic_module = lowering.lower_jaxpr_to_module( - mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics) + mosaic_module, extra_args = lowering.lower_jaxpr_to_module( + mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics, + mesh=mesh) if debug: print(mosaic_module) out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes] - return mlir.lower_fun( - mosaic.as_tpu_kernel( - mosaic_module, - out_avals, - backend=ctx.module_context.backend, - kernel_name=name, - kernel_regeneration_metadata=kernel_regeneration_metadata, - cost_estimate=mosaic_params.get('cost_estimate', None), - ), - multiple_results=True, - )(ctx, *in_nodes) + def _lower_fun(*args): + return mosaic.as_tpu_kernel( + mosaic_module, + out_avals, + backend=ctx.module_context.backend, + kernel_name=name, + kernel_regeneration_metadata=kernel_regeneration_metadata, + cost_estimate=mosaic_params.get('cost_estimate', None), + )(*extra_args, *args) + return mlir.lower_fun(_lower_fun, multiple_results=True)( + ctx, *in_nodes) mlir.register_lowering(pallas_call_p, pallas_call_tpu_lowering_rule, platform="tpu") diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index b69ecfb111de..cf47433298d6 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -17,6 +17,7 @@ import contextlib import dataclasses +import enum from typing import Any, Callable import jax @@ -107,21 +108,30 @@ def _run_scoped_abstract_eval(*args, jaxpr): return [], nonlocal_effects +class DeviceIdType(enum.Enum): + MESH = "mesh" + LOGICAL = "logical" + + semaphore_signal_p = jax_core.Primitive('semaphore_signal') semaphore_signal_p.multiple_results = True def semaphore_signal(sem, inc: int | jax.Array = 1, - *, device_id: int | jax.Array | None = None): + *, device_id: int | jax.Array | None = None, + device_id_type: DeviceIdType = DeviceIdType.MESH): inc = jnp.asarray(inc, dtype=jnp.int32) args = [sem, inc] has_device_id = device_id is not None if has_device_id: args = [*args, device_id] - semaphore_signal_p.bind(*args, has_device_id=has_device_id) + semaphore_signal_p.bind(*args, has_device_id=has_device_id, + device_id_type=device_id_type) @semaphore_signal_p.def_abstract_eval def _semaphore_signal_abstract_eval(sem_aval: tpu_core.AbstractSemaphore, value, - *args, has_device_id: bool): + *args, has_device_id: bool, + device_id_type: DeviceIdType): + del device_id_type if not isinstance(sem_aval, tpu_core.AbstractSemaphore): raise ValueError(f"Cannot signal on a non-semaphore value: {sem_aval}") if sem_aval.sem_type is not tpu_core.SemaphoreType.REGULAR: @@ -156,16 +166,18 @@ def _semaphore_wait_abstract_eval(sem_aval: tpu_core.AbstractSemaphore, value): class DMAFuture: flat_args: Any tree: Any + device_id_type: DeviceIdType | None def wait(self): - dma_wait_p.bind(*self.flat_args, tree=self.tree) + dma_wait_p.bind(*self.flat_args, tree=self.tree, + device_id_type=self.device_id_type) dma_start_p = jax_core.Primitive('dma_start') dma_start_p.multiple_results = True @dma_start_p.def_abstract_eval -def _dma_start_abstract_eval(*args, tree): - del args, tree +def _dma_start_abstract_eval(*args, tree, device_id_type): + del args, tree, device_id_type return [] def dma_start(src_ref, src_indices, dst_ref, dst_indices, sem) -> DMAFuture: @@ -175,13 +187,14 @@ def dma_start(src_ref, src_indices, dst_ref, dst_indices, sem) -> DMAFuture: dst_ref.shape) args = (src_ref, src_indexer, dst_ref, dst_indexer, sem, None, None) flat_args, tree = tree_util.tree_flatten(args) - dma_start_p.bind(*flat_args, tree=tree) + dma_start_p.bind(*flat_args, tree=tree, device_id_type=None) wait_args, tree = tree_util.tree_flatten((sem, dst_ref, dst_indexer)) - return DMAFuture(wait_args, tree) - + return DMAFuture(wait_args, tree, None) def remote_dma_start(src_ref, src_indices, dst_ref, dst_indices, src_sem, - dst_sem, device_id) -> tuple[DMAFuture, DMAFuture]: + dst_sem, device_id, + device_id_type: DeviceIdType) -> tuple[DMAFuture, + DMAFuture]: src_indexer = indexing.NDIndexer.from_indices_shape(src_indices, src_ref.shape) dst_indexer = indexing.NDIndexer.from_indices_shape(dst_indices, @@ -189,20 +202,21 @@ def remote_dma_start(src_ref, src_indices, dst_ref, dst_indices, src_sem, args = (src_ref, src_indexer, dst_ref, dst_indexer, dst_sem, src_sem, device_id) flat_args, tree = tree_util.tree_flatten(args) - dma_start_p.bind(*flat_args, tree=tree) + dma_start_p.bind(*flat_args, tree=tree, device_id_type=device_id_type) recv_wait_args = (dst_sem, dst_ref, dst_indexer) recv_args, recv_tree = tree_util.tree_flatten(recv_wait_args) send_wait_args = (src_sem, src_ref, src_indexer) send_args, send_tree = tree_util.tree_flatten(send_wait_args) - return DMAFuture(send_args, send_tree), DMAFuture(recv_args, recv_tree) + return (DMAFuture(send_args, send_tree, device_id_type), + DMAFuture(recv_args, recv_tree, device_id_type)) dma_wait_p = jax_core.Primitive('dma_wait') dma_wait_p.multiple_results = True @dma_wait_p.def_abstract_eval -def _dma_wait_abstract_eval(*args, tree): - del args, tree +def _dma_wait_abstract_eval(*args, tree, device_id_type): + del args, tree, device_id_type return [] def _get_ref_and_indexer(ref): @@ -216,11 +230,12 @@ def async_copy(src_ref, dst_ref, sem): dst_ref, dst_indices = _get_ref_and_indexer(dst_ref) return dma_start(src_ref, src_indices, dst_ref, dst_indices, sem) -def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id): +def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, + device_id_type: DeviceIdType = DeviceIdType.MESH): src_ref, src_indices = _get_ref_and_indexer(src_ref) dst_ref, dst_indices = _get_ref_and_indexer(dst_ref) return remote_dma_start(src_ref, src_indices, dst_ref, dst_indices, send_sem, - recv_sem, device_id) + recv_sem, device_id, device_id_type=device_id_type) device_id_p = jax_core.Primitive('device_id') diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index ea8a73df7129..54ed013ccc70 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -20,6 +20,7 @@ from jax._src.pallas.mosaic import SemaphoreType from jax._src.pallas.mosaic import TPUMemorySpace from jax._src.pallas.mosaic import VMEM +from jax._src.pallas.mosaic import DeviceIdType from jax._src.pallas.mosaic import async_copy from jax._src.pallas.mosaic import async_remote_copy from jax._src.pallas.mosaic import device_id diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 218e53a92007..acd711e168c0 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -399,7 +399,7 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal"> { let arguments = (ins TPU_SemaphoreType:$semaphore, I32:$amount, - Optional:$device_id + Optional:$device_id // For remote DMAs ); let assemblyFormat = [{ $semaphore `,` $amount (`,` $device_id^)? attr-dict