diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 1f5916fa..bd69a0be 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -9,6 +9,7 @@ import numpy as np from jax import numpy as jnp from jax.experimental import layout as jax_layout +from jax.experimental import multihost_utils from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec from jax_tpu_embedding.sparsecore.lib.nn import ( @@ -600,31 +601,26 @@ def _sparsecore_preprocess( # underlying stacked tables specs in the feature specs. # Aggregate stats across all processes/devices via pmax. - num_local_cpu_devices = jax.local_device_count("cpu") - - def pmax_aggregate(x: Any) -> Any: - if not hasattr(x, "ndim"): - x = np.array(x) - tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim))) - return jax.pmap( - lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call] - axis_name="all_cpus", - backend="cpu", - )(tiled_x)[0] - - full_stats = jax.tree.map(pmax_aggregate, stats) + all_stats = multihost_utils.process_allgather(stats) + aggregated_stats = jax.tree.map( + lambda x: jnp.max(x, axis=0), all_stats + ) # Check if stats changed enough to warrant action. stacked_table_specs = embedding.get_stacked_table_specs( self._config.feature_specs ) changed = any( - np.max(full_stats.max_ids_per_partition[stack_name]) + np.max(aggregated_stats.max_ids_per_partition[stack_name]) > spec.max_ids_per_partition - or np.max(full_stats.max_unique_ids_per_partition[stack_name]) + or np.max( + aggregated_stats.max_unique_ids_per_partition[stack_name] + ) > spec.max_unique_ids_per_partition or ( - np.max(full_stats.required_buffer_size_per_sc[stack_name]) + np.max( + aggregated_stats.required_buffer_size_per_sc[stack_name] + ) * num_sc_per_device ) > (spec.suggested_coo_buffer_size_per_device or 0) @@ -634,7 +630,9 @@ def pmax_aggregate(x: Any) -> Any: # Update configuration and repeat preprocessing if stats changed. if changed: embedding.update_preprocessing_parameters( - self._config.feature_specs, full_stats, num_sc_per_device + self._config.feature_specs, + aggregated_stats, + num_sc_per_device, ) # Re-execute preprocessing with consistent input statistics.