Skip to content

Conversation

joecummings
Copy link
Member

@joecummings joecummings commented Sep 15, 2025

What does this PR do?

  1. This PR integrates torchtitan as the backend for our trainer actor and our reference model actor. Crucially, this gives us the ability to run w/ multiple model parallelisms, which is needed for multi-node setup.
  2. This PR also incorporates some syntactic sugar in the form of hf://, which allows users to specify a model from the Hugging Face hub and it will automatically either download it if it does not exist locally or find it in the catch and point the trainer / reference model to that directory.

How was this PR tested?

This PR was tested for "runnability" using the following combinations of parallelisms on a single node.

Trainer Reference Policy Tested
single single single
single single TP 2
single single TP 4
DP 2 DP 2 single
DP 2 DP 2 TP 2
TP 2 TP 2 TP 2

In addition, I incorporated a unit test for the config hf:// specification.

FAQs

  1. Why is the titan trainer and reference model slower than Hugging Face trainer and reference model? Presumably, this is b/c we have to a) run in fp32 for now (see Update titan weights to load in bfloat16 #166 for updates on changing this) and b) b/c we now have to convert first back to the Hugging Face format before pushing weights.
  2. Why does loss parallel not work? Loss parallel is only guaranteed to work OOTB with PyTorch's cross entropy loss b/c PyTorch distributed does some automatic sharding logic for all underlying aten ops here. Since we write our own GRPO loss, we don't get this for free. A potential follow-up would be to look into enabling this through our own resharding logic in Forge or upstreaming a more general fix to PyTorch for the underlying aten ops we need.
  3. How does the user know what the various knobs from Titan they can play with? Currently, there are no docs available for ForgeEnginer, ForgeTrainer, etc. This is a huge risk from the UX side of using Forge b/c the user will have to navigate themselves through the torchtitan codebase to figure this out. cc @mjtrm
  4. **Why do we push compute logprobs onto the controller GPU? ** I agree this is not ideal. In an effort to keep the ReferenceModel idempotent, I changed it so that it just returns logits. In addition, this lets us reuse the compute_logprobs found in grpo/main.py. However, this is incredible slow and has a lot more computation on rank0 than I would prefer. cc @Jack-Khuu we should perhaps change this s.t. we create a ReferenceLogprobs Actor that does the logits and logprobs calculation. Less idempotent, but more efficient.

Is this blocked by anything?

YES: meta-pytorch/torchstore#32 cc @LucasLLC

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 15, 2025
@pradeepfn
Copy link
Contributor

awesome results!.

@joecummings would you be able to copy/paste the configs we have to change for DP/TP. thanks.

@vidhyav
Copy link
Contributor

vidhyav commented Sep 16, 2025

Curious, what data did you test this with?

@joecummings
Copy link
Member Author

Curious, what data did you test this with?

Everything is with GSM8K dataset

@joecummings joecummings changed the title [WIP] GRPO Titan RL Trainer GRPO Titan RL Trainer Sep 17, 2025
@joecummings joecummings marked this pull request as ready for review September 18, 2025 18:49
Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! I think there are areas we can improve on but way out of scope of this PR and I'll note a few more of them down

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this PR, I would prefer this conversion to be a classmethod on Episode cc @Jack-Khuu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep this, the logging will work if you super().__init__() in __post_init__()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused tho - why do we need this if there's already a logger defined on the Actor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this PR, but it seems really fragile for sample to be returning something defined by a function in the app, with no standard interface.

cc @Jack-Khuu - can we keep a note of this?

@joecummings joecummings merged commit d55de5b into meta-pytorch:main Sep 18, 2025
5 checks passed
@joecummings joecummings deleted the titan-rl-trainer branch September 18, 2025 21:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants