diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 82152e1b8..93d9fb2f7 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -31,7 +31,6 @@ from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer - from forge.types import LauncherConfig, ProvisionerConfig from forge.util.config import parse from forge.util.ops import compute_logprobs @@ -250,6 +249,11 @@ async def sample(self) -> dict[str, str] | None: len(sample["request"]), Reduce.MEAN, ) + record_metric( + "dataset/sample/max_sample_len", + len(sample["request"]), + Reduce.MAX, + ) record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX) return sample @@ -396,6 +400,24 @@ async def continuous_rollouts(): input_ids[i, :max_req_tokens] = episode.request_tensor input_ids[i, max_req_tokens:] = episode.response_tensor + # drop episodes if + # 1> reward std-dev is very small (including all 0s and all 1s) + # 2> response is potentially truncated (response_len >= max_res_tokens) + rewards = [e.reward for e in episodes] + rewards_std = torch.std(torch.tensor(rewards)) + max_response_len = max( + e.completion.token_ids.shape[0] for e in episodes + ) + drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens + record_metric( + "main/continuous_rollouts/dropped_episodes", + 1 if drop else 0, + Reduce.SUM, + ) + if drop: + del input_ids, episodes + continue + t.step("reward_evaluation") ref_logprobs = await ref_model.forward.route(