Skip to content

feat: add colocation weight sync via direct tensor passing#999

Open
NJX-njx wants to merge 2 commits intoinclusionAI:mainfrom
NJX-njx:feature/colocation-weight-sync
Open

feat: add colocation weight sync via direct tensor passing#999
NJX-njx wants to merge 2 commits intoinclusionAI:mainfrom
NJX-njx:feature/colocation-weight-sync

Conversation

@NJX-njx
Copy link
Contributor

@NJX-njx NJX-njx commented Mar 6, 2026

Summary

Adds a tensor weight update mode for colocation deployment where training and inference processes share the same GPU. NCCL cannot communicate between processes on the same device, so this mode bypasses NCCL entirely and passes tensors directly.

Closes #992

Problem

In colocation mode (training + inference on the same GPU), the current NCCL-based weight synchronization fails because NCCL cannot perform inter-process communication between processes sharing the same CUDA device. This is a fundamental limitation identified in #992.

Solution

Introduce a new "tensor" weight update type that:

  1. Training process gathers full tensors (handling DTensor/FSDP sharding and CPU offload)
  2. Chunks tensors by weight_chunked_mem_mb to control memory usage
  3. Passes tensors directly to the inference engine via HTTP serialization
  4. Inference engine loads weights from received tensors without NCCL

Changes

API Layer

  • areal/api/io_struct.py: Added "tensor" to WeightUpdateMeta.type Literal, new from_colocation() factory method
  • areal/api/engine_api.py: Added update_weights_from_tensor() abstract method to InferenceEngine
  • areal/api/cli_args.py: Added "tensor" to weight_update_mode choices

Core Module

  • areal/engine/core/colocation_sync.py: New file - Core colocation sync logic with DTensor handling, chunked transfer, pause/continue generation lifecycle

Training Engines

  • areal/engine/fsdp_engine.py: Tensor dispatch in update_weights(), _update_weights_from_tensor() method
  • areal/engine/megatron_engine.py: Same pattern as FSDP
  • areal/experimental/engine/archon_engine.py: Same pattern

Inference Engines

  • areal/engine/sglang_remote.py: update_weights_from_tensor() + build_tensor_weight_update_requests()
  • areal/engine/vllm_remote.py: Same pattern, targeting /areal_update_weights_tensor endpoint
  • areal/engine/vllm_ext/areal_vllm_server.py: New FastAPI route + EngineCore hook
  • areal/engine/vllm_ext/vllm_worker_extension.py: update_weight_tensor() method

Infrastructure

  • areal/infra/remote_inf_engine.py: Protocol extension + async tensor update function
  • areal/infra/controller/rollout_controller.py: Callback route + controller method

Trainer

  • areal/trainer/rl_trainer.py: "tensor" mode handling with from_colocation()

Tests

  • tests/test_colocation_sync.py: Unit tests for WeightUpdateMeta, helper functions, engine method existence

Usage

# In training config
actor:
  weight_update_mode: tensor  # Use for colocation deployment

References

  • verl - tensor-based weight sync approach
  • slime - offload/onload patterns

Add 'tensor' weight update mode for colocation deployment where training
and inference processes share the same GPU. NCCL cannot communicate
between processes on the same device, so this bypasses NCCL entirely.

Implementation:
- New WeightUpdateMeta.from_colocation() factory method
- Core colocation_sync module with chunked tensor export
- FSDPEngine, MegatronEngine, ArchonEngine tensor dispatch
- SGLang/vLLM remote engine tensor update endpoints
- RolloutController callback route for tensor updates
- vLLM worker extension for direct weight loading
- RL trainer 'tensor' mode configuration

The approach:
1. Training process gathers full tensors (handling DTensor/FSDP sharding)
2. Tensors are chunked by weight_chunked_mem_mb and passed to inference
3. Inference engine loads weights directly from received tensors

Closes inclusionAI#992
Copilot AI review requested due to automatic review settings March 6, 2026 13:06
@chatgpt-codex-connector
Copy link

