Skip to content

RFC: Parameter Server Training for the Keras JAX backend (Keras-JAX-PST) #20753

@SeeForTwo

Description

@SeeForTwo

RFC: Parameter Server Training for the Keras JAX backend (Keras-JAX-PST)

We are considering implementing Parameter Server Training (PST) for the Keras JAX backend. It would aim to provide a scalable and performant solution for asynchronous ML training. For PST, the training cluster contains M workers and N parameter servers where the master copy of the training variables (and embeddings) are placed on parameter servers. (Background: Scaling Distributed Machine Learning with the Parameter Server)

The advantages of PST include:

  • Large embeddings are sharded across multiple parameter servers. This enables use of embeddings that exceed the local memory available to a single device or the HBM available to all accelerators.
  • Training can be scaled across multiple CPUs (data parallelism) for increased speed even without accelerator hardware (GPUs, TPUs).
  • PST uses asynchronous training which is robust to individual worker failures/preemptions/restarts and potentially more performant with low availability guarantees.

We hope to make using PST a convenient option in an end-to-end recommendation solution with Keras.

We want your feedback on whether this would be of value to you.

Please comment below.

Metadata

Metadata

Labels

type:featureThe user is asking for a new feature.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions