diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 93d9fb2f7..ea37c5351 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -117,7 +117,12 @@ def collate( return inputs, targets -# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss` +# TODO (T245547773): Consolidate with SimpleGRPOLoss in losses/grpo_loss.py +# Currently duplicated because of function signature differences: +# - This function takes logits + response, computes logprobs internally +# - SimpleGRPOLoss takes pre-computed logprobs +# - TitanTrainer passes logits, so would need wrapper or signature change +# Consider refactoring TitanTrainer's loss interface to standardize this. def simple_grpo_loss( logits: torch.Tensor, response: torch.Tensor, @@ -129,11 +134,30 @@ def simple_grpo_loss( logprobs: torch.Tensor = compute_logprobs(logits, response) kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages - per_token_loss = -(per_token_policy_loss - beta * kl) - loss = ( - ((per_token_loss * padding_mask).sum(dim=1)) + + # Compute mean KL per valid token + mean_kl = ( + ((kl * padding_mask).sum(dim=1)) / (padding_mask.sum(dim=1).clamp(min=1.0)) + ).mean() + + # Compute mean policy loss per valid token + mean_policy_loss = ( + ((per_token_policy_loss * padding_mask).sum(dim=1)) / (padding_mask.sum(dim=1).clamp(min=1.0)) ).mean() + + # Compute loss using the means (mathematically equivalent) + loss = -(mean_policy_loss - beta * mean_kl) + + # Log metrics + record_metric("grpo_loss/kl_divergence_mean", mean_kl.item(), Reduce.MEAN) + record_metric( + "grpo_loss/kl_divergence_max", (kl * padding_mask).max().item(), Reduce.MAX + ) + record_metric("grpo_loss/policy_loss", mean_policy_loss.item(), Reduce.MEAN) + record_metric("grpo_loss/advantage_mean", advantages.mean().item(), Reduce.MEAN) + record_metric("grpo_loss/advantage_std", advantages.std().item(), Reduce.MEAN) + return loss