diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 5f2ea459..1f5916fa 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -442,7 +442,7 @@ def sparsecore_build( # Collect all stacked tables. table_specs = embedding.get_table_specs(feature_specs) - table_stacks = embedding_utils.get_table_stacks(table_specs) + table_stacks = jte_table_stacking.get_table_stacks(table_specs) # Create variables for all stacked tables and slot variables. with sparsecore_distribution.scope(): @@ -516,7 +516,7 @@ def _sparsecore_symbolic_preprocess( # Each stacked-table gets a ShardedCooMatrix. table_specs = embedding.get_table_specs(self._config.feature_specs) - table_stacks = embedding_utils.get_table_stacks(table_specs) + table_stacks = jte_table_stacking.get_table_stacks(table_specs) stacked_table_specs = { stack_name: stack[0].stacked_table_spec for stack_name, stack in table_stacks.items() @@ -720,7 +720,7 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None: config = self._config num_table_shards = config.mesh.devices.size * config.num_sc_per_device table_specs = embedding.get_table_specs(config.feature_specs) - sharded_tables = embedding_utils.stack_and_shard_tables( + sharded_tables = jte_table_stacking.stack_and_shard_tables( table_specs, tables, num_table_shards, @@ -763,7 +763,7 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]: return typing.cast( dict[str, ArrayLike], - embedding_utils.unshard_and_unstack_tables( + jte_table_stacking.unshard_and_unstack_tables( table_specs, table_variables, num_table_shards ), ) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py index 89cbb171..1dd3525d 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py @@ -24,7 +24,6 @@ from keras_rs.src.layers.embedding.jax import ( distributed_embedding as jax_distributed_embedding, ) -from keras_rs.src.layers.embedding.jax import embedding_utils from keras_rs.src.layers.embedding.jax import test_utils keras.config.disable_traceback_filtering() @@ -177,7 +176,7 @@ def test_sharded_matches_unsharded(self): ) self.assertEqual(actual.shape, expected_shape) - unsharded_tables = embedding_utils.unshard_and_unstack_tables( + unsharded_tables = table_stacking_lib.unshard_and_unstack_tables( table_specs, {stacked_table_spec.stack_name: actual}, num_table_shards, diff --git a/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py b/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py index 7816928c..d41b2535 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py +++ b/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py @@ -398,7 +398,7 @@ def test_backward_pass( res=(sharded_samples, sharded_table_and_slot_variables, None), gradients=activation_grads, ) - updated_tables_and_slots = embedding_utils.unshard_and_unstack_tables( + updated_tables_and_slots = table_stacking.unshard_and_unstack_tables( table_specs, updated_stacked_tables, num_table_shards ) @@ -553,7 +553,7 @@ def loss_fn(params, lookups, labels): lookup_grads = grads["lookup_tables"] # Recover unstacked and unsharded gradients. - updated_tables_and_slots = embedding_utils.unshard_and_unstack_tables( + updated_tables_and_slots = table_stacking.unshard_and_unstack_tables( table_specs, lookup_grads, num_table_shards ) diff --git a/keras_rs/src/layers/embedding/jax/embedding_utils.py b/keras_rs/src/layers/embedding/jax/embedding_utils.py index 8a2d1cd3..80e342dd 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_utils.py +++ b/keras_rs/src/layers/embedding/jax/embedding_utils.py @@ -1,16 +1,13 @@ """Utility functions for manipulating JAX embedding tables and inputs.""" import collections -import typing from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar import jax import numpy as np -from jax import numpy as jnp from jax_tpu_embedding.sparsecore.lib.nn import embedding +from jax_tpu_embedding.sparsecore.lib.nn import table_stacking from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec -from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import StackedTableSpec -from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec from keras_rs.src.types import Nested @@ -34,297 +31,6 @@ class ShardedCooMatrix(NamedTuple): values: ArrayLike -def _round_up_to_multiple(value: int, multiple: int) -> int: - return ((value + multiple - 1) // multiple) * multiple - - -def _default_stacked_table_spec( - table_spec: TableSpec, num_shards: int, batch_size: int -) -> StackedTableSpec: - return StackedTableSpec( - stack_name=table_spec.name, - stack_vocab_size=_round_up_to_multiple( - table_spec.vocabulary_size, 8 * num_shards - ), - stack_embedding_dim=_round_up_to_multiple(table_spec.embedding_dim, 8), - optimizer=table_spec.optimizer, - combiner=table_spec.combiner, - total_sample_count=batch_size, - max_ids_per_partition=table_spec.max_ids_per_partition, - max_unique_ids_per_partition=table_spec.max_unique_ids_per_partition, - ) - - -def _get_stacked_table_spec( - table_spec: TableSpec, num_shards: int, batch_size: int = 0 -) -> StackedTableSpec: - return table_spec.stacked_table_spec or _default_stacked_table_spec( - table_spec, num_shards, batch_size - ) - - -def pad_table( - table_spec: TableSpec, - table_values: jax.Array, - num_shards: int, - pad_value: jnp.float32 = jnp.nan, -) -> jax.Array: - """Adds appropriate padding to a table to prepare for stacking. - - Args: - table_spec: Table specification describing the table to pad. - table_values: Table values array to pad. - num_shards: Number of shards in the table (typically - `global_device_count * num_sc_per_device`). - pad_value: Value to use for padding. - - Returns: - Padded table values. - """ - vocabulary_size = table_spec.vocabulary_size - embedding_dim = table_spec.embedding_dim - padded_vocabulary_size = _round_up_to_multiple( - vocabulary_size, 8 * num_shards - ) - stack_embedding_dim = _get_stacked_table_spec( - table_spec, num_shards - ).stack_embedding_dim - return jnp.pad( - table_values, - ( - (0, padded_vocabulary_size - vocabulary_size), - (0, stack_embedding_dim - embedding_dim), - ), - constant_values=pad_value, - ) - - -def _stack_and_shard_table( - stacked_table: jax.Array, - table_spec: TableSpec, - table: jax.Array, - num_shards: int, - pad_value: jnp.float32, -) -> jax.Array: - """Stacks and shards a single table for use in sparsecore lookups.""" - padded_values = pad_table(table_spec, table, num_shards, pad_value) - sharded_padded_vocabulary_size = padded_values.shape[0] // num_shards - stack_embedding_dim = stacked_table.shape[-1] - - # Mod-shard vocabulary across devices. - sharded_values = jnp.swapaxes( - padded_values.reshape(-1, num_shards, stack_embedding_dim), - 0, - 1, - ) - - # Rotate shards. - setting_in_stack = table_spec.setting_in_stack - rotated_values = jnp.roll( - sharded_values, setting_in_stack.shard_rotation, axis=0 - ) - - # Insert table into the stack. - table_row = setting_in_stack.row_offset_in_shard - stacked_table = stacked_table.at[ - :, table_row : (table_row + sharded_padded_vocabulary_size), : - ].set(rotated_values) - - return stacked_table - - -def stack_and_shard_tables( - table_specs: Nested[TableSpec], - tables: Nested[ArrayLike], - num_shards: int, - pad_value: jnp.float32 = jnp.nan, -) -> dict[str, Nested[jax.Array]]: - """Stacks and shards tables for use in sparsecore lookups. - - Args: - table_specs: Nested collection of unstacked table specifications. - tables: Table values corresponding to the table_specs. - num_shards: Number of shards in the table (typically - `global_device_count * num_sc_per_device`). - pad_value: Value to use for padding. - - Returns: - A mapping of stacked table names to stacked table values. - """ - - # Gather stacked table information. - stacked_table_map: dict[ - str, - tuple[StackedTableSpec, list[TableSpec]], - ] = {} - - def collect_stacked_tables(table_spec: TableSpec) -> None: - stacked_table_spec = _get_stacked_table_spec(table_spec, num_shards) - stacked_table_name = stacked_table_spec.stack_name - if stacked_table_name not in stacked_table_map: - stacked_table_map[stacked_table_name] = (stacked_table_spec, []) - stacked_table_map[stacked_table_name][1].append(table_spec) - - _ = jax.tree.map(collect_stacked_tables, table_specs) - - table_map: dict[str, Nested[jax.Array]] = {} - - def collect_tables(table_spec: TableSpec, table: Nested[jax.Array]) -> None: - table_map[table_spec.name] = table - - _ = jax.tree.map(collect_tables, table_specs, tables) - - stacked_tables: dict[str, Nested[jax.Array]] = {} - for ( - stacked_table_spec, - table_specs, - ) in stacked_table_map.values(): - stack_vocab_size = stacked_table_spec.stack_vocab_size - sharded_vocab_size = stack_vocab_size // num_shards - stack_embedding_dim = stacked_table_spec.stack_embedding_dim - - # Allocate initial buffer. The stacked table will be divided among - # shards by splitting the vocabulary dimension: - # [ v, e ] -> [s, v/s, e] - stacked_table_tree = jax.tree.map( - lambda _: jnp.zeros( - # pylint: disable-next=cell-var-from-loop, used only in loop body. - shape=(num_shards, sharded_vocab_size, stack_embedding_dim), - dtype=jnp.float32, - ), - table_map[table_specs[0].name], - ) - - for table_spec in table_specs: - table_tree = table_map[table_spec.name] - stacked_table_tree = jax.tree.map( - lambda stacked_table, table: _stack_and_shard_table( - # pylint: disable-next=cell-var-from-loop, used only in loop body. - stacked_table, - # pylint: disable-next=cell-var-from-loop, used only in loop body. - table_spec, - table, - num_shards, - pad_value, - ), - stacked_table_tree, - table_tree, - ) - - stacked_tables[stacked_table_spec.stack_name] = stacked_table_tree - - return stacked_tables - - -def _unshard_and_unstack_table( - table_spec: TableSpec, - stacked_table_tree: Nested[jax.Array], - num_shards: int, -) -> Nested[jax.Array]: - """Unshards and unstacks a single table.""" - vocabulary_size = table_spec.vocabulary_size - embedding_dim = table_spec.embedding_dim - - def _unshard_and_unstack_single_table( - table_spec: TableSpec, stacked_table: jax.Array - ) -> jax.Array: - stack_embedding_dim = stacked_table.shape[-1] - - # Maybe re-shape in case it was flattened. - stacked_table = stacked_table.reshape( - num_shards, -1, stack_embedding_dim - ) - sharded_vocabulary_size = ( - _round_up_to_multiple(vocabulary_size, 8 * num_shards) // num_shards - ) - - # Extract padded values from the stacked table. - setting_in_stack = table_spec.setting_in_stack - row = setting_in_stack.row_offset_in_shard - padded_values = stacked_table[ - :, row : (row + sharded_vocabulary_size), : - ] - - # Un-rotate shards. - padded_values = jnp.roll( - padded_values, -setting_in_stack.shard_rotation, axis=0 - ) - - # Un-mod-shard. - padded_values = jnp.swapaxes(padded_values, 0, 1).reshape( - -1, stack_embedding_dim - ) - - # Un-pad. - return padded_values[:vocabulary_size, :embedding_dim] - - output: Nested[jax.Array] = jax.tree.map( - lambda stacked_table: _unshard_and_unstack_single_table( - table_spec, stacked_table - ), - stacked_table_tree, - ) - return output - - -def unshard_and_unstack_tables( - table_specs: Nested[TableSpec], - stacked_tables: Mapping[str, Nested[jax.Array]], - num_shards: int, -) -> Nested[jax.Array]: - """Unshards and unstacks a collection of tables. - - Args: - table_specs: Nested collection of unstacked table specifications. - stacked_tables: Mapping of stacked table names to stacked table values. - num_shards: Number of shards in the table (typically - `global_device_count * num_sc_per_device`). - - Returns: - A mapping of table names to unstacked table values. - """ - output: Nested[jax.Array] = jax.tree.map( - lambda table_spec: _unshard_and_unstack_table( - table_spec, - stacked_tables[ - _get_stacked_table_spec(table_spec, num_shards=1).stack_name - ], - num_shards, - ), - table_specs, - ) - return output - - -def get_table_stacks( - table_specs: Nested[TableSpec], -) -> dict[str, list[TableSpec]]: - """Extracts lists of tables that are stacked together. - - Args: - table_specs: Nested collection of table specifications. - - Returns: - A mapping of stacked table names to lists of table specifications for - each stack. - """ - stacked_table_specs: dict[str, list[TableSpec]] = collections.defaultdict( - list - ) - flat_table_specs, _ = jax.tree.flatten(table_specs) - for table_spec in flat_table_specs: - table_spec = typing.cast(TableSpec, table_spec) - stacked_table_spec = table_spec.stacked_table_spec - if stacked_table_spec is not None: - stacked_table_specs[stacked_table_spec.stack_name].append( - table_spec - ) - else: - stacked_table_specs[table_spec.name].append(table_spec) - - return stacked_table_specs - - def convert_to_numpy( ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any, dtype: Any, @@ -522,7 +228,10 @@ def collect_tokens_and_weights( for table_name in tables_names: shard_ends = preprocessed_inputs.lhs_row_pointers[table_name] shard_starts = np.concatenate( - [np.asarray([0]), _round_up_to_multiple(shard_ends[:-1], 8)] + [ + np.asarray([0]), + table_stacking._next_largest_multiple(shard_ends[:-1], 8), + ] ) out[table_name] = ShardedCooMatrix( shard_starts=shard_starts, diff --git a/keras_rs/src/layers/embedding/jax/test_utils.py b/keras_rs/src/layers/embedding/jax/test_utils.py index b051781e..29eab24f 100644 --- a/keras_rs/src/layers/embedding/jax/test_utils.py +++ b/keras_rs/src/layers/embedding/jax/test_utils.py @@ -9,10 +9,10 @@ from jax import numpy as jnp from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec +from jax_tpu_embedding.sparsecore.lib.nn import table_stacking from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec -from keras_rs.src.layers.embedding.jax import embedding_utils from keras_rs.src.layers.embedding.jax.embedding_utils import FeatureSamples from keras_rs.src.types import Nested @@ -316,7 +316,7 @@ def stack_shard_and_put_tables( num_shards: int, sharding: jax.sharding.Sharding, ) -> dict[str, embedding.EmbeddingVariables]: - sharded_tables = embedding_utils.stack_and_shard_tables( + sharded_tables = table_stacking.stack_and_shard_tables( table_specs, tables, num_shards ) output: dict[str, embedding.EmbeddingVariables] = jax.device_put( @@ -336,8 +336,11 @@ def get_unshard_and_unstack_tables( num_shards: int, ) -> Nested[jax.Array]: sharded_tables = jax.device_get(sharded_tables) - return embedding_utils.unshard_and_unstack_tables( - table_specs, sharded_tables, num_shards + return typing.cast( + Nested[jax.Array], + table_stacking.unshard_and_unstack_tables( + table_specs, sharded_tables, num_shards + ), )