Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions keras_rs/src/layers/embedding/jax/distributed_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _add_table_variable(
table_specs: Sequence[embedding_spec.TableSpec],
num_shards: int,
add_slot_variables: bool,
) -> tuple[keras.Variable, tuple[keras.Variable, ...] | None]:
) -> embedding.EmbeddingVariables:
stacked_table_spec = typing.cast(
embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
)
Expand Down Expand Up @@ -334,7 +334,7 @@ def _add_table_variable(
slot_initializers, slot_variables
)

return table_variable, slot_variables
return embedding.EmbeddingVariables(table_variable, slot_variables)

@keras_utils.no_automatic_dependency_tracking
def _sparsecore_init(
Expand Down Expand Up @@ -738,8 +738,8 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
# Assign stacked table variables to the device values.
keras.tree.map_structure_up_to(
device_tables,
lambda table_and_slot_variables,
table_value: table_and_slot_variables[0].assign(table_value),
lambda embedding_variables,
table_value: embedding_variables.table.assign(table_value),
self._table_and_slot_variables,
device_tables,
)
Expand All @@ -754,8 +754,10 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:

# Extract only the table variables, not the gradient slot variables.
table_variables = {
name: jax.device_get(table_and_slots[0].value)
for name, table_and_slots in self._table_and_slot_variables.items()
name: jax.device_get(embedding_variables.table.value)
for name, embedding_variables in (
self._table_and_slot_variables.items()
)
}

return typing.cast(
Expand Down