Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions keras_rs/src/layers/embedding/jax/distributed_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def sparsecore_build(

# Collect all stacked tables.
table_specs = embedding.get_table_specs(feature_specs)
table_stacks = embedding_utils.get_table_stacks(table_specs)
table_stacks = jte_table_stacking.get_table_stacks(table_specs)

# Create variables for all stacked tables and slot variables.
with sparsecore_distribution.scope():
Expand Down Expand Up @@ -516,7 +516,7 @@ def _sparsecore_symbolic_preprocess(

# Each stacked-table gets a ShardedCooMatrix.
table_specs = embedding.get_table_specs(self._config.feature_specs)
table_stacks = embedding_utils.get_table_stacks(table_specs)
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
stacked_table_specs = {
stack_name: stack[0].stacked_table_spec
for stack_name, stack in table_stacks.items()
Expand Down Expand Up @@ -720,7 +720,7 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
config = self._config
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
table_specs = embedding.get_table_specs(config.feature_specs)
sharded_tables = embedding_utils.stack_and_shard_tables(
sharded_tables = jte_table_stacking.stack_and_shard_tables(
table_specs,
tables,
num_table_shards,
Expand Down Expand Up @@ -763,7 +763,7 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:

return typing.cast(
dict[str, ArrayLike],
embedding_utils.unshard_and_unstack_tables(
jte_table_stacking.unshard_and_unstack_tables(
table_specs, table_variables, num_table_shards
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from keras_rs.src.layers.embedding.jax import (
distributed_embedding as jax_distributed_embedding,
)
from keras_rs.src.layers.embedding.jax import embedding_utils
from keras_rs.src.layers.embedding.jax import test_utils

keras.config.disable_traceback_filtering()
Expand Down Expand Up @@ -177,7 +176,7 @@ def test_sharded_matches_unsharded(self):
)
self.assertEqual(actual.shape, expected_shape)

unsharded_tables = embedding_utils.unshard_and_unstack_tables(
unsharded_tables = table_stacking_lib.unshard_and_unstack_tables(
table_specs,
{stacked_table_spec.stack_name: actual},
num_table_shards,
Expand Down
4 changes: 2 additions & 2 deletions keras_rs/src/layers/embedding/jax/embedding_lookup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def test_backward_pass(
res=(sharded_samples, sharded_table_and_slot_variables, None),
gradients=activation_grads,
)
updated_tables_and_slots = embedding_utils.unshard_and_unstack_tables(
updated_tables_and_slots = table_stacking.unshard_and_unstack_tables(
table_specs, updated_stacked_tables, num_table_shards
)

Expand Down Expand Up @@ -553,7 +553,7 @@ def loss_fn(params, lookups, labels):
lookup_grads = grads["lookup_tables"]

# Recover unstacked and unsharded gradients.
updated_tables_and_slots = embedding_utils.unshard_and_unstack_tables(
updated_tables_and_slots = table_stacking.unshard_and_unstack_tables(
table_specs, lookup_grads, num_table_shards
)

Expand Down
301 changes: 5 additions & 296 deletions keras_rs/src/layers/embedding/jax/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
"""Utility functions for manipulating JAX embedding tables and inputs."""

import collections
import typing
from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar

import jax
import numpy as np
from jax import numpy as jnp
from jax_tpu_embedding.sparsecore.lib.nn import embedding
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import StackedTableSpec
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec

from keras_rs.src.types import Nested

Expand All @@ -34,297 +31,6 @@ class ShardedCooMatrix(NamedTuple):
values: ArrayLike


def _round_up_to_multiple(value: int, multiple: int) -> int:
return ((value + multiple - 1) // multiple) * multiple


def _default_stacked_table_spec(
table_spec: TableSpec, num_shards: int, batch_size: int
) -> StackedTableSpec:
return StackedTableSpec(
stack_name=table_spec.name,
stack_vocab_size=_round_up_to_multiple(
table_spec.vocabulary_size, 8 * num_shards
),
stack_embedding_dim=_round_up_to_multiple(table_spec.embedding_dim, 8),
optimizer=table_spec.optimizer,
combiner=table_spec.combiner,
total_sample_count=batch_size,
max_ids_per_partition=table_spec.max_ids_per_partition,
max_unique_ids_per_partition=table_spec.max_unique_ids_per_partition,
)


def _get_stacked_table_spec(
table_spec: TableSpec, num_shards: int, batch_size: int = 0
) -> StackedTableSpec:
return table_spec.stacked_table_spec or _default_stacked_table_spec(
table_spec, num_shards, batch_size
)


def pad_table(
table_spec: TableSpec,
table_values: jax.Array,
num_shards: int,
pad_value: jnp.float32 = jnp.nan,
) -> jax.Array:
"""Adds appropriate padding to a table to prepare for stacking.

Args:
table_spec: Table specification describing the table to pad.
table_values: Table values array to pad.
num_shards: Number of shards in the table (typically
`global_device_count * num_sc_per_device`).
pad_value: Value to use for padding.

Returns:
Padded table values.
"""
vocabulary_size = table_spec.vocabulary_size
embedding_dim = table_spec.embedding_dim
padded_vocabulary_size = _round_up_to_multiple(
vocabulary_size, 8 * num_shards
)
stack_embedding_dim = _get_stacked_table_spec(
table_spec, num_shards
).stack_embedding_dim
return jnp.pad(
table_values,
(
(0, padded_vocabulary_size - vocabulary_size),
(0, stack_embedding_dim - embedding_dim),
),
constant_values=pad_value,
)


def _stack_and_shard_table(
stacked_table: jax.Array,
table_spec: TableSpec,
table: jax.Array,
num_shards: int,
pad_value: jnp.float32,
) -> jax.Array:
"""Stacks and shards a single table for use in sparsecore lookups."""
padded_values = pad_table(table_spec, table, num_shards, pad_value)
sharded_padded_vocabulary_size = padded_values.shape[0] // num_shards
stack_embedding_dim = stacked_table.shape[-1]

# Mod-shard vocabulary across devices.
sharded_values = jnp.swapaxes(
padded_values.reshape(-1, num_shards, stack_embedding_dim),
0,
1,
)

# Rotate shards.
setting_in_stack = table_spec.setting_in_stack
rotated_values = jnp.roll(
sharded_values, setting_in_stack.shard_rotation, axis=0
)

# Insert table into the stack.
table_row = setting_in_stack.row_offset_in_shard
stacked_table = stacked_table.at[
:, table_row : (table_row + sharded_padded_vocabulary_size), :
].set(rotated_values)

return stacked_table


def stack_and_shard_tables(
table_specs: Nested[TableSpec],
tables: Nested[ArrayLike],
num_shards: int,
pad_value: jnp.float32 = jnp.nan,
) -> dict[str, Nested[jax.Array]]:
"""Stacks and shards tables for use in sparsecore lookups.

Args:
table_specs: Nested collection of unstacked table specifications.
tables: Table values corresponding to the table_specs.
num_shards: Number of shards in the table (typically
`global_device_count * num_sc_per_device`).
pad_value: Value to use for padding.

Returns:
A mapping of stacked table names to stacked table values.
"""

# Gather stacked table information.
stacked_table_map: dict[
str,
tuple[StackedTableSpec, list[TableSpec]],
] = {}

def collect_stacked_tables(table_spec: TableSpec) -> None:
stacked_table_spec = _get_stacked_table_spec(table_spec, num_shards)
stacked_table_name = stacked_table_spec.stack_name
if stacked_table_name not in stacked_table_map:
stacked_table_map[stacked_table_name] = (stacked_table_spec, [])
stacked_table_map[stacked_table_name][1].append(table_spec)

_ = jax.tree.map(collect_stacked_tables, table_specs)

table_map: dict[str, Nested[jax.Array]] = {}

def collect_tables(table_spec: TableSpec, table: Nested[jax.Array]) -> None:
table_map[table_spec.name] = table

_ = jax.tree.map(collect_tables, table_specs, tables)

stacked_tables: dict[str, Nested[jax.Array]] = {}
for (
stacked_table_spec,
table_specs,
) in stacked_table_map.values():
stack_vocab_size = stacked_table_spec.stack_vocab_size
sharded_vocab_size = stack_vocab_size // num_shards
stack_embedding_dim = stacked_table_spec.stack_embedding_dim

# Allocate initial buffer. The stacked table will be divided among
# shards by splitting the vocabulary dimension:
# [ v, e ] -> [s, v/s, e]
stacked_table_tree = jax.tree.map(
lambda _: jnp.zeros(
# pylint: disable-next=cell-var-from-loop, used only in loop body.
shape=(num_shards, sharded_vocab_size, stack_embedding_dim),
dtype=jnp.float32,
),
table_map[table_specs[0].name],
)

for table_spec in table_specs:
table_tree = table_map[table_spec.name]
stacked_table_tree = jax.tree.map(
lambda stacked_table, table: _stack_and_shard_table(
# pylint: disable-next=cell-var-from-loop, used only in loop body.
stacked_table,
# pylint: disable-next=cell-var-from-loop, used only in loop body.
table_spec,
table,
num_shards,
pad_value,
),
stacked_table_tree,
table_tree,
)

stacked_tables[stacked_table_spec.stack_name] = stacked_table_tree

return stacked_tables


def _unshard_and_unstack_table(
table_spec: TableSpec,
stacked_table_tree: Nested[jax.Array],
num_shards: int,
) -> Nested[jax.Array]:
"""Unshards and unstacks a single table."""
vocabulary_size = table_spec.vocabulary_size
embedding_dim = table_spec.embedding_dim

def _unshard_and_unstack_single_table(
table_spec: TableSpec, stacked_table: jax.Array
) -> jax.Array:
stack_embedding_dim = stacked_table.shape[-1]

# Maybe re-shape in case it was flattened.
stacked_table = stacked_table.reshape(
num_shards, -1, stack_embedding_dim
)
sharded_vocabulary_size = (
_round_up_to_multiple(vocabulary_size, 8 * num_shards) // num_shards
)

# Extract padded values from the stacked table.
setting_in_stack = table_spec.setting_in_stack
row = setting_in_stack.row_offset_in_shard
padded_values = stacked_table[
:, row : (row + sharded_vocabulary_size), :
]

# Un-rotate shards.
padded_values = jnp.roll(
padded_values, -setting_in_stack.shard_rotation, axis=0
)

# Un-mod-shard.
padded_values = jnp.swapaxes(padded_values, 0, 1).reshape(
-1, stack_embedding_dim
)

# Un-pad.
return padded_values[:vocabulary_size, :embedding_dim]

output: Nested[jax.Array] = jax.tree.map(
lambda stacked_table: _unshard_and_unstack_single_table(
table_spec, stacked_table
),
stacked_table_tree,
)
return output


def unshard_and_unstack_tables(
table_specs: Nested[TableSpec],
stacked_tables: Mapping[str, Nested[jax.Array]],
num_shards: int,
) -> Nested[jax.Array]:
"""Unshards and unstacks a collection of tables.

Args:
table_specs: Nested collection of unstacked table specifications.
stacked_tables: Mapping of stacked table names to stacked table values.
num_shards: Number of shards in the table (typically
`global_device_count * num_sc_per_device`).

Returns:
A mapping of table names to unstacked table values.
"""
output: Nested[jax.Array] = jax.tree.map(
lambda table_spec: _unshard_and_unstack_table(
table_spec,
stacked_tables[
_get_stacked_table_spec(table_spec, num_shards=1).stack_name
],
num_shards,
),
table_specs,
)
return output


def get_table_stacks(
table_specs: Nested[TableSpec],
) -> dict[str, list[TableSpec]]:
"""Extracts lists of tables that are stacked together.

Args:
table_specs: Nested collection of table specifications.

Returns:
A mapping of stacked table names to lists of table specifications for
each stack.
"""
stacked_table_specs: dict[str, list[TableSpec]] = collections.defaultdict(
list
)
flat_table_specs, _ = jax.tree.flatten(table_specs)
for table_spec in flat_table_specs:
table_spec = typing.cast(TableSpec, table_spec)
stacked_table_spec = table_spec.stacked_table_spec
if stacked_table_spec is not None:
stacked_table_specs[stacked_table_spec.stack_name].append(
table_spec
)
else:
stacked_table_specs[table_spec.name].append(table_spec)

return stacked_table_specs


def convert_to_numpy(
ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
dtype: Any,
Expand Down Expand Up @@ -522,7 +228,10 @@ def collect_tokens_and_weights(
for table_name in tables_names:
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
shard_starts = np.concatenate(
[np.asarray([0]), _round_up_to_multiple(shard_ends[:-1], 8)]
[
np.asarray([0]),
table_stacking._next_largest_multiple(shard_ends[:-1], 8),
]
)
out[table_name] = ShardedCooMatrix(
shard_starts=shard_starts,
Expand Down
Loading