-
Notifications
You must be signed in to change notification settings - Fork 16
Support dp_size in replay buffer #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just need to fix one thing to not return sorted samples
src/forge/actors/replay_buffer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to return a sorted sample here as that reduces variability in the sample. You need to get the index of the sorted array and then probably do this as a nested for loop to be easier to read.
batch = []
for rank in self.dp_size:
local_batch = []
for i in bsz:
e = sampled_episodes[sort_order[rank*i]]
local_batch.append(e)
batch.append(local_batch)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out this issue. I have updated this part. Please review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good now. There's a few more small things to fix but I'll pre-approve it
src/forge/actors/replay_buffer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cleaner but is it moving the data twice? It's probably fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the code to make it more efficient.
Added a
dp_size
dimension in replay buffer sampling to enable data parallel.Updated
GRPO/main.py
accordinglyTest:
pytest tests/unit_tests/test_replay_buffer.py
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml