Skip to content

Conversation

@Ritesh1905
Copy link
Contributor

@Ritesh1905 Ritesh1905 commented Sep 18, 2025

A simple toy app RL loop that (almost) converges in less than 5 mins. This uses a much simpler reinforce loss. I could not get the reward-mean converging with the GRPO loss. Sending this PR here to get early feedback and once this makes sense, I will figure out to make it work with GRPO loss.

https://meta.wandb.io/rithesh/sumdigits-training/runs/kmj952x7?nw=nwuserrithesh

image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 18, 2025
@Ritesh1905 Ritesh1905 marked this pull request as ready for review September 18, 2025 18:12
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this looks great and makes huge strides toward correctness in Forge. Just a couple of comments.



Scalar = Union[int, float]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't say we're confident that these Episode and Group abstractions are the best ones yet - I'd be more comfortable if you just copy-pasta'd them into the sumdigits.py file in order to use them for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vidhyav is rolling out the abstractions soon. Jut centralizing this so that he just has 1 place to fix.

let me know if you still wish for me to copy paste them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still prefer a copy paste if that's alright? Sorry for being a stickler :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. copy pasted the code and added a TODO.

mlogger.log("loss/training_step", loss, training_step)
print(f"loss/training_step: {loss} at {training_step}")
if training_step % 5 == 0:
await trainer.push_weights.call(policy_version)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weight sync is off by 5?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. because this is a toy app and the weight sync take a long time. :)

Ideally I wish for us to have a accumulate and apply gradients abstractions so that we can just accumulate the gradients and apply them after every N batches (in this case 5)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense - just curious, how much faster does it converge when weight sync is just off by 1 via the replay buffer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to be on-policy. we can figure out what's best later when we are setting up the CI.

Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@joecummings joecummings merged commit d4fb5e1 into main Sep 18, 2025
5 checks passed
@Ritesh1905 Ritesh1905 deleted the rithesh/toy_app branch September 18, 2025 20:37
@JenniferWang
Copy link
Contributor

@Ritesh1905 , based on your experience, is this a regression?
image

https://meta.wandb.io/jiyue/sumdigits-training/runs/ah2js4mb?nw=nwuserjiyue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants