diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 852989682..21e6168b7 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -401,14 +401,14 @@ async def continuous_rollouts(): t.step("reward_evaluation") - # Calculate reference logprobs - ref_logits = await ref_model.forward.route(input_ids) - t.step("reference_model_forward") + ref_logprobs = await ref_model.forward.route( + input_ids, max_req_tokens, return_logprobs=True + ) + t.step("reference_model_calculate_logprobs") - ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:]) for i, episode in enumerate(group.episodes): episode.ref_logprobs = ref_logprobs[i] - del ref_logits, ref_logprobs, input_ids + del ref_logprobs, input_ids t.step("compute_logprobs") # Calculate advantages and add to replay buffer diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index f0777665b..cc57e5246 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -29,6 +29,7 @@ from forge.controller import ForgeActor from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer +from forge.util.ops import compute_logprobs logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -90,8 +91,23 @@ async def setup(self): self.model.eval() @endpoint - async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: - + async def forward( + self, input_ids: torch.Tensor, max_req_tokens: int, return_logprobs: bool + ) -> torch.Tensor: + """ + Args: + input_ids (torch.Tensor): input token ids with shape [group_size, req + res length]. + max_req_tokens (int): maximum request length. + return_logprobs (bool): whether to return og probabilities instead of raw logits. + + return_logprobs flag significantly impacts the amount of data transferred to the caller: + - When False: Returns logits with shape [group_size, req + res_length, vocab_size]. + This includes the full vocabulary distribution for each token position. + + - When True: Returns log probabilities with shape [group_size, req_length]. + This only includes probabilities for the request tokens, significantly reducing memory + usage and transfer overhead. + """ # Record reference model metrics record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM) record_metric( @@ -133,5 +149,12 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: if isinstance(logits, DTensor): logits = logits.full_tensor() t.step("forward") - t.stop() - return logits + + if not return_logprobs: + t.stop() + return logits + else: + logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:]) + t.step("compute_logprobs") + t.stop() + return logprobs