feat: Add true multi-node inference support for device_map="auto" (#1890)#4066
Open
Saurav-Gupta-13 wants to merge 1 commit into
Open
feat: Add true multi-node inference support for device_map="auto" (#1890)#4066Saurav-Gupta-13 wants to merge 1 commit into
Saurav-Gupta-13 wants to merge 1 commit into
Conversation
Author
|
Hey everyone! 👋 I noticed this was still a huge pain point for the community, so I just opened a PR that natively solves the multi-node I completely overhauled the core modeling architecture to natively leverage PyTorch Distributed RPC ( Here is how it works under the hood:
The end result is completely transparent to the user. You simply pass Would love to get your thoughts and feedback on the implementation! Let me know if there are any specific tests you'd like me to run. 🚀 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Resolves #1890
Description
Currently, passing
device_map="auto"insideaccelerate launchacross multiple machines causes each node to independently attempt to load the entire chunk of the model onto its local GPUs, ignoring the rest of the cluster and causing OOM crashes.This PR completely overhauls the core modeling architecture to natively support PyTorch Distributed RPC (
torch.distributed.rpc), enabling true multi-nodedevice_map="auto"pipelining without requiring any dummy inputs or complex PiPPy tracing.Key Architectural Changes
get_max_memory()to detect SPMD environments and usetorch.distributed.all_gather_objectto fetch memory constraints from every GPU on every node in the cluster. This builds a continuous global mapping (e.g. GPUs0..15for a 2-node 8x GPU cluster).load_checkpoint_and_dispatch, we rewrite the device map so that layers assigned to a remote node are mapped to the"meta"device locally. This prevents local processes from wasting RAM loading weights for layers they don't own.dispatch_model, we detect remote layers and automatically spin up atorch.distributed.rpcagent network. Instead of attaching a standard localAlignDevicesHook, we dynamically rewrite the remote module's.forward()method to execute anrpc.rpc_syncnetwork request to the target rank..forwardloop disabled, transforming them into pure background RPC execution servers serving the Rank 0 head node.This integration is completely transparent to the user. No model tracing or
prepare_pippyconfiguration is required. Simply passdevice_map="auto", and the entire cluster functions as a single unified GPU.