-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Open
Labels
type:featureThe user is asking for a new feature.The user is asking for a new feature.
Description
RFC: Parameter Server Training for the Keras JAX backend (Keras-JAX-PST)
We are considering implementing Parameter Server Training (PST) for the Keras JAX backend. It would aim to provide a scalable and performant solution for asynchronous ML training. For PST, the training cluster contains M workers and N parameter servers where the master copy of the training variables (and embeddings) are placed on parameter servers. (Background: Scaling Distributed Machine Learning with the Parameter Server)
The advantages of PST include:
- Large embeddings are sharded across multiple parameter servers. This enables use of embeddings that exceed the local memory available to a single device or the HBM available to all accelerators.
- Training can be scaled across multiple CPUs (data parallelism) for increased speed even without accelerator hardware (GPUs, TPUs).
- PST uses asynchronous training which is robust to individual worker failures/preemptions/restarts and potentially more performant with low availability guarantees.
We hope to make using PST a convenient option in an end-to-end recommendation solution with Keras.
We want your feedback on whether this would be of value to you.
Please comment below.
Metadata
Metadata
Assignees
Labels
type:featureThe user is asking for a new feature.The user is asking for a new feature.