Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions keras_rs/src/layers/embedding/jax/distributed_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
from jax import numpy as jnp
from jax.experimental import layout as jax_layout
from jax.experimental import multihost_utils
from jax_tpu_embedding.sparsecore.lib.nn import embedding
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
from jax_tpu_embedding.sparsecore.lib.nn import (
Expand Down Expand Up @@ -600,31 +601,26 @@ def _sparsecore_preprocess(
# underlying stacked tables specs in the feature specs.

# Aggregate stats across all processes/devices via pmax.
num_local_cpu_devices = jax.local_device_count("cpu")

def pmax_aggregate(x: Any) -> Any:
if not hasattr(x, "ndim"):
x = np.array(x)
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
return jax.pmap(
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
axis_name="all_cpus",
backend="cpu",
)(tiled_x)[0]

full_stats = jax.tree.map(pmax_aggregate, stats)
all_stats = multihost_utils.process_allgather(stats)
aggregated_stats = jax.tree.map(
lambda x: jnp.max(x, axis=0), all_stats
)
Comment on lines 603 to +607
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The original code used jax.pmap with jax.lax.pmax to aggregate statistics across devices. This change replaces that with multihost_utils.process_allgather followed by jnp.max. This seems like a good simplification, leveraging a dedicated utility for multihost aggregation.

However, it's important to ensure that process_allgather correctly handles the data sharding and aggregation across multiple hosts in your specific environment. Double-check that the resulting aggregated_stats contains the expected maximum values across all processes.


# Check if stats changed enough to warrant action.
stacked_table_specs = embedding.get_stacked_table_specs(
self._config.feature_specs
)
changed = any(
np.max(full_stats.max_ids_per_partition[stack_name])
np.max(aggregated_stats.max_ids_per_partition[stack_name])
> spec.max_ids_per_partition
or np.max(full_stats.max_unique_ids_per_partition[stack_name])
or np.max(
aggregated_stats.max_unique_ids_per_partition[stack_name]
)
> spec.max_unique_ids_per_partition
or (
np.max(full_stats.required_buffer_size_per_sc[stack_name])
np.max(
aggregated_stats.required_buffer_size_per_sc[stack_name]
)
* num_sc_per_device
)
> (spec.suggested_coo_buffer_size_per_device or 0)
Expand All @@ -634,7 +630,9 @@ def pmax_aggregate(x: Any) -> Any:
# Update configuration and repeat preprocessing if stats changed.
if changed:
embedding.update_preprocessing_parameters(
self._config.feature_specs, full_stats, num_sc_per_device
self._config.feature_specs,
aggregated_stats,
num_sc_per_device,
)

# Re-execute preprocessing with consistent input statistics.
Expand Down