feat: add colocation weight sync via direct tensor passing#999
feat: add colocation weight sync via direct tensor passing#999NJX-njx wants to merge 2 commits intoinclusionAI:mainfrom
Conversation
Add 'tensor' weight update mode for colocation deployment where training and inference processes share the same GPU. NCCL cannot communicate between processes on the same device, so this bypasses NCCL entirely. Implementation: - New WeightUpdateMeta.from_colocation() factory method - Core colocation_sync module with chunked tensor export - FSDPEngine, MegatronEngine, ArchonEngine tensor dispatch - SGLang/vLLM remote engine tensor update endpoints - RolloutController callback route for tensor updates - vLLM worker extension for direct weight loading - RL trainer 'tensor' mode configuration The approach: 1. Training process gathers full tensors (handling DTensor/FSDP sharding) 2. Tensors are chunked by weight_chunked_mem_mb and passed to inference 3. Inference engine loads weights directly from received tensors Closes inclusionAI#992
|
You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard. |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical limitation in colocation deployments where training and inference occur on the same GPU, by implementing a novel weight synchronization mechanism. It introduces a direct tensor passing approach, circumventing the inherent inability of NCCL to perform inter-process communication within a single device. This enhancement allows for efficient and reliable weight updates in such shared-GPU environments, improving the overall performance and stability of colocation setups. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-structured feature: a tensor weight update mode for colocation deployments to bypass NCCL limitations. The implementation is clean, with a dedicated core module for the new synchronization logic that is correctly integrated into the various training and inference engines. The API changes are consistent, and the inclusion of tests is a great addition. I have one minor suggestion to improve efficiency and code clarity in the vLLM worker extension.
Note: Security Review did not run due to the size of the PR.
| for name in names: | ||
| tensor = weights[name].to(self.model_runner.device) | ||
| self.model_runner.model.load_weights(weights=[(name, tensor)]) |
There was a problem hiding this comment.
The current implementation calls self.model_runner.model.load_weights inside a loop for each tensor. While this works, it's more efficient and readable to collect all tensors into a list and make a single call to load_weights, which is designed to accept a list of weight tuples. This change reduces Python function call overhead and better aligns with the API's intended use.
| for name in names: | |
| tensor = weights[name].to(self.model_runner.device) | |
| self.model_runner.model.load_weights(weights=[(name, tensor)]) | |
| weights_to_load = [(name, weights[name].to(self.model_runner.device)) | |
| for name in names] | |
| self.model_runner.model.load_weights(weights=weights_to_load) |
There was a problem hiding this comment.
Pull request overview
Adds a new tensor-based weight synchronization mode intended for colocation deployments (training + inference sharing the same GPU), bypassing NCCL/XCCL by transferring weights as serialized tensors to the inference engine.
Changes:
- Introduces
WeightUpdateMeta.type = "tensor"and corresponding CLI/config handling. - Adds a new colocation sync implementation that gathers full tensors (DTensor-aware) and pushes them in chunks to the inference engine.
- Extends remote inference backends/servers (vLLM, SGLang) and controller plumbing to accept tensor-based weight updates, plus basic unit tests.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
areal/api/io_struct.py |
Extends WeightUpdateMeta with "tensor" and adds from_colocation() factory. |
areal/api/engine_api.py |
Adds InferenceEngine.update_weights_from_tensor() abstract API. |
areal/api/cli_args.py |
Adds "tensor" to weight_update_mode CLI choices and help text. |
areal/engine/core/colocation_sync.py |
New implementation: gathers full tensors, chunks, pauses/resumes generation, pushes updates via engine API. |
areal/engine/fsdp_engine.py |
Routes meta.type == "tensor" into colocation sync path. |
areal/engine/megatron_engine.py |
Routes meta.type == "tensor" into colocation sync path. |
areal/experimental/engine/archon_engine.py |
Routes meta.type == "tensor" into colocation sync path. |
areal/infra/remote_inf_engine.py |
Adds tensor-weight-update request builder hook and async remote update entrypoint. |
areal/engine/vllm_remote.py |
Implements tensor update request building and forwards the API call. |
areal/engine/sglang_remote.py |
Implements tensor update request building and forwards the API call. |
areal/engine/vllm_ext/areal_vllm_server.py |
Adds FastAPI route + EngineCore hook to receive tensor weight updates. |
areal/engine/vllm_ext/vllm_worker_extension.py |
Adds worker-side method to apply received tensor weights. |
areal/infra/controller/rollout_controller.py |
Adds controller callback route + RPC method for tensor updates. |
areal/trainer/rl_trainer.py |
Constructs WeightUpdateMeta via from_colocation() when weight_update_mode == "tensor". |
areal/utils/logging.py |
Adds ColocationSync logger color mapping. |
tests/test_colocation_sync.py |
Adds basic tests for meta factory and colocation sync helpers/method presence. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| class TestMegatronEngineUpdateWeightsTensor: | ||
| """Test MegatronEngine.update_weights dispatches to tensor mode.""" | ||
|
|
||
| def test_update_weights_dispatches_tensor_type(self): |
There was a problem hiding this comment.
This test imports MegatronEngine unconditionally. megatron-core is declared as an optional extra (see pyproject.toml), so this import can fail in environments where the extra isn't installed, causing unrelated CI failures. Use pytest.importorskip("megatron") (or similar) / conditional skipping around this test.
| def test_update_weights_dispatches_tensor_type(self): | |
| def test_update_weights_dispatches_tensor_type(self): | |
| pytest.importorskip("megatron_core") |
| tensor = torch.randn(4, 4) | ||
| result = _get_full_tensor(tensor) | ||
| assert torch.equal(tensor, result) |
There was a problem hiding this comment.
_get_full_tensor may move a CPU tensor to the current accelerator device. On CUDA-enabled test runners this makes torch.equal(cpu_tensor, gpu_tensor) error due to device mismatch. To avoid device-dependent failures, make the test compare on the same device (e.g., move both to CPU before comparing, or construct the input tensor on current_platform.device_type).
| import dataclasses | ||
| from concurrent.futures import Future | ||
| from unittest.mock import MagicMock, patch |
There was a problem hiding this comment.
Unused imports (dataclasses, patch) increase lint noise and can mask real issues in tests. Remove unused imports or use them.
| import dataclasses | |
| from concurrent.futures import Future | |
| from unittest.mock import MagicMock, patch | |
| from concurrent.futures import Future | |
| from unittest.mock import MagicMock |
| """Helper to update weights via direct tensor passing (colocation mode). | ||
|
|
||
| Serializes tensor data and sends to inference servers. For colocation, both | ||
| processes are on the same GPU, so the data transfer is efficient. |
There was a problem hiding this comment.
The docstring claims tensor passing is efficient because both processes share the same GPU, but the current backends serialize tensor.detach().cpu() and send JSON over HTTP, which involves a GPU→CPU copy and network stack overhead. Consider updating this docstring to reflect the actual mechanism (CPU serialization over HTTP) or changing the implementation if true CUDA IPC/zero-copy is intended.
| """Helper to update weights via direct tensor passing (colocation mode). | |
| Serializes tensor data and sends to inference servers. For colocation, both | |
| processes are on the same GPU, so the data transfer is efficient. | |
| """Helper to update weights via tensor payloads (colocation mode). | |
| Serializes tensor data on the CPU and sends it over HTTP to inference | |
| servers. In colocation mode, services typically run on the same host | |
| or node, so this CPU serialization plus HTTP transport can still be | |
| relatively efficient, but it is not GPU zero-copy or CUDA IPC. |
| @dataclass | ||
| class WeightUpdateMeta: | ||
| type: Literal["disk", "nccl"] | ||
| type: Literal["disk", "nccl", "tensor"] |
There was a problem hiding this comment.
WeightUpdateMeta.type is annotated as Literal["disk", "nccl", "tensor"], but the rest of the codebase (and the factory methods in this same class) uses "xccl" (e.g., from_fsdp_xccl/from_megatron_xccl set type="xccl", and engines branch on meta.type == "xccl"). This makes the annotation inaccurate and breaks static checking. Update the Literal (and/or rename the string used everywhere) so it matches the actual runtime values (likely "disk" | "xccl" | "tensor").
| type: Literal["disk", "nccl", "tensor"] | |
| type: Literal["disk", "xccl", "tensor"] |
| dist.barrier(group=cpu_group) | ||
|
|
||
| weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024 | ||
| main_rank = dist.get_rank() == 0 | ||
|
|
||
| buffer_size = 0 | ||
| named_tensors: list[tuple[str, torch.Tensor]] = [] | ||
|
|
||
| if get_model_name_parameters is not None: | ||
| param_iterator = get_model_name_parameters() | ||
| else: | ||
| param_iterator = model.named_parameters() | ||
|
|
||
| if use_lora: | ||
| param_iterator = ( | ||
| (name, param) | ||
| for name, param in param_iterator | ||
| if param.requires_grad | ||
| ) | ||
|
|
||
| for name, param in param_iterator: | ||
| tensor = _get_full_tensor(param) | ||
|
|
||
| # Non-main ranks only help to get full tensor (for FSDP gather) | ||
| if not main_rank: | ||
| continue | ||
|
|
||
| tensor_size = tensor.numel() * tensor.element_size() | ||
|
|
||
| if tensor_size + buffer_size > weight_chunked_mem_size: | ||
| _update_tensor_bucket(rollout_engine, named_tensors) | ||
| buffer_size = 0 | ||
|
|
||
| named_tensors.append((name, tensor)) | ||
| buffer_size += tensor_size | ||
|
|
||
| # Flush remaining | ||
| if named_tensors: | ||
| _update_tensor_bucket(rollout_engine, named_tensors) | ||
|
|
||
| dist.barrier(group=cpu_group) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| rollout_engine.continue_generation() | ||
|
|
||
| current_platform.synchronize() | ||
| dist.barrier(group=cpu_group) | ||
|
|
||
|
|
There was a problem hiding this comment.
pause_generation() is called on rank 0 before the update, but continue_generation() is only called on the success path. If an exception occurs during tensor materialization or the HTTP update, the inference engine can remain paused indefinitely. Wrap the update section in a try/finally (rank 0 only) to guarantee continue_generation() runs (and consider logging/re-raising after resume).
| dist.barrier(group=cpu_group) | |
| weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024 | |
| main_rank = dist.get_rank() == 0 | |
| buffer_size = 0 | |
| named_tensors: list[tuple[str, torch.Tensor]] = [] | |
| if get_model_name_parameters is not None: | |
| param_iterator = get_model_name_parameters() | |
| else: | |
| param_iterator = model.named_parameters() | |
| if use_lora: | |
| param_iterator = ( | |
| (name, param) | |
| for name, param in param_iterator | |
| if param.requires_grad | |
| ) | |
| for name, param in param_iterator: | |
| tensor = _get_full_tensor(param) | |
| # Non-main ranks only help to get full tensor (for FSDP gather) | |
| if not main_rank: | |
| continue | |
| tensor_size = tensor.numel() * tensor.element_size() | |
| if tensor_size + buffer_size > weight_chunked_mem_size: | |
| _update_tensor_bucket(rollout_engine, named_tensors) | |
| buffer_size = 0 | |
| named_tensors.append((name, tensor)) | |
| buffer_size += tensor_size | |
| # Flush remaining | |
| if named_tensors: | |
| _update_tensor_bucket(rollout_engine, named_tensors) | |
| dist.barrier(group=cpu_group) | |
| if dist.get_rank() == 0: | |
| rollout_engine.continue_generation() | |
| current_platform.synchronize() | |
| dist.barrier(group=cpu_group) | |
| try: | |
| dist.barrier(group=cpu_group) | |
| weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024 | |
| main_rank = dist.get_rank() == 0 | |
| buffer_size = 0 | |
| named_tensors: list[tuple[str, torch.Tensor]] = [] | |
| if get_model_name_parameters is not None: | |
| param_iterator = get_model_name_parameters() | |
| else: | |
| param_iterator = model.named_parameters() | |
| if use_lora: | |
| param_iterator = ( | |
| (name, param) | |
| for name, param in param_iterator | |
| if param.requires_grad | |
| ) | |
| for name, param in param_iterator: | |
| tensor = _get_full_tensor(param) | |
| # Non-main ranks only help to get full tensor (for FSDP gather) | |
| if not main_rank: | |
| continue | |
| tensor_size = tensor.numel() * tensor.element_size() | |
| if tensor_size + buffer_size > weight_chunked_mem_size: | |
| _update_tensor_bucket(rollout_engine, named_tensors) | |
| buffer_size = 0 | |
| named_tensors.append((name, tensor)) | |
| buffer_size += tensor_size | |
| # Flush remaining | |
| if named_tensors: | |
| _update_tensor_bucket(rollout_engine, named_tensors) | |
| dist.barrier(group=cpu_group) | |
| current_platform.synchronize() | |
| dist.barrier(group=cpu_group) | |
| finally: | |
| if dist.get_rank() == 0: | |
| try: | |
| rollout_engine.continue_generation() | |
| except Exception: | |
| logger.exception("Failed to resume inference engine after weight update") | |
| raise |
| placements=tensor.placements, | ||
| ).full_tensor() | ||
| else: | ||
| if tensor.device.type == "cpu": |
There was a problem hiding this comment.
For non-DTensor parameters on CPU, _get_full_tensor unconditionally moves them to current_platform.device_type. In offload/CPU-param cases this can force large GPU copies only to later serialize back to CPU (the remote backends call tensor.detach().cpu()), increasing both time and peak memory. Consider keeping CPU tensors on CPU here (or making the device move optional/explicit) and only moving when required for a sharded gather.
| if tensor.device.type == "cpu": | |
| # For non-DTensor tensors, avoid unnecessary CPU -> device transfers. | |
| # Keep CPU tensors on CPU, and only move if the tensor is on a | |
| # different non-CPU device than the current platform device. | |
| if tensor.device.type != "cpu" and tensor.device.type != current_platform.device_type: |
| tensor_kwargs = {"allocation_mode": self.allocation_mode} | ||
| if config.actor.use_lora: | ||
| tensor_kwargs.update( | ||
| { | ||
| "use_lora": config.actor.use_lora, | ||
| "lora_name": config.gconfig.lora_name, | ||
| "base_model_name": config.actor.path, | ||
| } | ||
| ) |
There was a problem hiding this comment.
In tensor mode, use_lora, lora_name, and base_model_name are propagated into WeightUpdateMeta, but the tensor-sync path doesn’t appear to use any of that metadata: update_weights_from_tensor(...) only receives raw tensors, and the vLLM tensor endpoint loads weights directly into the base model (not via the LoRA manager). As written, enabling actor.use_lora with weight_update_mode: tensor is likely to no-op or fail. Either explicitly disallow LoRA in tensor mode (raise early) or extend the tensor update API/endpoint to handle LoRA adapter updates.
| tensor_kwargs = {"allocation_mode": self.allocation_mode} | |
| if config.actor.use_lora: | |
| tensor_kwargs.update( | |
| { | |
| "use_lora": config.actor.use_lora, | |
| "lora_name": config.gconfig.lora_name, | |
| "base_model_name": config.actor.path, | |
| } | |
| ) | |
| if config.actor.use_lora: | |
| raise ValueError( | |
| "LoRA is not supported with weight_update_mode='tensor'. " | |
| "Please use 'disk' or 'xccl' for LoRA-based weight updates." | |
| ) | |
| tensor_kwargs = {"allocation_mode": self.allocation_mode} |
Summary
Adds a
tensorweight update mode for colocation deployment where training and inference processes share the same GPU. NCCL cannot communicate between processes on the same device, so this mode bypasses NCCL entirely and passes tensors directly.Closes #992
Problem
In colocation mode (training + inference on the same GPU), the current NCCL-based weight synchronization fails because NCCL cannot perform inter-process communication between processes sharing the same CUDA device. This is a fundamental limitation identified in #992.
Solution
Introduce a new
"tensor"weight update type that:weight_chunked_mem_mbto control memory usageChanges
API Layer
areal/api/io_struct.py: Added"tensor"toWeightUpdateMeta.typeLiteral, newfrom_colocation()factory methodareal/api/engine_api.py: Addedupdate_weights_from_tensor()abstract method toInferenceEngineareal/api/cli_args.py: Added"tensor"toweight_update_modechoicesCore Module
areal/engine/core/colocation_sync.py: New file - Core colocation sync logic with DTensor handling, chunked transfer, pause/continue generation lifecycleTraining Engines
areal/engine/fsdp_engine.py: Tensor dispatch inupdate_weights(),_update_weights_from_tensor()methodareal/engine/megatron_engine.py: Same pattern as FSDPareal/experimental/engine/archon_engine.py: Same patternInference Engines
areal/engine/sglang_remote.py:update_weights_from_tensor()+build_tensor_weight_update_requests()areal/engine/vllm_remote.py: Same pattern, targeting/areal_update_weights_tensorendpointareal/engine/vllm_ext/areal_vllm_server.py: New FastAPI route + EngineCore hookareal/engine/vllm_ext/vllm_worker_extension.py:update_weight_tensor()methodInfrastructure
areal/infra/remote_inf_engine.py: Protocol extension + async tensor update functionareal/infra/controller/rollout_controller.py: Callback route + controller methodTrainer
areal/trainer/rl_trainer.py:"tensor"mode handling withfrom_colocation()Tests
tests/test_colocation_sync.py: Unit tests for WeightUpdateMeta, helper functions, engine method existenceUsage
References