Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions tests/sandbox/weight_sync/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Weight Sync Sandbox

A minimal test environment focused exclusively on testing the weight synchronization
mechanism between RLTrainer and Generator.

Usage:
python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml
"""

import asyncio
import time

import torch
import torchstore as ts
from forge.actors._torchstore_utils import rdma_enabled
from forge.actors.generator import Generator
from forge.actors.trainer import RLTrainer
from forge.controller.provisioner import init_provisioner, shutdown
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.types import LauncherConfig, ProvisionerConfig
from forge.util.config import parse
from omegaconf import DictConfig
from vllm.transformers_utils.tokenizer import get_tokenizer


def generate_random_batch(
local_batch_size: int,
request_len: int,
response_len: int,
vocab_size: int = 32000,
device: str = "cuda",
dp_size: int = 1,
):
"""
Generate random input and target tensors for a single training step.
Creates one batch per data parallel rank.
"""
inputs = []
targets = []

# Create one batch for each data parallel rank
for _ in range(dp_size):
request = torch.randint(
1,
vocab_size,
(local_batch_size, request_len),
dtype=torch.long,
device=device,
)
response = torch.randint(
1,
vocab_size,
(local_batch_size, response_len),
dtype=torch.long,
device=device,
)

# Create padding mask
padding_mask = torch.rand((local_batch_size, response_len), device=device) > 0.1

ref_logprobs = (
-torch.abs(torch.randn((local_batch_size, response_len), device=device))
- 1.0
)
advantages = torch.randn((local_batch_size, 1), device=device)
input_tokens = torch.cat([request, response], dim=1)
inputs.append({"tokens": input_tokens})
targets.append(
{
"response": response,
"ref_logprobs": ref_logprobs,
"advantages": advantages,
"padding_mask": padding_mask,
}
)

return inputs, targets


async def main(cfg: DictConfig):
local_batch_size = cfg.get("local_batch_size", None)
assert local_batch_size is not None, "local_batch_size must be specified"

request_len = cfg.get("max_req_tokens", 64)
response_len = cfg.get("max_res_tokens", 64)
model_name = cfg.get("model")

print(f"Loading tokenizer for model: {model_name}")
tokenizer = get_tokenizer(model_name)
vocab_size = tokenizer.vocab_size
print(f"Detected vocab size: {vocab_size}")

trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1)
dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1

# ---- Global setups ---- #
provisioner = None
if cfg.get("provisioner", None) is not None:
provisioner = await init_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
else:
provisioner = await init_provisioner()

metric_logging_cfg = cfg.get("metric_logging", {})
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)

# Initialize torchstore
await ts.initialize(strategy=ts.ControllerStorageVolumes())

print("=" * 80)
print(f"Model: {model_name}")
print(f"Local batch size: {local_batch_size}")
print(
f"Sequence length: {request_len + response_len} ({request_len} + {response_len})"
)
print(f"Data parallel size: {dp_size}")
print(f"Is RDMA available? {rdma_enabled()}")
print("=" * 80 + "\n")

# Initialize trainer and generator
print("Initializing trainer and generator...")
init_start = time.time()

trainer, policy = await asyncio.gather(
RLTrainer.options(**cfg.actors.trainer).as_actor(
**cfg.trainer,
loss=lambda *args, **kwargs: torch.tensor(
1.0, requires_grad=True, device="cuda"
),
),
Generator.options(**cfg.actors.policy).as_actor(**cfg.policy),
)

init_time = time.time() - init_start
print(f"Finished initialization in ({init_time:.2f}s)")

# Run one training step to create weight delta
print("Running single training step...")
step_start = time.time()

inputs, targets = generate_random_batch(
local_batch_size=local_batch_size,
request_len=request_len,
response_len=response_len,
vocab_size=vocab_size,
dp_size=dp_size,
)

await trainer.train_step.call(inputs, targets)
step_time = time.time() - step_start
print(f"Finished train step in ({step_time:.2f}s)\n")

# Test push_weights
print("Pushing weights from trainer to store...")
push_start = time.time()

await trainer.push_weights.call(policy_version=1)

push_time = time.time() - push_start
print(f"Finished weights push in ({push_time:.2f}s)\n")

# Test update_weights
print("Updating generator weights from store...")
update_start = time.time()

await policy.update_weights.call(version=1)

update_time = time.time() - update_start
print(f"Updated generator weights ({update_time:.2f}s)\n")

# TODO - ideally we have the capability to check forward passes between
# the trainer/generator to verify correctness. This would require adding
# forward capabilities to both trainer/generator actors.

# Summary
print("=" * 80)
print("Results")
print("=" * 80)
print(f"Push time: {push_time:.2f}s")
print(f"Update time: {update_time:.2f}s")
print(f"Total sync time: {push_time + update_time:.2f}s")
print("=" * 80 + "\n")

# Cleanup
print("Shutting down...")
await shutdown()
print("Shutdown complete.")


if __name__ == "__main__":

@parse
def _main(cfg):
asyncio.run(main(cfg))

_main()
74 changes: 74 additions & 0 deletions tests/sandbox/weight_sync/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Weight Sync Sandbox Configuration
# >>> python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml

model: "Qwen/Qwen3-1.7B"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we could use a larger model like 8b

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add more model configs as needed

local_batch_size: 4
max_req_tokens: 64
max_res_tokens: 64

metric_logging:
console:
logging_mode: global_reduce

policy:
prefetch_weights_to_shm: false # Disable to avoid shared memory warnings in test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what warnings are you seeing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It spams resource_tracking stuff saying that the shared memory files don't exist anymore. Claude couldn't figure it out so I just disabled it lol

engine_args:
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: true
sampling_params:
n: 1
max_tokens: 32 # Just for verification forward pass
temperature: 1.0
top_p: 1.0

trainer:
model:
name: qwen3
flavor: 1.7B
hf_assets_path: hf://${model}
optimizer:
name: AdamW
lr: 1e-5
eps: 1e-8
lr_scheduler:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 128 # max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1 # We only run 1 step
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1 # Single GPU, no FSDP
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
folder: ./checkpoint
initial_load_path: hf://${model}
initial_load_in_hf: true
last_save_in_hf: true
async_mode: "disabled"
activation_checkpoint:
mode: selective
selective_ac_option: op

# Resource allocation - both as actors
actors:
policy:
procs: 1 # Single process for generator
with_gpus: true
mesh_name: policy
trainer:
procs: 1 # Single process for trainer
with_gpus: true
mesh_name: trainer
Loading