You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical limitation in colocation deployments where training and inference occur on the same GPU, by implementing a novel weight synchronization mechanism. It introduces a direct tensor passing approach, circumventing the inherent inability of NCCL to perform inter-process communication within a single device. This enhancement allows for efficient and reliable weight updates in such shared-GPU environments, improving the overall performance and stability of colocation setups.

Highlights

  • New Weight Update Mode: Introduced a new "tensor" weight update mode specifically designed for colocation deployments where training and inference processes share the same GPU.
  • NCCL Bypass: This new mode completely bypasses NCCL (which cannot communicate between processes on the same device) by directly passing tensors between the training and inference engines.
  • Tensor Handling: The training process now gathers full tensors, handles DTensor/FSDP sharding and CPU offload, and chunks tensors to control memory usage before passing them.
  • Inference Engine Integration: The inference engine is updated to load weights directly from these received tensors, eliminating the need for NCCL-based synchronization.
  • API and Infrastructure Updates: API layers, core modules, various training and inference engines (FSDP, Megatron, Archon, SGLang, vLLM), and infrastructure components (controller, remote engine) have been updated to support this new tensor-based synchronization.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • areal/api/cli_args.py
    • Added "tensor" to weight_update_mode choices.
  • areal/api/engine_api.py
    • Added update_weights_from_tensor() abstract method to InferenceEngine.
  • areal/api/io_struct.py
    • Added "tensor" to WeightUpdateMeta.type Literal.
    • Added a new from_colocation() factory method for tensor-based sync.
  • areal/engine/core/colocation_sync.py
    • Added new file containing the core logic for colocation weight synchronization via direct tensor passing, including DTensor handling and chunked transfer.
  • areal/engine/fsdp_engine.py
    • Implemented _update_weights_from_tensor() method.
    • Updated update_weights() to dispatch to the new tensor-based update method.
  • areal/engine/megatron_engine.py
    • Implemented _update_weights_from_tensor() method.
    • Updated update_weights() to dispatch to the new tensor-based update method.
  • areal/engine/sglang_remote.py
    • Imported torch.
    • Added build_tensor_weight_update_requests() method.
    • Implemented update_weights_from_tensor().
  • areal/engine/vllm_ext/areal_vllm_server.py
    • Added a new FastAPI route /areal_update_weights_tensor.
    • Hooked the new route into EngineCore for tensor-based weight updates.
  • areal/engine/vllm_ext/vllm_worker_extension.py
    • Implemented update_weight_tensor() method to load weights from direct tensor data.
  • areal/engine/vllm_remote.py
    • Imported torch.
    • Added build_tensor_weight_update_requests() method.
    • Implemented update_weights_from_tensor().
  • areal/experimental/engine/archon_engine.py
    • Implemented _update_weights_from_tensor() method.
    • Updated update_weights() to dispatch to the new tensor-based update method.
  • areal/infra/controller/rollout_controller.py
    • Added a new callback route /callback/update_weights_tensor.
    • Implemented update_weights_from_tensor() async method.
  • areal/infra/remote_inf_engine.py
    • Imported torch.
    • Added build_tensor_weight_update_requests() abstract method.
    • Implemented update_weights_from_tensor().
    • Added _update_weights_from_tensor() helper function.
  • areal/trainer/rl_trainer.py
    • Added logic to initialize WeightUpdateMeta with from_colocation() when weight_update_mode is "tensor".
  • areal/utils/logging.py
    • Added "ColocationSync" to the logger color mapping.
  • tests/test_colocation_sync.py
    • Added new unit tests for WeightUpdateMeta tensor type support.
    • Added tests for helper functions in colocation_sync module.
    • Added tests to verify engine dispatch to tensor update methods.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-structured feature: a tensor weight update mode for colocation deployments to bypass NCCL limitations. The implementation is clean, with a dedicated core module for the new synchronization logic that is correctly integrated into the various training and inference engines. The API changes are consistent, and the inclusion of tests is a great addition. I have one minor suggestion to improve efficiency and code clarity in the vLLM worker extension.

Note: Security Review did not run due to the size of the PR.

