diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 1825f4f6e05f..0413df0e58e7 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -123,6 +123,7 @@ class GridMapping: block_mappings: tuple[BlockMapping | None, ...] mapped_dims: tuple[int, ...] num_index_operands: int + num_scratch_operands: int replace = dataclasses.replace @@ -260,7 +261,7 @@ def get_grid_mapping( in_tree=grid_tree), out_specs, out_ref_avals) grid_mapping = GridMapping( self.grid, (*in_block_mappings, *out_block_mappings), (), - num_index_operands=0) + num_index_operands=0, num_scratch_operands=0) jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals) jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) if not isinstance(jaxpr_out_avals, (tuple, list)): diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index a0822117baf4..a5fade94f61e 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -117,6 +117,15 @@ def get_aval(self) -> AbstractMemoryRef: jax_core.ShapedArray(self.shape, self.dtype), self.memory_space) +def _make_aval(obj: object) -> jax_core.AbstractValue: + if isinstance(obj, MemoryRef): + return obj.get_aval() + if isinstance(obj, SemaphoreType): + return obj.get_aval() + raise ValueError(f"No registered conversion for {type(obj)}. " + "Only VMEM and SemaphoreType are supported.") + + @dataclasses.dataclass(init=False, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): grid: Grid @@ -125,6 +134,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): out_specs: tuple[BlockSpec | NoBlockSpec, ...] in_specs_tree: Any out_specs_tree: Any + scratch_shapes: tuple[Any, ...] def __init__( self, @@ -136,14 +146,19 @@ def __init__( out_specs: BlockSpec | Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec, + scratch_shapes: Any | Sequence[Any] = () ): super().__init__(grid, in_specs, out_specs) self.num_scalar_prefetch = num_scalar_prefetch + self.scratch_shapes = tuple(scratch_shapes) def get_grid_mapping( self, in_avals, in_tree, out_avals, out_tree ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: all_avals = tree_util.tree_unflatten(in_tree, in_avals) + flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( + self.scratch_shapes) + flat_scratch_avals = map(_make_aval, flat_scratch_shapes) scalar_avals, unflat_in_avals = split_list( all_avals, [self.num_scalar_prefetch]) flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals) @@ -176,12 +191,19 @@ def get_grid_mapping( block_mappings=(*in_block_mappings, *out_block_mappings), mapped_dims=(), num_index_operands=num_flat_scalar_prefetch, + num_scratch_operands=len(flat_scratch_avals) ) jaxpr_scalar_ref_avals = tree_util.tree_unflatten( scalar_tree, scalar_ref_avals) jaxpr_in_ref_avals = tree_util.tree_unflatten(in_avals_tree, in_ref_avals) - jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals) + jaxpr_scratch_avals = tree_util.tree_unflatten( + scratch_tree, flat_scratch_avals) + if not isinstance(jaxpr_scratch_avals, (tuple, list)): + jaxpr_scratch_avals = (jaxpr_scratch_avals,) + jaxpr_in_avals = (*jaxpr_scalar_ref_avals, + *jaxpr_in_ref_avals) jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) if not isinstance(jaxpr_out_avals, (tuple, list)): jaxpr_out_avals = (jaxpr_out_avals,) - return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping + return (*jaxpr_in_avals, *jaxpr_out_avals, + *jaxpr_scratch_avals), grid_mapping diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index c204974fd4d8..42832f8c6626 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -105,14 +105,22 @@ def _memory_space_to_tpu_memspace(memory_space: TPUMemorySpace | None def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None): - if shape is None: - shape = aval.shape + if isinstance(aval, tpu_core.AbstractSemaphore): + if aval.sem_type == tpu_core.SemaphoreType.DMA: + return ir.Type.parse("!tpu.dma_semaphore") + elif aval.sem_type == tpu_core.SemaphoreType.REGULAR: + return ir.Type.parse("!tpu.semaphore") + raise NotImplementedError(aval.sem_type) if isinstance(aval, state.AbstractRef): + if shape is None: + shape = aval.shape memspace = _memory_space_to_tpu_memspace(memory_space) return ir.MemRefType.get(shape, mlir.dtype_to_ir_type(aval.dtype), memory_space=memspace) - elif isinstance(aval, jax_core.ShapedArray): - if shape == (): + if isinstance(aval, jax_core.ShapedArray): + if shape is None: + shape = aval.shape + if not shape: return mlir.dtype_to_ir_type(aval.dtype) return ir.VectorType.get(shape, mlir.dtype_to_ir_type(aval.dtype)) raise NotImplementedError(aval) @@ -174,13 +182,6 @@ def lower_jaxpr_to_module( logical_to_mesh[i] = np.array(idx) mesh_info = (logical_to_mesh, tuple(axis_names), mesh_strides) extra_args = mesh_info[:1] if mesh_info else () - if not grid_mapping.grid: - # Trivial grid-map, we don't need to populate the transform functions. - func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping, - name="main", mesh_info=mesh_info) - m.body.append(func_op) - sym_tab.insert(func_op) - return m, extra_args func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping, name="main", mesh_info=mesh_info) m.body.append(func_op) @@ -193,36 +194,46 @@ def lower_jaxpr_to_module( num_smem_inputs += 1 window_params = [] grid = grid_mapping.grid - for i, bm in enumerate(grid_mapping.block_mappings): - func_name = f"transform_{i}" - if bm.index_map_jaxpr.consts: - raise NotImplementedError("Index map jaxpr with consts not supported.") - mlir_func = lower_jaxpr_to_transform_func( - ctx, - bm.index_map_jaxpr.jaxpr, - [*[None] * len(grid), *[SMEM] * num_smem_inputs], - name=func_name) - assert mlir_func.verify(), mlir_func - block_shape = [ - 1 if b is core.mapped else b for b in bm.block_shape - ] - window_shape = ir.DenseI64ArrayAttr.get(block_shape) - window_params.append( - ir.DictAttr.get( - dict( - window_bounds=window_shape, - transform_indices=ir.FlatSymbolRefAttr.get(func_name), - ) - ) - ) - m.body.append(mlir_func) - sym_tab.insert(mlir_func) + if grid: + for i, bm in enumerate(grid_mapping.block_mappings): + # TODO(sharadmv): generate default block mapping if left as no_block_spec + if bm is None: + raise NotImplementedError("Please specify block mappings if " + "grid is specified.") + if bm.index_map_jaxpr is None: + raise NotImplementedError("Please specify index_maps if " + "grid is specified.") + func_name = f"transform_{i}" + if bm.index_map_jaxpr.consts: + raise NotImplementedError("Index map jaxpr with consts not supported.") + mlir_func = lower_jaxpr_to_transform_func( + ctx, + bm.index_map_jaxpr.jaxpr, + [*[None] * len(grid), *[SMEM] * num_smem_inputs], + name=func_name) + assert mlir_func.verify(), mlir_func + block_shape = [ + 1 if b is core.mapped else b for b in bm.block_shape + ] + window_shape = ir.DenseI64ArrayAttr.get(block_shape) + window_params.append( + ir.DictAttr.get( + dict( + window_bounds=window_shape, + transform_indices=ir.FlatSymbolRefAttr.get(func_name), + ) + ) + ) + m.body.append(mlir_func) + sym_tab.insert(mlir_func) + func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params) + func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(grid) + + num_scratch_inputs = grid_mapping.num_scratch_operands func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get( ir.IntegerType.get_signless(64), num_smem_inputs) - func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params) - func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( - grid_mapping.grid - ) + func_op.attributes["scratch_operands"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), num_scratch_inputs) def _get_semantics(s: str | None) -> str: if s is None: @@ -304,6 +315,11 @@ def lower_jaxpr_to_func( def _get_arg_type(aval, block_mapping: core.BlockMapping | None, memory_space: tpu_core.TPUMemorySpace | None): + if isinstance(aval, tpu_core.AbstractMemoryRef): + assert memory_space is None + memory_space = aval.memory_space + if isinstance(aval, tpu_core.AbstractSemaphore): + return aval_to_ir_type(aval), None if block_mapping is None: return aval_to_ir_type(aval, memory_space=memory_space), aval.shape shape = tuple( @@ -323,8 +339,17 @@ def _get_arg_type(aval, block_mapping: core.BlockMapping | None, else: scalar_prefetch = grid_mapping.num_index_operands block_mappings = grid_mapping.block_mappings - block_mappings = [*[None] * scalar_prefetch, *block_mappings] - memory_spaces = [*[SMEM] * scalar_prefetch, *memory_spaces] + num_scratch = grid_mapping.num_scratch_operands + block_mappings = [ + *[None] * scalar_prefetch, + *block_mappings, + *[None] * num_scratch, + ] + memory_spaces = [ + *[SMEM] * scalar_prefetch, + *memory_spaces, + *[None] * num_scratch, + ] assert len(memory_spaces) == len(jaxpr.invars), ( "Must have as many memory spaces as inputs and outputs.") invar_arg_types, block_shapes = unzip2( @@ -606,7 +631,7 @@ def _masked_swap_lowering_rule( for a in idx_aval.indices ): raise ValueError("Cannot do int indexing on TPU") - if not ref_block_shape: + if not is_smem_store and not ref_block_shape: raise NotImplementedError( "Indexing into a ()-shaped Ref not yet supported on TPU.") starts = tuple(