Skip to content

Commit

Permalink
[Pallas TPU] Add support for hoisted scratch spaces
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576336673
  • Loading branch information
sharadmv authored and jax authors committed Oct 25, 2023
1 parent 47a76df commit d488812
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 45 deletions.
3 changes: 2 additions & 1 deletion jax/_src/pallas/core.py
Expand Up @@ -123,6 +123,7 @@ class GridMapping:
block_mappings: tuple[BlockMapping | None, ...]
mapped_dims: tuple[int, ...]
num_index_operands: int
num_scratch_operands: int

replace = dataclasses.replace

Expand Down Expand Up @@ -260,7 +261,7 @@ def get_grid_mapping(
in_tree=grid_tree), out_specs, out_ref_avals)
grid_mapping = GridMapping(
self.grid, (*in_block_mappings, *out_block_mappings), (),
num_index_operands=0)
num_index_operands=0, num_scratch_operands=0)
jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
if not isinstance(jaxpr_out_avals, (tuple, list)):
Expand Down
26 changes: 24 additions & 2 deletions jax/_src/pallas/mosaic/core.py
Expand Up @@ -117,6 +117,15 @@ def get_aval(self) -> AbstractMemoryRef:
jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)


def _make_aval(obj: object) -> jax_core.AbstractValue:
if isinstance(obj, MemoryRef):
return obj.get_aval()
if isinstance(obj, SemaphoreType):
return obj.get_aval()
raise ValueError(f"No registered conversion for {type(obj)}. "
"Only VMEM and SemaphoreType are supported.")


@dataclasses.dataclass(init=False, unsafe_hash=True)
class PrefetchScalarGridSpec(pallas_core.GridSpec):
grid: Grid
Expand All @@ -125,6 +134,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
in_specs_tree: Any
out_specs_tree: Any
scratch_shapes: tuple[Any, ...]

def __init__(
self,
Expand All @@ -136,14 +146,19 @@ def __init__(
out_specs: BlockSpec
| Sequence[BlockSpec | NoBlockSpec]
| NoBlockSpec = no_block_spec,
scratch_shapes: Any | Sequence[Any] = ()
):
super().__init__(grid, in_specs, out_specs)
self.num_scalar_prefetch = num_scalar_prefetch
self.scratch_shapes = tuple(scratch_shapes)

def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
self.scratch_shapes)
flat_scratch_avals = map(_make_aval, flat_scratch_shapes)
scalar_avals, unflat_in_avals = split_list(
all_avals, [self.num_scalar_prefetch])
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
Expand Down Expand Up @@ -176,12 +191,19 @@ def get_grid_mapping(
block_mappings=(*in_block_mappings, *out_block_mappings),
mapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,
num_scratch_operands=len(flat_scratch_avals)
)
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
scalar_tree, scalar_ref_avals)
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_avals_tree, in_ref_avals)
jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals)
jaxpr_scratch_avals = tree_util.tree_unflatten(
scratch_tree, flat_scratch_avals)
if not isinstance(jaxpr_scratch_avals, (tuple, list)):
jaxpr_scratch_avals = (jaxpr_scratch_avals,)
jaxpr_in_avals = (*jaxpr_scalar_ref_avals,
*jaxpr_in_ref_avals)
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
if not isinstance(jaxpr_out_avals, (tuple, list)):
jaxpr_out_avals = (jaxpr_out_avals,)
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping
return (*jaxpr_in_avals, *jaxpr_out_avals,
*jaxpr_scratch_avals), grid_mapping
109 changes: 67 additions & 42 deletions jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -105,14 +105,22 @@ def _memory_space_to_tpu_memspace(memory_space: TPUMemorySpace | None


def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None):
if shape is None:
shape = aval.shape
if isinstance(aval, tpu_core.AbstractSemaphore):
if aval.sem_type == tpu_core.SemaphoreType.DMA:
return ir.Type.parse("!tpu.dma_semaphore")
elif aval.sem_type == tpu_core.SemaphoreType.REGULAR:
return ir.Type.parse("!tpu.semaphore")
raise NotImplementedError(aval.sem_type)
if isinstance(aval, state.AbstractRef):
if shape is None:
shape = aval.shape
memspace = _memory_space_to_tpu_memspace(memory_space)
return ir.MemRefType.get(shape, mlir.dtype_to_ir_type(aval.dtype),
memory_space=memspace)
elif isinstance(aval, jax_core.ShapedArray):
if shape == ():
if isinstance(aval, jax_core.ShapedArray):
if shape is None:
shape = aval.shape
if not shape:
return mlir.dtype_to_ir_type(aval.dtype)
return ir.VectorType.get(shape, mlir.dtype_to_ir_type(aval.dtype))
raise NotImplementedError(aval)
Expand Down Expand Up @@ -174,13 +182,6 @@ def lower_jaxpr_to_module(
logical_to_mesh[i] = np.array(idx)
mesh_info = (logical_to_mesh, tuple(axis_names), mesh_strides)
extra_args = mesh_info[:1] if mesh_info else ()
if not grid_mapping.grid:
# Trivial grid-map, we don't need to populate the transform functions.
func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping,
name="main", mesh_info=mesh_info)
m.body.append(func_op)
sym_tab.insert(func_op)
return m, extra_args
func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping,
name="main", mesh_info=mesh_info)
m.body.append(func_op)
Expand All @@ -193,36 +194,46 @@ def lower_jaxpr_to_module(
num_smem_inputs += 1
window_params = []
grid = grid_mapping.grid
for i, bm in enumerate(grid_mapping.block_mappings):
func_name = f"transform_{i}"
if bm.index_map_jaxpr.consts:
raise NotImplementedError("Index map jaxpr with consts not supported.")
mlir_func = lower_jaxpr_to_transform_func(
ctx,
bm.index_map_jaxpr.jaxpr,
[*[None] * len(grid), *[SMEM] * num_smem_inputs],
name=func_name)
assert mlir_func.verify(), mlir_func
block_shape = [
1 if b is core.mapped else b for b in bm.block_shape
]
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
window_params.append(
ir.DictAttr.get(
dict(
window_bounds=window_shape,
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
)
)
)
m.body.append(mlir_func)
sym_tab.insert(mlir_func)
if grid:
for i, bm in enumerate(grid_mapping.block_mappings):
# TODO(sharadmv): generate default block mapping if left as no_block_spec
if bm is None:
raise NotImplementedError("Please specify block mappings if "
"grid is specified.")
if bm.index_map_jaxpr is None:
raise NotImplementedError("Please specify index_maps if "
"grid is specified.")
func_name = f"transform_{i}"
if bm.index_map_jaxpr.consts:
raise NotImplementedError("Index map jaxpr with consts not supported.")
mlir_func = lower_jaxpr_to_transform_func(
ctx,
bm.index_map_jaxpr.jaxpr,
[*[None] * len(grid), *[SMEM] * num_smem_inputs],
name=func_name)
assert mlir_func.verify(), mlir_func
block_shape = [
1 if b is core.mapped else b for b in bm.block_shape
]
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
window_params.append(
ir.DictAttr.get(
dict(
window_bounds=window_shape,
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
)
)
)
m.body.append(mlir_func)
sym_tab.insert(mlir_func)
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(grid)

num_scratch_inputs = grid_mapping.num_scratch_operands
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), num_smem_inputs)
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
grid_mapping.grid
)
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), num_scratch_inputs)

