From 95b82c826245cc82e1f58d33e2a9cdca7d0438ba Mon Sep 17 00:00:00 2001 From: Intaik Park Date: Mon, 17 Nov 2025 20:16:12 -0800 Subject: [PATCH] drop episodes with 0 advantages or truncated (#580) Summary: Episodes with all rewards = 0 or =1 does not help learning as advantage would be 0. also, episodes with generations that are tuncated due to max_res_tokens would mostly get 0 rewards unnecessary as most of answers are at the end. Dropping these episodes provides trainer better batches to learn from (at the cost of sampling efficiency) {F1983571844} {F1983571853} Reviewed By: casteryh Differential Revision: D87243621 --- apps/grpo/main.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 82152e1b8..93d9fb2f7 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -31,7 +31,6 @@ from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer - from forge.types import LauncherConfig, ProvisionerConfig from forge.util.config import parse from forge.util.ops import compute_logprobs @@ -250,6 +249,11 @@ async def sample(self) -> dict[str, str] | None: len(sample["request"]), Reduce.MEAN, ) + record_metric( + "dataset/sample/max_sample_len", + len(sample["request"]), + Reduce.MAX, + ) record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX) return sample @@ -396,6 +400,24 @@ async def continuous_rollouts(): input_ids[i, :max_req_tokens] = episode.request_tensor input_ids[i, max_req_tokens:] = episode.response_tensor + # drop episodes if + # 1> reward std-dev is very small (including all 0s and all 1s) + # 2> response is potentially truncated (response_len >= max_res_tokens) + rewards = [e.reward for e in episodes] + rewards_std = torch.std(torch.tensor(rewards)) + max_response_len = max( + e.completion.token_ids.shape[0] for e in episodes + ) + drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens + record_metric( + "main/continuous_rollouts/dropped_episodes", + 1 if drop else 0, + Reduce.SUM, + ) + if drop: + del input_ids, episodes + continue + t.step("reward_evaluation") ref_logprobs = await ref_model.forward.route(