Comment on lines +164 to +166
for name in names:
tensor = weights[name].to(self.model_runner.device)
self.model_runner.model.load_weights(weights=[(name, tensor)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation calls self.model_runner.model.load_weights inside a loop for each tensor. While this works, it's more efficient and readable to collect all tensors into a list and make a single call to load_weights, which is designed to accept a list of weight tuples. This change reduces Python function call overhead and better aligns with the API's intended use.

Suggested change
for name in names:
tensor = weights[name].to(self.model_runner.device)
self.model_runner.model.load_weights(weights=[(name, tensor)])
weights_to_load = [(name, weights[name].to(self.model_runner.device))
for name in names]
self.model_runner.model.load_weights(weights=weights_to_load)

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new tensor-based weight synchronization mode intended for colocation deployments (training + inference sharing the same GPU), bypassing NCCL/XCCL by transferring weights as serialized tensors to the inference engine.

Changes:

  • Introduces WeightUpdateMeta.type = "tensor" and corresponding CLI/config handling.
  • Adds a new colocation sync implementation that gathers full tensors (DTensor-aware) and pushes them in chunks to the inference engine.
  • Extends remote inference backends/servers (vLLM, SGLang) and controller plumbing to accept tensor-based weight updates, plus basic unit tests.

Reviewed changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
areal/api/io_struct.py Extends WeightUpdateMeta with "tensor" and adds from_colocation() factory.
areal/api/engine_api.py Adds InferenceEngine.update_weights_from_tensor() abstract API.
areal/api/cli_args.py Adds "tensor" to weight_update_mode CLI choices and help text.
areal/engine/core/colocation_sync.py New implementation: gathers full tensors, chunks, pauses/resumes generation, pushes updates via engine API.
areal/engine/fsdp_engine.py Routes meta.type == "tensor" into colocation sync path.
areal/engine/megatron_engine.py Routes meta.type == "tensor" into colocation sync path.
areal/experimental/engine/archon_engine.py Routes meta.type == "tensor" into colocation sync path.
areal/infra/remote_inf_engine.py Adds tensor-weight-update request builder hook and async remote update entrypoint.
areal/engine/vllm_remote.py Implements tensor update request building and forwards the API call.
areal/engine/sglang_remote.py Implements tensor update request building and forwards the API call.
areal/engine/vllm_ext/areal_vllm_server.py Adds FastAPI route + EngineCore hook to receive tensor weight updates.
areal/engine/vllm_ext/vllm_worker_extension.py Adds worker-side method to apply received tensor weights.
areal/infra/controller/rollout_controller.py Adds controller callback route + RPC method for tensor updates.
areal/trainer/rl_trainer.py Constructs WeightUpdateMeta via from_colocation() when weight_update_mode == "tensor".
areal/utils/logging.py Adds ColocationSync logger color mapping.
tests/test_colocation_sync.py Adds basic tests for meta factory and colocation sync helpers/method presence.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

class TestMegatronEngineUpdateWeightsTensor:
"""Test MegatronEngine.update_weights dispatches to tensor mode."""

def test_update_weights_dispatches_tensor_type(self):
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test imports MegatronEngine unconditionally. megatron-core is declared as an optional extra (see pyproject.toml), so this import can fail in environments where the extra isn't installed, causing unrelated CI failures. Use pytest.importorskip("megatron") (or similar) / conditional skipping around this test.

Suggested change
def test_update_weights_dispatches_tensor_type(self):
def test_update_weights_dispatches_tensor_type(self):
pytest.importorskip("megatron_core")

Copilot uses AI. Check for mistakes.
Comment on lines +68 to +70
tensor = torch.randn(4, 4)
result = _get_full_tensor(tensor)
assert torch.equal(tensor, result)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_full_tensor may move a CPU tensor to the current accelerator device. On CUDA-enabled test runners this makes torch.equal(cpu_tensor, gpu_tensor) error due to device mismatch. To avoid device-dependent failures, make the test compare on the same device (e.g., move both to CPU before comparing, or construct the input tensor on current_platform.device_type).

Copilot uses AI. Check for mistakes.
Comment on lines +5 to +7
import dataclasses
from concurrent.futures import Future
from unittest.mock import MagicMock, patch
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused imports (dataclasses, patch) increase lint noise and can mask real issues in tests. Remove unused imports or use them.

Suggested change
import dataclasses
from concurrent.futures import Future
from unittest.mock import MagicMock, patch
from concurrent.futures import Future
from unittest.mock import MagicMock

Copilot uses AI. Check for mistakes.
Comment on lines +1470 to +1473
"""Helper to update weights via direct tensor passing (colocation mode).

Serializes tensor data and sends to inference servers. For colocation, both
processes are on the same GPU, so the data transfer is efficient.
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring claims tensor passing is efficient because both processes share the same GPU, but the current backends serialize tensor.detach().cpu() and send JSON over HTTP, which involves a GPU→CPU copy and network stack overhead. Consider updating this docstring to reflect the actual mechanism (CPU serialization over HTTP) or changing the implementation if true CUDA IPC/zero-copy is intended.

Suggested change
"""Helper to update weights via direct tensor passing (colocation mode).
Serializes tensor data and sends to inference servers. For colocation, both
processes are on the same GPU, so the data transfer is efficient.
"""Helper to update weights via tensor payloads (colocation mode).
Serializes tensor data on the CPU and sends it over HTTP to inference
servers. In colocation mode, services typically run on the same host
or node, so this CPU serialization plus HTTP transport can still be
relatively efficient, but it is not GPU zero-copy or CUDA IPC.

Copilot uses AI. Check for mistakes.
@dataclass
class WeightUpdateMeta:
type: Literal["disk", "nccl"]
type: Literal["disk", "nccl", "tensor"]
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WeightUpdateMeta.type is annotated as Literal["disk", "nccl", "tensor"], but the rest of the codebase (and the factory methods in this same class) uses "xccl" (e.g., from_fsdp_xccl/from_megatron_xccl set type="xccl", and engines branch on meta.type == "xccl"). This makes the annotation inaccurate and breaks static checking. Update the Literal (and/or rename the string used everywhere) so it matches the actual runtime values (likely "disk" | "xccl" | "tensor").

Suggested change
type: Literal["disk", "nccl", "tensor"]
type: Literal["disk", "xccl", "tensor"]

Copilot uses AI. Check for mistakes.
Comment on lines +88 to +136
dist.barrier(group=cpu_group)

weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024
main_rank = dist.get_rank() == 0

buffer_size = 0
named_tensors: list[tuple[str, torch.Tensor]] = []

if get_model_name_parameters is not None:
param_iterator = get_model_name_parameters()
else:
param_iterator = model.named_parameters()

if use_lora:
param_iterator = (
(name, param)
for name, param in param_iterator
if param.requires_grad
)

for name, param in param_iterator:
tensor = _get_full_tensor(param)

# Non-main ranks only help to get full tensor (for FSDP gather)
if not main_rank:
continue

tensor_size = tensor.numel() * tensor.element_size()

if tensor_size + buffer_size > weight_chunked_mem_size:
_update_tensor_bucket(rollout_engine, named_tensors)
buffer_size = 0

named_tensors.append((name, tensor))
buffer_size += tensor_size

# Flush remaining
if named_tensors:
_update_tensor_bucket(rollout_engine, named_tensors)

dist.barrier(group=cpu_group)

if dist.get_rank() == 0:
rollout_engine.continue_generation()

current_platform.synchronize()
dist.barrier(group=cpu_group)


Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pause_generation() is called on rank 0 before the update, but continue_generation() is only called on the success path. If an exception occurs during tensor materialization or the HTTP update, the inference engine can remain paused indefinitely. Wrap the update section in a try/finally (rank 0 only) to guarantee continue_generation() runs (and consider logging/re-raising after resume).

Suggested change
dist.barrier(group=cpu_group)
weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024
main_rank = dist.get_rank() == 0
buffer_size = 0
named_tensors: list[tuple[str, torch.Tensor]] = []
if get_model_name_parameters is not None:
param_iterator = get_model_name_parameters()
else:
param_iterator = model.named_parameters()
if use_lora:
param_iterator = (
(name, param)
for name, param in param_iterator
if param.requires_grad
)
for name, param in param_iterator:
tensor = _get_full_tensor(param)
# Non-main ranks only help to get full tensor (for FSDP gather)
if not main_rank:
continue
tensor_size = tensor.numel() * tensor.element_size()
if tensor_size + buffer_size > weight_chunked_mem_size:
_update_tensor_bucket(rollout_engine, named_tensors)
buffer_size = 0
named_tensors.append((name, tensor))
buffer_size += tensor_size
# Flush remaining
if named_tensors:
_update_tensor_bucket(rollout_engine, named_tensors)
dist.barrier(group=cpu_group)
if dist.get_rank() == 0:
rollout_engine.continue_generation()
current_platform.synchronize()
dist.barrier(group=cpu_group)
try:
dist.barrier(group=cpu_group)
weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024
main_rank = dist.get_rank() == 0
buffer_size = 0
named_tensors: list[tuple[str, torch.Tensor]] = []
if get_model_name_parameters is not None:
param_iterator = get_model_name_parameters()
else:
param_iterator = model.named_parameters()
if use_lora:
param_iterator = (
(name, param)
for name, param in param_iterator
if param.requires_grad
)
for name, param in param_iterator:
tensor = _get_full_tensor(param)
# Non-main ranks only help to get full tensor (for FSDP gather)
if not main_rank:
continue
tensor_size = tensor.numel() * tensor.element_size()
if tensor_size + buffer_size > weight_chunked_mem_size:
_update_tensor_bucket(rollout_engine, named_tensors)
buffer_size = 0
named_tensors.append((name, tensor))
buffer_size += tensor_size
# Flush remaining
if named_tensors:
_update_tensor_bucket(rollout_engine, named_tensors)
dist.barrier(group=cpu_group)
current_platform.synchronize()
dist.barrier(group=cpu_group)
finally:
if dist.get_rank() == 0:
try:
rollout_engine.continue_generation()
except Exception:
logger.exception("Failed to resume inference engine after weight update")
raise

Copilot uses AI. Check for mistakes.
placements=tensor.placements,
).full_tensor()
else:
if tensor.device.type == "cpu":
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-DTensor parameters on CPU, _get_full_tensor unconditionally moves them to current_platform.device_type. In offload/CPU-param cases this can force large GPU copies only to later serialize back to CPU (the remote backends call tensor.detach().cpu()), increasing both time and peak memory. Consider keeping CPU tensors on CPU here (or making the device move optional/explicit) and only moving when required for a sharded gather.

Suggested change
if tensor.device.type == "cpu":
# For non-DTensor tensors, avoid unnecessary CPU -> device transfers.
# Keep CPU tensors on CPU, and only move if the tensor is on a
# different non-CPU device than the current platform device.
if tensor.device.type != "cpu" and tensor.device.type != current_platform.device_type:

Copilot uses AI. Check for mistakes.
Comment on lines +258 to +266
tensor_kwargs = {"allocation_mode": self.allocation_mode}
if config.actor.use_lora:
tensor_kwargs.update(
{
"use_lora": config.actor.use_lora,
"lora_name": config.gconfig.lora_name,
"base_model_name": config.actor.path,
}
)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In tensor mode, use_lora, lora_name, and base_model_name are propagated into WeightUpdateMeta, but the tensor-sync path doesn’t appear to use any of that metadata: update_weights_from_tensor(...) only receives raw tensors, and the vLLM tensor endpoint loads weights directly into the base model (not via the LoRA manager). As written, enabling actor.use_lora with weight_update_mode: tensor is likely to no-op or fail. Either explicitly disallow LoRA in tensor mode (raise early) or extend the tensor update API/endpoint to handle LoRA adapter updates.

Suggested change
tensor_kwargs = {"allocation_mode": self.allocation_mode}
if config.actor.use_lora:
tensor_kwargs.update(
{
"use_lora": config.actor.use_lora,
"lora_name": config.gconfig.lora_name,
"base_model_name": config.actor.path,
}
)
if config.actor.use_lora:
raise ValueError(
"LoRA is not supported with weight_update_mode='tensor'. "
"Please use 'disk' or 'xccl' for LoRA-based weight updates."
)
tensor_kwargs = {"allocation_mode": self.allocation_mode}

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Question] 现在支持rollout和actor的共卡场景了吗?代码里有共卡,但是似乎不是这个场景

3 participants