def _get_semantics(s: str | None) -> str:
if s is None:
Expand Down Expand Up @@ -304,6 +315,11 @@ def lower_jaxpr_to_func(

def _get_arg_type(aval, block_mapping: core.BlockMapping | None,
memory_space: tpu_core.TPUMemorySpace | None):
if isinstance(aval, tpu_core.AbstractMemoryRef):
assert memory_space is None
memory_space = aval.memory_space
if isinstance(aval, tpu_core.AbstractSemaphore):
return aval_to_ir_type(aval), None
if block_mapping is None:
return aval_to_ir_type(aval, memory_space=memory_space), aval.shape
shape = tuple(
Expand All @@ -323,8 +339,17 @@ def _get_arg_type(aval, block_mapping: core.BlockMapping | None,
else:
scalar_prefetch = grid_mapping.num_index_operands
block_mappings = grid_mapping.block_mappings
block_mappings = [*[None] * scalar_prefetch, *block_mappings]
memory_spaces = [*[SMEM] * scalar_prefetch, *memory_spaces]
num_scratch = grid_mapping.num_scratch_operands
block_mappings = [
*[None] * scalar_prefetch,
*block_mappings,
*[None] * num_scratch,
]
memory_spaces = [
*[SMEM] * scalar_prefetch,
*memory_spaces,
*[None] * num_scratch,
]
assert len(memory_spaces) == len(jaxpr.invars), (
"Must have as many memory spaces as inputs and outputs.")
invar_arg_types, block_shapes = unzip2(
Expand Down Expand Up @@ -606,7 +631,7 @@ def _masked_swap_lowering_rule(
for a in idx_aval.indices
):
raise ValueError("Cannot do int indexing on TPU")
if not ref_block_shape:
if not is_smem_store and not ref_block_shape:
raise NotImplementedError(
"Indexing into a ()-shaped Ref not yet supported on TPU.")
starts = tuple(
Expand Down

0 comments on commit d488812

Please sign in to comment.