diff --git a/apps/grpo/main.py b/apps/grpo/main.py index ea37c5351..1c1c2bd4a 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -424,23 +424,21 @@ async def continuous_rollouts(): input_ids[i, :max_req_tokens] = episode.request_tensor input_ids[i, max_req_tokens:] = episode.response_tensor - # drop episodes if - # 1> reward std-dev is very small (including all 0s and all 1s) - # 2> response is potentially truncated (response_len >= max_res_tokens) - rewards = [e.reward for e in episodes] - rewards_std = torch.std(torch.tensor(rewards)) - max_response_len = max( - e.completion.token_ids.shape[0] for e in episodes - ) - drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens - record_metric( - "main/continuous_rollouts/dropped_episodes", - 1 if drop else 0, - Reduce.SUM, - ) - if drop: - del input_ids, episodes - continue + # drop episodes if + # 1> reward std-dev is very small (including all 0s and all 1s) + # 2> response is potentially truncated (response_len >= max_res_tokens) + rewards = [e.reward for e in episodes] + rewards_std = torch.std(torch.tensor(rewards)) + max_response_len = max(e.completion.token_ids.shape[0] for e in episodes) + drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens + record_metric( + "main/continuous_rollouts/dropped_episodes", + 1 if drop else 0, + Reduce.SUM, + ) + if drop: + del input_ids, episodes + continue t.step("reward_evaluation")