Skip to content

Conversation

@felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Oct 23, 2025

fixes #495

image

Summary

Fixed rope_cache size mismatch in GRPO training.

Problem

The trainer.training.seq_len parameter was hardcoded to 2048, causing an AssertionError when the actual sequence length (max_req_tokens + max_res_tokens) exceeded this value.

Traceback (most recent call last):
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/monarch/_src/actor/actor_mesh.py", line 932, in handle
    result = await the_method(*args, **kwargs)
  File "/fsx/lewis/git/torchforge/src/forge/actors/reference_model.py", line 181, in forward
    logits = self.model(input_ids)
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 491, in forward
    h = layer(h, self.rope_cache, attention_masks)
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 345, in forward
    x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks)
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 219, in forward
    xq, xk = apply_rotary_emb(xq, xk, rope_cache)
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 91, in apply_rotary_emb
    rope_cache = reshape_for_broadcast(rope_cache, xq)
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 79, in reshape_for_broadcast
    assert rope_cache.shape == (seqlen, head_dim * 2)
AssertionError

Root Cause

qwen3_1_7b.yaml (seq_len=2048, max_req_tokens=512, max_res_tokens=16384)
  → Qwen3Model.__init__() (precomputes rope_cache with size 2048)
  → model.forward() (receives 16896 tokens)
  → reshape_for_broadcast() (assertion fails: 2048 ≠ 16896)

Solution

Just set seq_len to the sum of max_res_tokens + max_req_tokens

training:
    seq_len: ${sum:${max_req_tokens},${max_res_tokens}}

To support "sum" in the yaml, i had to register the op with OmegaConf

@felipemello1 felipemello1 changed the title [draft] [FIX] Enable larger seq len [FIX] Remove hardcoded seq len Oct 23, 2025
Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the quick fix!

@joecummings joecummings merged commit 6919943 into meta-pytorch:main Oct 23, 2025
8 checks passed
@felipemello1 felipemello1 deleted the fix_seq_len branch October 23, 2025 16:46
photomz pushed a commit to photomz/forge that referenced this pull request Oct 25, 2025
Co-authored-by: Felipe Mello <felipemello@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[GRPO app] Example fails when max_res_tokens: 16384 with assert rope_cache.shape == (seqlen, head_dim * 2)

3 participants