Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ def collate(
return inputs, targets


# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
# TODO (T245547773): Consolidate with SimpleGRPOLoss in losses/grpo_loss.py
# Currently duplicated because of function signature differences:
# - This function takes logits + response, computes logprobs internally
# - SimpleGRPOLoss takes pre-computed logprobs
# - TitanTrainer passes logits, so would need wrapper or signature change
# Consider refactoring TitanTrainer's loss interface to standardize this.
def simple_grpo_loss(
logits: torch.Tensor,
response: torch.Tensor,
Expand All @@ -129,11 +134,30 @@ def simple_grpo_loss(
logprobs: torch.Tensor = compute_logprobs(logits, response)
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - beta * kl)
loss = (
((per_token_loss * padding_mask).sum(dim=1))

# Compute mean KL per valid token
mean_kl = (
((kl * padding_mask).sum(dim=1)) / (padding_mask.sum(dim=1).clamp(min=1.0))
).mean()

# Compute mean policy loss per valid token
mean_policy_loss = (
((per_token_policy_loss * padding_mask).sum(dim=1))
/ (padding_mask.sum(dim=1).clamp(min=1.0))
).mean()

# Compute loss using the means (mathematically equivalent)
loss = -(mean_policy_loss - beta * mean_kl)

# Log metrics
record_metric("grpo_loss/kl_divergence_mean", mean_kl.item(), Reduce.MEAN)
record_metric(
"grpo_loss/kl_divergence_max", (kl * padding_mask).max().item(), Reduce.MAX
)
record_metric("grpo_loss/policy_loss", mean_policy_loss.item(), Reduce.MEAN)
record_metric("grpo_loss/advantage_mean", advantages.mean().item(), Reduce.MEAN)
record_metric("grpo_loss/advantage_std", advantages.std().item(), Reduce.MEAN)

return loss


Expand Down
Loading