Skip to content

feat: Add true multi-node inference support for device_map="auto" (#1890)#4066

Open
Saurav-Gupta-13 wants to merge 1 commit into
huggingface:mainfrom
Saurav-Gupta-13:feature/multi-node-device-map
Open

feat: Add true multi-node inference support for device_map="auto" (#1890)#4066
Saurav-Gupta-13 wants to merge 1 commit into
huggingface:mainfrom
Saurav-Gupta-13:feature/multi-node-device-map

Conversation

@Saurav-Gupta-13
Copy link
Copy Markdown

Resolves #1890

Description

Currently, passing device_map="auto" inside accelerate launch across multiple machines causes each node to independently attempt to load the entire chunk of the model onto its local GPUs, ignoring the rest of the cluster and causing OOM crashes.

This PR completely overhauls the core modeling architecture to natively support PyTorch Distributed RPC (torch.distributed.rpc), enabling true multi-node device_map="auto" pipelining without requiring any dummy inputs or complex PiPPy tracing.

Key Architectural Changes

  1. Global GPU Memory Mapping: Modified get_max_memory() to detect SPMD environments and use torch.distributed.all_gather_object to fetch memory constraints from every GPU on every node in the cluster. This builds a continuous global mapping (e.g. GPUs 0..15 for a 2-node 8x GPU cluster).
  2. Meta Layer Interception: In load_checkpoint_and_dispatch, we rewrite the device map so that layers assigned to a remote node are mapped to the "meta" device locally. This prevents local processes from wasting RAM loading weights for layers they don't own.
  3. PyTorch Distributed RPC Hooking: In dispatch_model, we detect remote layers and automatically spin up a torch.distributed.rpc agent network. Instead of attaching a standard local AlignDevicesHook, we dynamically rewrite the remote module's .forward() method to execute an rpc.rpc_sync network request to the target rank.
  4. Master-Worker SPMD Mode: To avoid duplicate execution and deadlocks, all worker nodes (Ranks > 0) automatically have their main .forward loop disabled, transforming them into pure background RPC execution servers serving the Rank 0 head node.

This integration is completely transparent to the user. No model tracing or prepare_pippy configuration is required. Simply pass device_map="auto", and the entire cluster functions as a single unified GPU.

@Saurav-Gupta-13
Copy link
Copy Markdown
Author

Hey everyone! 👋

I noticed this was still a huge pain point for the community, so I just opened a PR that natively solves the multi-node device_map="auto" issue without requiring any complex prepare_pippy tracing or dummy inputs!

I completely overhauled the core modeling architecture to natively leverage PyTorch Distributed RPC (torch.distributed.rpc).

Here is how it works under the hood:

  1. Global GPU Memory Mapping: get_max_memory() now detects an SPMD environment and uses torch.distributed.all_gather_object to build a continuous global mapping of every GPU across all nodes.
  2. Meta Layer Interception: We rewrite the calculated device_map so that layers assigned to a remote node are mapped to the "meta" device locally. This completely prevents local processes from OOMing or wasting RAM loading weights for layers they don't own.
  3. RPC Network Routing: Instead of attaching a standard local memory movement hook (AlignDevicesHook), we dynamically rewrite the remote module's .forward() method to execute an rpc.rpc_sync network request directly to the target node.
  4. Master-Worker SPMD Mode: To prevent deadlocks and duplicate batch processing, all worker nodes (Ranks > 0) automatically have their main .forward loop disabled, transforming them into pure background RPC execution servers serving the Rank 0 head node.

The end result is completely transparent to the user. You simply pass device_map="auto", and your cluster acts as one giant GPU.

Would love to get your thoughts and feedback on the implementation! Let me know if there are any specific tests you'd like me to run. 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature request: support for multi-node, multi-GPU distributed inference

1 participant