From df9dd53c1654246b8864e03b55d23ef23975d3c3 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 17 Nov 2023 18:04:16 -0800 Subject: [PATCH] [Pallas] Refactor Mosaic lowering to encapsulate jaxpr->mlir type creation all in one place PiperOrigin-RevId: 583532870 --- jax/_src/pallas/mosaic/lowering.py | 466 +++++++++++++++-------------- 1 file changed, 249 insertions(+), 217 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 5d8de6cefc73..b4863d196fc3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -79,7 +79,6 @@ class MeshContext: @dataclasses.dataclass class LoweringContext: ir_context: ir.Context - grid_mapping: core.GridMapping | None grid_indices: Sequence[ir.Value] | None block_shapes: list[tuple[int | core.Mapped, ...]] name_stack: source_info_util.NameStack @@ -156,21 +155,120 @@ def ir_constant(x, mlir_type=None): skip_mlir_conversions = set() -def lower_jaxpr_to_module( - ctx: ir.Context, - grid_mapping: core.GridMapping, - jaxpr: jax_core.Jaxpr, - dimension_semantics: tuple[str | None, ...] | None, - mesh: mesh_lib.Mesh | None = None -) -> ir.Module: - m = ir.Module.create() - sym_tab = ir.SymbolTable(m.operation) - used_axis_names = jax_core.used_axis_names_jaxpr(jaxpr) - mesh_info = None - if used_axis_names: +def _get_arg_type( + aval, + block_mapping: core.BlockMapping | None, + memory_space: tpu_core.TPUMemorySpace | None, +): + if isinstance(aval, tpu_core.AbstractMemoryRef): + assert memory_space is None + memory_space = aval.memory_space + if isinstance(aval, tpu_core.AbstractSemaphore): + return aval_to_ir_type(aval), None + if block_mapping is None: + return aval_to_ir_type(aval, memory_space=memory_space), aval.shape + assert memory_space is None + shape = tuple(1 if b is core.mapped else b for b in block_mapping.block_shape) + return ( + aval_to_ir_type( + aval, shape=shape, memory_space=block_mapping.memory_space + ), + block_mapping.block_shape, + ) + + +@dataclasses.dataclass(init=False) +class MosaicGridMapping: + grid: tuple[int, ...] | None + jaxpr: jax_core.Jaxpr + block_mappings: tuple[core.BlockMapping | None, ...] + mapped_dims: tuple[int, ...] + scalar_prefetch_types: tuple[ir.Type, ...] + operand_types: tuple[ir.Type, ...] + scratch_types: tuple[ir.Type, ...] + grid_types: tuple[ir.Type, ...] + scalar_prefetch_block_shapes: tuple[tuple[int, ...], ...] + operand_block_shapes: tuple[tuple[int, ...], ...] + scratch_block_shapes: tuple[tuple[int, ...], ...] + mesh_info: MeshInfo | None + get_grid_indices: Callable | None + + def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: core.GridMapping, + dimension_semantics: tuple[str, ...] | None, + mesh: mesh_lib.Mesh | None): + self.grid = grid_mapping.grid + self.jaxpr = jaxpr + self.block_mappings = grid_mapping.block_mappings + self.mapped_dims = grid_mapping.mapped_dims + num_scalar_prefetch = grid_mapping.num_index_operands + num_scratch = grid_mapping.num_scratch_operands + # jaxpr has signature [*scalar_prefetch, *in_ops *out_ops, *scratch] + num_operands = ( + len(self.jaxpr.invars) + - num_scalar_prefetch + - num_scratch + ) + user_grid = tuple( + g for i, g in enumerate(self.grid) if i not in self.mapped_dims + ) + if dimension_semantics is None: + dimension_semantics = ("arbitrary",) * len(user_grid) + if len(user_grid) != len(dimension_semantics): + raise ValueError( + "Must have dimension semantics for each dimension of the grid." + ) + if num_operands != len(self.block_mappings): + raise ValueError("Must have block mappings for each operand.") + assert len(self.mapped_dims) + len(dimension_semantics) == len( + self.grid + ), ( + f"Misconfigured grid: {self.mapped_dims=}, {dimension_semantics=}," + f" {self.grid=}" + ) + # dimension_semantics is user provided and won't take into account vmap + # dimensions. Here we add in parallel dimensions for the vmaps. + semantics_iter = iter(dimension_semantics) + self._dimension_semantics = tuple( + next(semantics_iter) if i not in self.mapped_dims else "parallel" + for i in range(len(self.grid)) + ) + + in_avals = [invar.aval for invar in self.jaxpr.invars] + scalar_prefetch_avals, operand_avals, scratch_avals = split_list( + in_avals, [num_scalar_prefetch, num_operands] + ) + self.scalar_prefetch_types, _ = unzip2([ + _get_arg_type(aval, None, memory_space=SMEM) + for aval in scalar_prefetch_avals]) + self.scalar_prefetch_block_shapes = tuple( + aval.shape for aval in scalar_prefetch_avals) + self.operand_types, self.operand_block_shapes = unzip2([ + _get_arg_type(aval, block_mapping, memory_space=None) + for aval, block_mapping in zip(operand_avals, self.block_mappings)]) + self.scratch_types, _ = unzip2([ + _get_arg_type(aval, None, memory_space=None) for aval in scratch_avals]) + self.scratch_block_shapes = tuple( + aval.shape if not isinstance(aval, tpu_core.AbstractSemaphore) else None + for aval in scratch_avals + ) + self.grid_types, _ = unzip2([ + _get_arg_type(jax_core.ShapedArray((), jnp.int32), None, + memory_space=None) + for _ in range(len(self.grid)) + ]) + self._prepare_mesh_info(mesh) + def _get_grid_indices(indices): + return indices + self.get_grid_indices = _get_grid_indices + + def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): + if not self.has_communication: + self.mesh_info = None + return if mesh is None: - raise ValueError("Cannot use axis names in pallas_call without shard_map." - ) + raise ValueError( + "Cannot use communication in pallas_call without shard_map." + ) axis_names = mesh.axis_names # We need mesh <-> logical translation tables. Since the logical IDs are # just linearized versions of the mesh IDs, we create those tables. @@ -180,20 +278,70 @@ def lower_jaxpr_to_module( logical_to_mesh = np.empty((mesh.size, len(axis_names)), dtype=np.int32) for i, idx in enumerate(np.ndindex(*mesh.device_ids.shape)): logical_to_mesh[i] = np.array(idx) - mesh_info = (logical_to_mesh, axis_names, mesh_strides) - extra_args = mesh_info[:1] if mesh_info else () - func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping, - name="main", mesh_info=mesh_info) - m.body.append(func_op) - sym_tab.insert(func_op) - num_smem_inputs = grid_mapping.num_index_operands - if mesh_info: + self.mesh_info = MeshInfo(logical_to_mesh, axis_names, mesh_strides) + l_to_m_aval = state.AbstractRef( + jax_core.raise_to_shaped(jax_core.get_aval(logical_to_mesh)) + ) # We are now passing in the logical -> mesh index mapping # TODO(sharadmv,apaszke): avoid stalling pipeline by marking the index # mapping as scalar prefetch and instead just mark it as an SMEM operand. - num_smem_inputs += 1 + self.scalar_prefetch_types = ( + _get_arg_type(l_to_m_aval, None, memory_space=SMEM)[0], + *self.scalar_prefetch_types) + + def maybe_compress_grid(self): + # If we have many leading parallel dimensions, we should "compress" them + # into one so we can load balance across cores as best as we can. + # TODO(sharadmv): implement this optimization + pass + + @functools.cached_property + def has_communication(self) -> bool: + return bool(jax_core.used_axis_names_jaxpr(self.jaxpr)) + + def get_extra_args(self) -> tuple[Any, ...]: + if self.mesh_info is None: + return () + return (self.mesh_info.logical_to_mesh,) + + def get_dimension_semantics(self) -> ir.ArrayAttr: + + def _get_semantics(s: str | None) -> str: + if s is None: + return "#tpu.dimension_semantics" + return f"#tpu.dimension_semantics<{s}>" + + return ir.ArrayAttr.get( + map( + ir.Attribute.parse, + map(_get_semantics, self._dimension_semantics), + ) + ) + +@dataclasses.dataclass +class MeshInfo: + logical_to_mesh: np.ndarray + axis_names: list[str] + mesh_strides: tuple[int, ...] + +def lower_jaxpr_to_module( + ctx: ir.Context, + grid_mapping: core.GridMapping, + jaxpr: jax_core.Jaxpr, + dimension_semantics: tuple[str | None, ...] | None, + mesh: mesh_lib.Mesh | None = None +) -> ir.Module: + mosaic_grid_mapping = MosaicGridMapping( + jaxpr, grid_mapping, dimension_semantics, mesh) + mosaic_grid_mapping.maybe_compress_grid() + m = ir.Module.create() + sym_tab = ir.SymbolTable(m.operation) + func_op = lower_jaxpr_to_func(ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping, + name="main") + m.body.append(func_op) + sym_tab.insert(func_op) window_params = [] - grid = grid_mapping.grid + grid = mosaic_grid_mapping.grid if grid: for i, bm in enumerate(grid_mapping.block_mappings): # TODO(sharadmv): generate default block mapping if left as no_block_spec @@ -209,10 +357,8 @@ def lower_jaxpr_to_module( mlir_func = lower_jaxpr_to_transform_func( ctx, bm.index_map_jaxpr.jaxpr, - [*[None] * len(grid), *[SMEM] * grid_mapping.num_index_operands], name=func_name, - mesh_info=mesh_info, - grid_mapping=grid_mapping, + mosaic_grid_mapping=mosaic_grid_mapping, ) assert mlir_func.verify(), mlir_func block_shape = [ @@ -232,103 +378,101 @@ def lower_jaxpr_to_module( func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params) func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(grid) - num_scratch_inputs = grid_mapping.num_scratch_operands func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get( - ir.IntegerType.get_signless(64), num_smem_inputs) + ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types)) func_op.attributes["scratch_operands"] = ir.IntegerAttr.get( - ir.IntegerType.get_signless(64), num_scratch_inputs) - - def _get_semantics(s: str | None) -> str: - if s is None: - return "#tpu.dimension_semantics" - return f"#tpu.dimension_semantics<{s}>" - - if dimension_semantics is None: - func_dimension_semantics = [ - _get_semantics("parallel") - if i in grid_mapping.mapped_dims - else _get_semantics(None) - for i, d in enumerate(grid_mapping.grid) - ] - else: - dimension_semantics_iter = iter(dimension_semantics) - func_dimension_semantics = [ - _get_semantics("parallel") - if i in grid_mapping.mapped_dims - else _get_semantics(next(dimension_semantics_iter)) - for i, d in enumerate(grid_mapping.grid) - ] - func_op.attributes["dimension_semantics"] = ir.ArrayAttr.get( - map(ir.Attribute.parse, func_dimension_semantics) + ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scratch_types)) + func_op.attributes["dimension_semantics"] = ( + mosaic_grid_mapping.get_dimension_semantics() ) - return m, extra_args + return m, mosaic_grid_mapping.get_extra_args() def lower_jaxpr_to_transform_func( ctx: ir.Context, jaxpr: jax_core.Jaxpr, - memspaces: Sequence[Any], *, name: str, - mesh_info: Any, - grid_mapping: core.GridMapping, + mosaic_grid_mapping: MosaicGridMapping, ) -> func.FuncOp: - block_shapes = [i.aval.shape for i in jaxpr.invars] - assert len(jaxpr.invars) == len(memspaces), ( - f"Must have as many invars ({len(jaxpr.invars)}) as" - f" memspaces ({len(memspaces)})." - ) - arg_types = [*map(aval_to_ir_type, [invar.aval for invar in jaxpr.invars], - block_shapes, memspaces)] + num_grid = len(mosaic_grid_mapping.grid_types) + arg_types = [ + *mosaic_grid_mapping.grid_types, + *mosaic_grid_mapping.scalar_prefetch_types, + ] + def body_func(*args): + grid_indices, scalar_prefetch = split_list(args, [num_grid]) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) + arg_block_shapes = [ + *[()] * len(jaxpr_indices), + *mosaic_grid_mapping.scalar_prefetch_block_shapes, + ] - if mesh_info is not None: - l_to_m, axis_names, mesh_strides = mesh_info - l_to_m_aval = state.AbstractRef( - jax_core.raise_to_shaped(jax_core.get_aval(l_to_m)) - ) - grid_index_types, scalar_prefetch_types = split_list( - arg_types, - [ - len(grid_mapping.grid), - ], + mesh_info = mosaic_grid_mapping.mesh_info + if mesh_info is not None: + (l_to_m,), scalar_prefetch = split_list(scalar_prefetch, [1]) + mesh_context = MeshContext(l_to_m, mesh_info.axis_names, + mesh_info.mesh_strides) + else: + mesh_context = None + lowering_context = LoweringContext( + ctx, + None, + arg_block_shapes, + source_info_util.NameStack(), + mesh_context=mesh_context, ) - arg_types = [ - *grid_index_types, - _get_arg_type(l_to_m_aval, None, SMEM)[0], - *scalar_prefetch_types, - ] + return jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices, + *scalar_prefetch) + body_func.__name__ = name + body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) + body.func_op.verify() + return body.func_op - if mesh_info is not None: - def body_func(*args): - grid_indices, scalar_prefetch = split_list( - args, - [ - len(grid_mapping.grid), - ], - ) +def lower_jaxpr_to_func( + ctx: ir.Context, + jaxpr: jax_core.Jaxpr, + *, + mosaic_grid_mapping: MosaicGridMapping, + name: str, +) -> func.FuncOp: + num_grid = len(mosaic_grid_mapping.grid_types) + num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types) + arg_types = [ + *mosaic_grid_mapping.grid_types, + *mosaic_grid_mapping.scalar_prefetch_types, + *mosaic_grid_mapping.operand_types, + *mosaic_grid_mapping.scratch_types, + ] + arg_block_shapes = [ + *mosaic_grid_mapping.scalar_prefetch_block_shapes, + *mosaic_grid_mapping.operand_block_shapes, + *mosaic_grid_mapping.scratch_block_shapes, + ] + def body_func(*args): + grid_indices, scalar_prefetch, operands_and_scratch = split_list( + args, [num_grid, num_scalar_prefetch]) + grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) + jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices) + if i not in mosaic_grid_mapping.mapped_dims) + mesh_info = mosaic_grid_mapping.mesh_info + if mesh_info is not None: (l_to_m,), scalar_prefetch = split_list(scalar_prefetch, [1]) - mesh_context = MeshContext(l_to_m, axis_names, mesh_strides) - lowering_context = LoweringContext( - ctx, - None, - None, - block_shapes, - source_info_util.NameStack(), - mesh_context=mesh_context, - ) - return jaxpr_subcomp(lowering_context, jaxpr, *grid_indices, *scalar_prefetch) - - else: + mesh_context = MeshContext(l_to_m, mesh_info.axis_names, + mesh_info.mesh_strides) + else: + mesh_context = None lowering_context = LoweringContext( ctx, - None, - None, - block_shapes, + jaxpr_indices, + arg_block_shapes, source_info_util.NameStack(), - mesh_context=None, + mesh_context=mesh_context, + ) + return jaxpr_subcomp( + lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch ) - body_func = functools.partial(jaxpr_subcomp, lowering_context, jaxpr) body_func.__name__ = name body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) body.func_op.verify() @@ -353,118 +497,6 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): return f_lowered -def _get_arg_type( - aval, - block_mapping: core.BlockMapping | None, - memory_space: tpu_core.TPUMemorySpace | None, -): - if isinstance(aval, tpu_core.AbstractMemoryRef): - assert memory_space is None - memory_space = aval.memory_space - if isinstance(aval, tpu_core.AbstractSemaphore): - return aval_to_ir_type(aval), None - if block_mapping is None: - return aval_to_ir_type(aval, memory_space=memory_space), aval.shape - shape = tuple(1 if b is core.mapped else b for b in block_mapping.block_shape) - return ( - aval_to_ir_type(aval, shape=shape, memory_space=memory_space), - block_mapping.block_shape, - ) - - -def lower_jaxpr_to_func( - ctx: ir.Context, - jaxpr: jax_core.Jaxpr, - *, - grid_mapping: core.GridMapping | None, - name: str, - mesh_info: Any, -) -> func.FuncOp: - memory_spaces = [None if bm is None else bm.memory_space - for bm in grid_mapping.block_mappings] - if grid_mapping: - arg_types = map( - aval_to_ir_type, - [jax_core.ShapedArray((), jnp.int32) for _ in grid_mapping.grid], - ) - else: - arg_types = [] - - if mesh_info is not None: - l_to_m, axis_names, mesh_strides = mesh_info - l_to_m_aval = state.AbstractRef( - jax_core.raise_to_shaped(jax_core.get_aval(l_to_m))) - arg_types.append(_get_arg_type(l_to_m_aval, None, SMEM)[0]) - - scalar_prefetch = None - num_scratch = None - if grid_mapping is None: - block_mappings = [None] * len(jaxpr.invars) - else: - scalar_prefetch = grid_mapping.num_index_operands - block_mappings = grid_mapping.block_mappings - num_scratch = grid_mapping.num_scratch_operands - block_mappings = [ - *[None] * scalar_prefetch, - *block_mappings, - *[None] * num_scratch, - ] - memory_spaces = [ - *[SMEM] * scalar_prefetch, - *memory_spaces, - *[None] * num_scratch, - ] - assert len(memory_spaces) == len(jaxpr.invars), ( - f"Must have as many memory spaces as inputs ({len(jaxpr.invars)}) and" - f" outputs ({len(memory_spaces)})." - f" scalar_prefetch={scalar_prefetch} num_scratch={num_scratch}" - ) - invar_arg_types, block_shapes = unzip2( - map(_get_arg_type, [invar.aval for invar in jaxpr.invars], block_mappings, - memory_spaces) - ) - arg_types = [*arg_types, *invar_arg_types] - if grid_mapping: - - def body_func(*args): - grid_indices, scalar_prefetch, args = split_list( - args, - [ - len(grid_mapping.grid), - grid_mapping.num_index_operands + bool(mesh_info), - ], - ) - grid_indices = [ - g - for i, g in enumerate(grid_indices) - if i not in grid_mapping.mapped_dims - ] - if mesh_info is not None: - (l_to_m,), scalar_prefetch = split_list(scalar_prefetch, [1]) - mesh_context = MeshContext(l_to_m, axis_names, mesh_strides) - else: - mesh_context = None - lowering_context = LoweringContext( - ctx, - grid_mapping, - tuple(grid_indices), - block_shapes, - source_info_util.NameStack(), - mesh_context=mesh_context, - ) - return jaxpr_subcomp(lowering_context, jaxpr, *scalar_prefetch, *args) - - else: - lowering_context = LoweringContext( - ctx, None, None, block_shapes, source_info_util.NameStack() - ) - body_func = functools.partial(jaxpr_subcomp, lowering_context, jaxpr) - body_func.__name__ = name - body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) - body.func_op.verify() - return body.func_op - - class LoweringException(Exception): pass