Skip to content

Conversation

wukaixingxp
Copy link
Contributor

@wukaixingxp wukaixingxp commented Oct 16, 2025

Summary:
This change fixes a runtime ValueError in GRPO training that occurred when using a sequence length different from the default. The ref_model was not inheriting the seq_len from the main trainer configuration, causing it to fall back to the job config default value (e.g., 2048). This led to a dimension mismatch error with the rotary position embeddings when the trainer was configured with a longer sequence length.

The fix explicitly sets seq_len: ${trainer.training.seq_len} in the ref_model.training section of the relevant GRPO YAML files, like this one. This ensures the reference model always uses the same sequence length as the trainer, resolving the crash.

Error Log:
if trainer.training.seq_len !=2048: then the error looks like this:

Spawning single actor SandboxedPythonCoder
Launcher not provided, remote allocations will not work.
Spawning single actor DatasetActor
Spawning Service for Policy
Spawning single actor RLTrainer
Spawning single actor ReplayBuffer
Spawning single actor ComputeAdvantages
Spawning Service for ReferenceModel
Spawning Service for RewardActor
[0] [RLTrainer-0/1] 2025-10-09 23:42:42 INFO Building 0-D device mesh with [], []
[0] [RLTrainer-0/1] 2025-10-09 23:42:42 INFO [GC] Initial GC collection took 0.00 seconds
[0] [RLTrainer-0/1] 2025-10-09 23:42:45 INFO Total parameter count: dense 2,031,739,904, sparse 0, active 2,031,739,904
[0] [RLTrainer-0/1] 2025-10-09 23:42:45 INFO Applied selective activation checkpointing to the model
[0] [RLTrainer-0/1] 2025-10-09 23:42:46 INFO Checkpointing active. Checkpoints will be loaded from and saved to checkpoint
[0] [RLTrainer-0/1] 2025-10-09 23:42:46 INFO Mixed precision training is handled by AMP
[0] [RLTrainer-0/1] 2025-10-09 23:42:46 INFO loading from HF safetensors from --checkpoint.initial_load_path: /home/nvidia/.cache/huggingface/hub/models--Qwen--Qwen3-1.7B/snapshots/70d244cc86ccca08cf5af4e1e306ecf908b1ad5e
[0] [RLTrainer-0/1] 2025-10-09 23:42:46 INFO Loading the checkpoint from /home/nvidia/.cache/huggingface/hub/models--Qwen--Qwen3-1.7B/snapshots/70d244cc86ccca08cf5af4e1e306ecf908b1ad5e.
[0] [RLTrainer-0/1] 2025-10-09 23:42:47 INFO [GC] GC collection for checkpoint loading. took 0.07 seconds
[0] [RLTrainer-0/1] 2025-10-09 23:42:47 INFO Finished loading the checkpoint in 1.24 seconds.
[0] [ReferenceModel-0/1] 2025-10-09 23:42:52 INFO Building 0-D device mesh with [], []
[0] [ReferenceModel-0/1] 2025-10-09 23:42:52 INFO [GC] Initial GC collection took 0.00 seconds
[0] [ReferenceModel-0/1] 2025-10-09 23:42:56 INFO Total parameter count: dense 2,031,739,904, sparse 0, active 2,031,739,904
[0] [ReferenceModel-0/1] 2025-10-09 23:42:56 INFO Applied selective activation checkpointing to the model
[0] [ReferenceModel-0/1] 2025-10-09 23:42:57 INFO Checkpointing active. Checkpoints will be loaded from and saved to checkpoint
[0] [ReferenceModel-0/1] 2025-10-09 23:42:57 INFO Mixed precision training is handled by AMP
[0] [ReferenceModel-0/1] 2025-10-09 23:42:57 INFO loading from HF safetensors from --checkpoint.initial_load_path: /home/nvidia/.cache/huggingface/hub/models--Qwen--Qwen3-1.7B/snapshots/70d244cc86ccca08cf5af4e1e306ecf908b1ad5e
[0] [ReferenceModel-0/1] 2025-10-09 23:42:57 INFO Loading the checkpoint from /home/nvidia/.cache/huggingface/hub/models--Qwen--Qwen3-1.7B/snapshots/70d244cc86ccca08cf5af4e1e306ecf908b1ad5e.
`torch_dtype` is deprecated! Use `dtype` instead!
INFO 10-09 23:42:58 [config.py:1604] Using max model len 40960
[0] [ReferenceModel-0/1] 2025-10-09 23:42:58 INFO [GC] GC collection for checkpoint loading. took 0.07 seconds
[0] [ReferenceModel-0/1] 2025-10-09 23:42:58 INFO Finished loading the checkpoint in 1.38 seconds.
INFO 10-09 23:42:58 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=16384.
[0] INFO 10-09 23:43:03 [__init__.py:235] Automatically detected platform cuda.
[0] INFO 10-09 23:43:03 [__init__.py:235] Automatically detected platform cuda.
[0] WARNING 10-09 23:43:07 [multiproc_worker_utils.py:307] Reducing Torch parallelism from 120 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
[0] [W1009 23:43:11.138720406 ProcessGroupNCCL.cpp:941] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated. (function operator())
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] INFO 10-09 23:43:11 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
[0] WARNING 10-09 23:43:11 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
[0] INFO 10-09 23:43:12 [gpu_model_runner.py:1843] Starting to load model Qwen/Qwen3-1.7B...
[0] INFO 10-09 23:43:12 [gpu_model_runner.py:1875] Loading model from scratch...
[0] INFO 10-09 23:43:12 [cuda.py:290] Using Flash Attention backend on V1 engine.
[0] INFO 10-09 23:43:12 [weight_utils.py:296] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.97it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.97it/s]
[0]
[0] INFO 10-09 23:43:14 [default_loader.py:262] Loading weights took 1.10 seconds
[0] INFO 10-09 23:43:14 [gpu_model_runner.py:1892] Model loading took 3.2152 GiB and 1.808709 seconds
[0] `torch_dtype` is deprecated! Use `dtype` instead!
[0] INFO 10-09 23:43:22 [config.py:1604] Using max model len 40960
[0] INFO 10-09 23:43:23 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=16384.
[0] WARNING 10-09 23:43:24 [config.py:1528] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`.
[0] INFO 10-09 23:43:33 [backends.py:530] Using cache directory: /home/nvidia/.cache/vllm/torch_compile_cache/3739b4341d/rank_0_0/backbone for vLLM's torch.compile
[0] INFO 10-09 23:43:33 [backends.py:541] Dynamo bytecode transform time: 7.69 s
[0] INFO 10-09 23:43:36 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 2.847 s
[0] INFO 10-09 23:43:46 [monitor.py:34] torch.compile takes 7.69 s in total
[0] INFO 10-09 23:43:47 [gpu_worker.py:255] Available KV cache memory: 74.88 GiB
[0] INFO 10-09 23:43:48 [kv_cache_utils.py:833] GPU KV cache size: 701,088 tokens
[0] INFO 10-09 23:43:48 [kv_cache_utils.py:837] Maximum concurrency for 40,960 tokens per request: 17.12x
Capturing CUDA graph shapes: 100%|██████████| 67/67 [00:03<00:00, 21.03it/s]
[0] INFO 10-09 23:43:51 [gpu_model_runner.py:2485] Graph capturing finished in 3 secs, took 0.44 GiB
All services initialized successfully!
Starting GRPO with 1 rollout threads, 1 training threads
[0] Error in testing framework: 'ActorEndpoint' object is not callable
[0] Error in testing framework: 'ActorEndpoint' object is not callable
[0] Error in testing framework: 'ActorEndpoint' object is not callable
[0] Error in testing framework: 'ActorEndpoint' object is not callable
[0] /home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/cuda/memory.py:491: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
[0]   warnings.warn(
[0] [0]E1009 23:44:31.693703 1698762 monarch_hyperactor/src/telemetry.rs:58] file:actor_mesh.py, lineno:868, stacktrace:Traceback (most recent call last):
[0] [ReferenceModel-0/1] 2025-10-09 23:44:31 CRITICAL Unhandled exception in actor endpoint
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/monarch/_src/actor/actor_mesh.py", line 865, in instrumented
[0] Traceback (most recent call last):
[0]     result = await the_method(*args, **kwargs)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/monarch/_src/actor/actor_mesh.py", line 865, in instrumented
[0]   File "/home/nvidia/kai/forge/src/forge/actors/reference_model.py", line 149, in forward
[0]     logits = self.model(input_ids)
[0]     result = await the_method(*args, **kwargs)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[0]   File "/home/nvidia/kai/forge/src/forge/actors/reference_model.py", line 149, in forward
[0]     return self._call_impl(*args, **kwargs)
[0]     logits = self.model(input_ids)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[0]     return forward_call(*args, **kwargs)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 452, in forward
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[0]     return self._call_impl(*args, **kwargs)
[0]     h = layer(h, self.rope_cache)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[0]     return self._call_impl(*args, **kwargs)
[0]     return forward_call(*args, **kwargs)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 452, in forward
[0]     h = layer(h, self.rope_cache)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[0]     return self._call_impl(*args, **kwargs)
[0]     return forward_call(*args, **kwargs)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 323, in forward
[0]     return forward_call(*args, **kwargs)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 323, in forward
[0]     x = x + self.attention(self.attention_norm(x), rope_cache)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[0]     x = x + self.attention(self.attention_norm(x), rope_cache)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[0]     return self._call_impl(*args, **kwargs)
[0]     return self._call_impl(*args, **kwargs)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[0]     return forward_call(*args, **kwargs)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 203, in forward
[0]     return forward_call(*args, **kwargs)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 203, in forward
[0]     xq, xk = apply_rotary_emb(xq, xk, rope_cache)
[0]     xq, xk = apply_rotary_emb(xq, xk, rope_cache)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 81, in apply_rotary_emb
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 81, in apply_rotary_emb
[0]     rope_cache = reshape_for_broadcast(rope_cache, xq)
[0]     rope_cache = reshape_for_broadcast(rope_cache, xq)
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 69, in reshape_for_broadcast
[0]   File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 69, in reshape_for_broadcast
[0]     assert rope_cache.shape == (seqlen, head_dim * 2)
[0] , actor_id:_18tVD5TRZcYt[0].ReferenceModel[0], Unhandled exception in actor endpoint
[0]     assert rope_cache.shape == (seqlen, head_dim * 2)
Got failure on replica 0. Error:
A remote actor call has failed.
 Traceback of where the remote call failed (most recent call last):
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/monarch/_src/actor/actor_mesh.py", line 875, in handle
    result = await instrumented()
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/monarch/_src/actor/actor_mesh.py", line 872, in instrumented
    raise e
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/monarch/_src/actor/actor_mesh.py", line 865, in instrumented
    result = await the_method(*args, **kwargs)
  File "/home/nvidia/kai/forge/src/forge/actors/reference_model.py", line 149, in forward
    logits = self.model(input_ids)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 452, in forward
    h = layer(h, self.rope_cache)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 323, in forward
    x = x + self.attention(self.attention_norm(x), rope_cache)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 203, in forward
    xq, xk = apply_rotary_emb(xq, xk, rope_cache)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 81, in apply_rotary_emb
    rope_cache = reshape_for_broadcast(rope_cache, xq)
  File "/home/nvidia/anaconda3/envs/forge/lib/python3.10/site-packages/torchtitan/experiments/qwen3/model/model.py", line 69, in reshape_for_broadcast
    assert rope_cache.shape == (seqlen, head_dim * 2)
AssertionError
[0] AssertionError
Shutting down...

Test Plan:
Run any of the modified GRPO configurations with a seq_len in the trainer.training section that is different from the default (e.g., 8192). The training will now proceed without mismatch error.

Test Log:
Training successful:

(test) [kaiwu@devgpu005.zas1 ~/work/kaiwu/forge (fix_ref_model_seq_len)]$ cat apps/grpo/qwen3_1_7b.yaml 
# Grouped Relative Policy Optimization (GRPO)
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

# Global configuration
group_size: 8
local_batch_size: 16 # per-device batch size
max_req_tokens: 2048
max_res_tokens: 2048
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default

# Main loop configuration
rollout_threads: 1   # Recommended to set equal to policy.num_replicas


# Observability configuration
metric_logging:
  wandb:
    project: "grpo-training"
    group: "grpo_exp_${oc.env:USER}"
    reduce_across_ranks: True
  console:
    reduce_across_ranks: True

# Dataset configuration
dataset:
  path: "openai/gsm8k"
  revision: "main"
  data_split: "train"
  streaming: true
  model: ${model}

# Policy configuration
policy:
  engine_args:  # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
    model: ${model}
    tensor_parallel_size: 1
    pipeline_parallel_size: 1
    enforce_eager: false
  sampling_params:  # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
    n: ${group_size}
    max_tokens: ${max_res_tokens}
    temperature: 1.0
    top_p: 1.0

# Trainer configuration
trainer:
  model:
    name: qwen3
    flavor: 1.7B
    hf_assets_path: hf://${model}
  optimizer:
    name: AdamW
    lr: 1e-5
    eps: 1e-8
  lr_scheduler:
    warmup_steps: 1
  training:
    local_batch_size: ${local_batch_size}
    seq_len: 4196
    max_norm: 1.0
    steps: 1000000
    dtype: bfloat16
    gc_freq: 1
  compile:
    enable: false
  parallelism:
    data_parallel_replicate_degree: 1
    data_parallel_shard_degree: 1
    tensor_parallel_degree: 1
    pipeline_parallel_degree: 1
    context_parallel_degree: 1
    expert_parallel_degree: 1
    disable_loss_parallel: true
  checkpoint:
    enable: true
    initial_load_path: hf://${model}
    initial_load_in_hf: true
    last_save_in_hf: true
    interval: 500
    async_mode: "disabled"
  activation_checkpoint:
    mode: selective
    selective_ac_option: op

# Replay buffer configuration
replay_buffer:
  batch_size: ${local_batch_size}
  max_policy_age: ${off_by_n}
  dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree

# Reference model configuration
ref_model:
  model:
    name: qwen3
    flavor: 1.7B
    hf_assets_path: hf://${model}
  training:
    seq_len: ${trainer.training.seq_len}
    dtype: bfloat16
    gc_freq: 1
  compile:
    enable: false
  parallelism:
    data_parallel_replicate_degree: 1
    data_parallel_shard_degree: 1
    tensor_parallel_degree: 1
    pipeline_parallel_degree: 1
    context_parallel_degree: 1
    expert_parallel_degree: 1
  checkpoint:
    enable: true
    initial_load_path: hf://${model}
    initial_load_in_hf: true

# All resource allocations
services:
  policy:
    procs: ${policy.engine_args.tensor_parallel_size}
    num_replicas: 1
    mesh_name: policy
    with_gpus: true
  ref_model:
    procs: 1
    num_replicas: 1
    mesh_name: ref_model
    with_gpus: true
  reward_actor:
    procs: 1
    num_replicas: 1
    mesh_name: reward_actor
    with_gpus: false

actors:
  dataset:
    procs: 1
    with_gpus: false
    mesh_name: dataset
  trainer:
    procs: 1
    with_gpus: true
    mesh_name: trainer
  replay_buffer:
    procs: 1
    with_gpus: false
    mesh_name: replay_buffer
  compute_advantages:
    procs: 1
    with_gpus: false
    mesh_name: compute_advantages
(test) [kaiwu@devgpu005.zas1 ~/work/kaiwu/forge (main)]$ python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
Warning: setting HYPERACTOR_CODEC_MAX_FRAME_LENGTH since this needs to be set to enable large RPC calls via Monarch
INFO 10-16 13:39:38 [__init__.py:235] Automatically detected platform cuda.
Launcher not provided, remote allocations will not work.
/home/kaiwu/.conda/envs/test/lib/python3.10/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
  warnings.warn(
/home/kaiwu/.conda/envs/test/lib/python3.10/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
  warnings.warn(
wandb: Detected [openai] in use.
wandb: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
wandb: For more information, check out the docs at: https://weave-docs.wandb.ai/
Spawning actor DatasetActor
Spawning service Generator
Spawning actor RLTrainer
Spawning actor ReplayBuffer
Spawning actor ComputeAdvantages
Spawning service ReferenceModel
Spawning service RewardActor
[0] [0] [RLTrainer-0/1] 2025-10-16 13:39:58 INFO Compiling loss
[0] [0] [RLTrainer-0/1] 2025-10-16 13:40:02 INFO Building 0-D device mesh with [], []
[0] [0] [RLTrainer-0/1] 2025-10-16 13:40:02 INFO [GC] Initial GC collection took 0.00 seconds
[0] [0] [RLTrainer-0/1] 2025-10-16 13:40:05 INFO Total parameter count: dense 2,031,739,904, sparse 0, active 2,031,739,904
[0] [0] [RLTrainer-0/1] 2025-10-16 13:40:05 INFO Applied selective activation checkpointing to the model
[0] [0] [RLTrainer-0/1] 2025-10-16 13:40:07 INFO Checkpointing active. Checkpoints will be loaded from and saved to checkpoint
[0] [0] [RLTrainer-0/1] 2025-10-16 13:40:07 INFO Mixed precision training is handled by AMP
[0] [0] [RLTrainer-0/1] 2025-10-16 13:40:07 INFO loading from HF safetensors from --checkpoint.initial_load_path: /home/kaiwu/.cache/huggingface/hub/models--Qwen--Qwen3-1.7B/snapshots/70d244cc86ccca08cf5af4e1e306ecf908b1ad5e
[0] [0] [RLTrainer-0/1] 2025-10-16 13:40:07 INFO Loading the checkpoint from /home/kaiwu/.cache/huggingface/hub/models--Qwen--Qwen3-1.7B/snapshots/70d244cc86ccca08cf5af4e1e306ecf908b1ad5e.
[0] [0] INFO 10-16 13:40:07 [__init__.py:235] Automatically detected platform cuda.
[0] [0] [RLTrainer-0/1] 2025-10-16 13:40:08 INFO [GC] GC collection for checkpoint loading. took 0.00 seconds
[0] [0] [RLTrainer-0/1] 2025-10-16 13:40:08 INFO Finished loading the checkpoint in 1.04 seconds.
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:40:09 INFO Building 0-D device mesh with [], []
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:40:09 INFO [GC] Initial GC collection took 0.00 seconds
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:40:10 INFO Total parameter count: dense 2,031,739,904, sparse 0, active 2,031,739,904
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:40:10 INFO Applied selective activation checkpointing to the model
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:40:10 INFO Checkpointing active. Checkpoints will be loaded from and saved to 
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:40:10 INFO Mixed precision training is handled by AMP
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:40:10 INFO loading from HF safetensors from --checkpoint.initial_load_path: /home/kaiwu/.cache/huggingface/hub/models--Qwen--Qwen3-1.7B/snapshots/70d244cc86ccca08cf5af4e1e306ecf908b1ad5e
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:40:10 INFO Loading the checkpoint from /home/kaiwu/.cache/huggingface/hub/models--Qwen--Qwen3-1.7B/snapshots/70d244cc86ccca08cf5af4e1e306ecf908b1ad5e.
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:40:11 INFO [GC] GC collection for checkpoint loading. took 0.05 seconds
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:40:11 INFO Finished loading the checkpoint in 1.00 seconds.
[0] [0] `torch_dtype` is deprecated! Use `dtype` instead!
[0] [0] INFO 10-16 13:40:17 [config.py:1604] Using max model len 40960
[0] [0] INFO 10-16 13:40:17 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=16384.
[0] [0] INFO 10-16 13:40:19 [__init__.py:235] Automatically detected platform cuda.
[0] [0] WARNING 10-16 13:40:21 [multiproc_worker_utils.py:307] Reducing Torch parallelism from 112 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
[0] [0] [W1016 13:40:24.600647873 ProcessGroupNCCL.cpp:941] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated. (function operator())
[0] [0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [0] INFO 10-16 13:40:24 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
[0] [0] WARNING 10-16 13:40:24 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
[0] [0] INFO 10-16 13:40:24 [gpu_model_runner.py:1843] Starting to load model Qwen/Qwen3-1.7B...
[0] [0] INFO 10-16 13:40:24 [gpu_model_runner.py:1875] Loading model from scratch...
[0] [0] INFO 10-16 13:40:24 [cuda.py:290] Using Flash Attention backend on V1 engine.
[0] [0] INFO 10-16 13:40:25 [weight_utils.py:296] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00,  2.71it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00,  2.71it/s]
[0] [0] 
[0] [0] INFO 10-16 13:40:25 [default_loader.py:262] Loading weights took 0.85 seconds
[0] [0] INFO 10-16 13:40:26 [gpu_model_runner.py:1892] Model loading took 3.2152 GiB and 1.416880 seconds
[-]E1016 13:40:29.249965 37991 hyperactor/src/channel/net.rs:872] error_msg:session unix:@2f7RggtveES6a7LBlgT1fPWP.6154151429112659686: failed to deliver message within timeout
[0] [0] INFO 10-16 13:40:32 [backends.py:530] Using cache directory: /home/kaiwu/.cache/vllm/torch_compile_cache/69394135b9/rank_0_0/backbone for vLLM's torch.compile
[0] [0] INFO 10-16 13:40:32 [backends.py:541] Dynamo bytecode transform time: 5.19 s
[0] [0] INFO 10-16 13:40:36 [backends.py:194] Cache the graph for dynamic shape for later use
[0] [0] INFO 10-16 13:40:46 [backends.py:215] Compiling a graph for dynamic shape takes 13.71 s
[-]E1016 13:40:47.413240 37991 hyperactor/src/channel/net.rs:872] error_msg:session unix:@2f7RggtveES6a7LBlgT1fPWP.9221547878324211595: failed to deliver message within timeout
[0] [0] INFO 10-16 13:40:57 [monitor.py:34] torch.compile takes 18.89 s in total
[0] [0] INFO 10-16 13:40:58 [gpu_worker.py:255] Available KV cache memory: 116.85 GiB
[0] [0] INFO 10-16 13:40:58 [kv_cache_utils.py:833] GPU KV cache size: 1,094,016 tokens
[0] [0] INFO 10-16 13:40:58 [kv_cache_utils.py:837] Maximum concurrency for 40,960 tokens per request: 26.71x
Capturing CUDA graph shapes: 100%|██████████| 67/67 [00:07<00:00,  8.79it/s]
[0] [0] INFO 10-16 13:41:06 [gpu_model_runner.py:2485] Graph capturing finished in 8 secs, took 0.61 GiB
All services initialized successfully!
Torchstore successfully initialized with local rank strategy
Starting GRPO with 1 rollout threads, 1 training threads
/home/kaiwu/work/kaiwu/forge/apps/grpo/main.py:63: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  tensor = torch.tensor(request_tokens, dtype=torch.long)
/home/kaiwu/work/kaiwu/forge/apps/grpo/main.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  tensor = torch.tensor(response_tokens, dtype=torch.long)
[0] [0] /home/kaiwu/.conda/envs/test/lib/python3.10/site-packages/torch/cuda/memory.py:491: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
[0] [0]   warnings.warn(
Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized. This happens when you try to use `record_metric` before calling `init_backends`. To disable this warning, please call in your main file:
`mlogger = await get_or_create_metric_logger()`
`await mlogger.init_backends.call_one(logging_config)`
or set env variable `FORGE_DISABLE_METRICS=True`
[0] [0] /home/kaiwu/work/kaiwu/forge/apps/grpo/main.py:63: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
[0] [0]   tensor = torch.tensor(request_tokens, dtype=torch.long)
[0] [0] /home/kaiwu/work/kaiwu/forge/apps/grpo/main.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
[0] [0]   tensor = torch.tensor(response_tokens, dtype=torch.long)
[0] [0] /home/kaiwu/.conda/envs/test/lib/python3.10/site-packages/torch/cuda/memory.py:491: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
[0] [0]   warnings.warn(
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:41:31 INFO [GC] Performing periodic GC collection took 0.00 seconds
[0] [0] /home/kaiwu/.conda/envs/test/lib/python3.10/site-packages/torch/cuda/memory.py:491: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
[0] [0]   warnings.warn(
[0] [0] [RLTrainer-0/1] 2025-10-16 13:41:38 INFO Pushing weights for policy version 1
[0] [0] /home/kaiwu/.conda/envs/test/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py:133: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
[0] [0]   if tensor.storage().size() != tensor.numel():
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:41:40 INFO [GC] Performing periodic GC collection took 0.00 seconds
[0] [0] [RLTrainer-0/1] 2025-10-16 13:41:46 INFO Completed weights push in 8.30 seconds
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:41:49 INFO [GC] Performing periodic GC collection took 0.00 seconds
[0] [0] INFO 10-16 13:41:55 [block_pool.py:321] Successfully reset prefix cache
[0] [0] [Generator-0/1] 2025-10-16 13:41:55 INFO Weight update completed (now v1)
WandbBackend: Logged 71 metrics at global_step 1
=== [UnknownActor] - METRICS STEP 1 ===
  buffer/add/count_episodes_added: 40.0
  buffer/evict/avg_policy_age: 0.0
  buffer/evict/max_policy_age: 0.0
  buffer/evict/sum_episodes_evicted: 0.0
  buffer/sample/avg_data_utilization: 1.9918518518518518
  buffer/sample/count_sample_requests: 193.0
  buffer_perf/sample/total_duration_avg_s: 0.0001317064007599428
  buffer_perf/sample/total_duration_max_s: 0.004480354022234678
  dataset/sample/avg_sample_len: 612.1666666666666
  dataset/sample/count_samples_generated: 6.0
  generator/generate/avg_tokens_generated: 1502.775
  generator/generate/count_requests: 6.0
  generator/generate/count_sequences_completed: 40.0
  generator/generate/sum_tokens_generated: 60111.0
  generator/update_weights/count_weight_updates: 1.0
  generator_perf/generate/generate/duration_avg_s: 7.23542626953125
  generator_perf/generate/generate/duration_max_s: 8.615248046875
  generator_perf/generate/process_inputs/duration_avg_s: 0.0038483520269393923
  generator_perf/generate/process_inputs/duration_max_s: 0.013906432151794433
  generator_perf/generate/total_duration_avg_s: 7.239456560760736
  generator_perf/generate/total_duration_max_s: 8.61735995092988
  generator_perf/update_weights/avg_pending_requests: 1.0
  generator_perf/update_weights/max_pending_requests: 1.0
  reference_perf/forward/avg_sequence_length: 4096.0
  reference_perf/forward/compute_logprobs/duration_avg_s: 0.046741083916276696
  reference_perf/forward/compute_logprobs/duration_max_s: 0.23294888576492667
  reference_perf/forward/count_forward_passes: 5.0
  reference_perf/forward/forward/duration_avg_s: 0.047694106306880715
  reference_perf/forward/forward/duration_max_s: 0.17533081443980336
  reference_perf/forward/garbage_collection/duration_avg_s: 0.0004692518152296543
  reference_perf/forward/garbage_collection/duration_max_s: 0.0006710137240588665
  reference_perf/forward/memory_delta_end_start_avg_gb: 9.27994384765625
  reference_perf/forward/memory_peak_max_gb: 36.27710008621216
  reference_perf/forward/to_device/duration_avg_s: 0.0001405918039381504
  reference_perf/forward/to_device/duration_max_s: 0.00015233177691698074
  reference_perf/forward/total_duration_avg_s: 0.09504847247153521
  reference_perf/forward/total_duration_max_s: 0.4089419641532004
  reward/evaluate_response/avg_MathReward_reward: 0.3274999999999998
  reward/evaluate_response/avg_ThinkingReward_reward: 0.6599999999999998
  reward/evaluate_response/avg_total_reward: 0.49375000000000024
  reward/evaluate_response/count_MathReward_calls: 40.0
  reward/evaluate_response/count_ThinkingReward_calls: 40.0
  reward/evaluate_response/std_MathReward_reward: 0.41592517355889874
  reward/evaluate_response/std_ThinkingReward_reward: 0.3954743986657039
  reward/evaluate_response/sum_MathReward_reward: 13.099999999999993
  reward/evaluate_response/sum_ThinkingReward_reward: 26.39999999999999
  rl_trainer/avg_grpo_loss: 7.450580596923828e-08
  rl_trainer/count_training_steps: 1.0
  rl_trainer/learning_rate: 0.001
  rl_trainer_perf/push_weights/dcp_save/duration_avg_s: 8.298058313317597
  rl_trainer_perf/push_weights/dcp_save/duration_max_s: 8.298058313317597
  rl_trainer_perf/push_weights/flatten_state_dict/duration_avg_s: 0.0007137400098145008
  rl_trainer_perf/push_weights/flatten_state_dict/duration_max_s: 0.0007137400098145008
  rl_trainer_perf/push_weights/memory_delta_end_start_avg_gb: 0.0
  rl_trainer_perf/push_weights/memory_peak_max_gb: 11.419936656951904
  rl_trainer_perf/push_weights/to_hf/duration_avg_s: 0.0005804160609841347
  rl_trainer_perf/push_weights/to_hf/duration_max_s: 0.0005804160609841347
  rl_trainer_perf/push_weights/total_duration_avg_s: 8.299356839153916
  rl_trainer_perf/push_weights/total_duration_max_s: 8.299356839153916
  rl_trainer_perf/step/forward_backward/duration_avg_s: 9.580604746937752
  rl_trainer_perf/step/forward_backward/duration_max_s: 9.580604746937752
  rl_trainer_perf/step/memory_delta_end_start_avg_gb: 7.632381439208984
  rl_trainer_perf/step/memory_peak_max_gb: 84.80950689315796
  rl_trainer_perf/step/optimizer_step/duration_avg_s: 0.013598949182778597
  rl_trainer_perf/step/optimizer_step/duration_max_s: 0.013598949182778597
  rl_trainer_perf/step/save_checkpoint/duration_avg_s: 0.014532398898154497
  rl_trainer_perf/step/save_checkpoint/duration_max_s: 0.014532398898154497
  rl_trainer_perf/step/total_duration_avg_s: 9.608739678747952
  rl_trainer_perf/step/total_duration_max_s: 9.608739678747952
  worker_perf/update_weights/total_duration_avg_s: 6.335729785263538
  worker_perf/update_weights/total_duration_max_s: 6.335729785263538
==============================

[0] [0] /home/kaiwu/work/kaiwu/forge/apps/grpo/main.py:63: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
[0] [0]   tensor = torch.tensor(request_tokens, dtype=torch.long)
[0] [0] /home/kaiwu/work/kaiwu/forge/apps/grpo/main.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
[0] [0]   tensor = torch.tensor(response_tokens, dtype=torch.long)
[0] [0] [RLTrainer-0/1] 2025-10-16 13:41:55 INFO [GC] Performing periodic GC collection took 0.24 seconds
[0] [0] [RLTrainer-0/1] 2025-10-16 13:41:57 INFO Pushing weights for policy version 2
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:41:58 INFO [GC] Performing periodic GC collection took 0.00 seconds
[0] [0] [RLTrainer-0/1] 2025-10-16 13:42:04 INFO Completed weights push in 6.99 seconds
[0] [0] [ReferenceModel-0/1] 2025-10-16 13:42:07 INFO [GC] Performing periodic GC collection took 0.00 seconds

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 16, 2025
@casteryh
Copy link
Contributor

good catch!

@casteryh casteryh self-requested a review October 16, 2025 20:05
@Jack-Khuu
Copy link
Contributor

Legit, thanks!

@Jack-Khuu Jack-Khuu added Best Practices Things we should be doing but aren't bug Something isn't working and removed Best Practices Things we should be doing but aren't labels Oct 16, 2025
@wukaixingxp wukaixingxp marked this pull request as ready for review October 16, 2025 20:43
@wukaixingxp wukaixingxp merged commit 75df074 into meta-pytorch:main Oct 16, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants