Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,14 @@ async def continuous_rollouts():

t.step("reward_evaluation")

# Calculate reference logprobs
ref_logits = await ref_model.forward.route(input_ids)
t.step("reference_model_forward")
ref_logprobs = await ref_model.forward.route(
input_ids, max_req_tokens, return_logprobs=True
)
t.step("reference_model_calculate_logprobs")

ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:])
for i, episode in enumerate(group.episodes):
episode.ref_logprobs = ref_logprobs[i]
del ref_logits, ref_logprobs, input_ids
del ref_logprobs, input_ids
t.step("compute_logprobs")

# Calculate advantages and add to replay buffer
Expand Down
31 changes: 27 additions & 4 deletions src/forge/actors/reference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from forge.controller import ForgeActor
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.util.ops import compute_logprobs

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -90,8 +91,23 @@ async def setup(self):
self.model.eval()

@endpoint
async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:

async def forward(
self, input_ids: torch.Tensor, max_req_tokens: int, return_logprobs: bool
) -> torch.Tensor:
"""
Args:
input_ids (torch.Tensor): input token ids with shape [group_size, req + res length].
max_req_tokens (int): maximum request length.
return_logprobs (bool): whether to return og probabilities instead of raw logits.

return_logprobs flag significantly impacts the amount of data transferred to the caller:
- When False: Returns logits with shape [group_size, req + res_length, vocab_size].
This includes the full vocabulary distribution for each token position.

- When True: Returns log probabilities with shape [group_size, req_length].
This only includes probabilities for the request tokens, significantly reducing memory
usage and transfer overhead.
"""
# Record reference model metrics
record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM)
record_metric(
Expand Down Expand Up @@ -133,5 +149,12 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
if isinstance(logits, DTensor):
logits = logits.full_tensor()
t.step("forward")
t.stop()
return logits

if not return_logprobs:
t.stop()
return logits
else:
logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:])
t.step("compute_logprobs")
t.stop()
return logprobs
Loading