diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 2088dae5f..852989682 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -32,7 +32,7 @@ from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer -from forge.util.ops import selective_log_softmax +from forge.util.ops import compute_logprobs from monarch.actor import endpoint from omegaconf import DictConfig from vllm.transformers_utils.tokenizer import get_tokenizer @@ -137,16 +137,6 @@ def collate(batches: list[list[Episode]]): return inputs, targets -def compute_logprobs( - logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 -) -> torch.Tensor: - context_length = logits.shape[1] - input_ids.shape[1] - logits = logits[:, context_length - 1 : -1].to(input_ids.device) - scaled_logits = logits / temperature - logprobs = selective_log_softmax(scaled_logits, input_ids) - return logprobs - - def simple_grpo_loss( logits: torch.Tensor, response: torch.Tensor, @@ -155,7 +145,12 @@ def simple_grpo_loss( padding_mask: torch.Tensor, beta: float = 0.1, ) -> torch.Tensor: - logprobs = compute_logprobs(logits, response) + """ + Example GRPO Loss Function for RLTrainer + """ + logprobs: torch.Tensor = compute_logprobs(logits, response) + + # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss` kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages per_token_loss = -(per_token_policy_loss - beta * kl) diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py index 49044b33f..2eca1fdd1 100644 --- a/src/forge/util/ops.py +++ b/src/forge/util/ops.py @@ -49,3 +49,26 @@ def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Te per_token_logps.append(row_per_token_logps) per_token_logps = torch.stack(per_token_logps) return per_token_logps + + +def compute_logprobs( + logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 +) -> torch.Tensor: + """ + Computes the log probabilities of the input tokens given the model logits and temperature. + + Args: + logits (`torch.Tensor`): + The model output logits of shape `(batch_size, sequence_length, vocab_size)`. + input_ids (`torch.Tensor`): + The input token ids of shape `(batch_size, target_sequence_length)`. + temperature (`float`, *optional*, defaults to 1.0): + The temperature value for scaling logits before computing log probabilities. + + """ + # Ignore the last token from logits because it predicts the next token (-1) + # And align logits with the input tokens length. + logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device) + scaled_logits = logits / temperature + logprobs = selective_log_softmax(scaled_logits, input_ids) + return logprobs diff --git a/tests/unit_tests/util/test_compute_logprobs.py b/tests/unit_tests/util/test_compute_logprobs.py new file mode 100644 index 000000000..c4e3bffcb --- /dev/null +++ b/tests/unit_tests/util/test_compute_logprobs.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +import torch.nn.functional as F +from forge.util.ops import compute_logprobs + + +def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor): + # Helper: Textbook Log Softmax + log_probs = F.log_softmax(logits, dim=-1) + return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) + + +class TestComputeLogprobs: + def test_single_batch_item(self): + """Test with single batch item.""" + # Shape: (1, 2, 3) + logits = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]) + # Shape: (1, 1) + input_ids = torch.tensor([[1]]) + result = compute_logprobs(logits, input_ids) + + # Manual calculation + expected_logits = torch.tensor([[[1.0, 2.0, 3.0]]]) + expected = _textbook_log_softmax(expected_logits, input_ids) + + assert torch.allclose(result, expected, atol=1e-5) + assert result.shape == (1, 1) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # Shape: (1, 3, 3) + logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]) + # Shape: (1, 2) + input_ids = torch.tensor([[2, 0]]) + result = compute_logprobs(logits, input_ids) + + # Manual calculation + expected_logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]) + expected = _textbook_log_softmax(expected_logits, input_ids) + + assert torch.allclose(result, expected, atol=1e-5) + assert result.shape == (1, 2) + + @pytest.mark.timeout(10) + def test_multi_batch(self): + """Test with multiple batch items.""" + # Shape: (2, 2, 3) + logits = torch.tensor( + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[0.5, 1.5, 2.5], [3.5, 4.5, 5.5]]] + ) + # Shape: (2, 1) + input_ids = torch.tensor([[1], [2]]) + result = compute_logprobs(logits, input_ids) + + # Manual calculation + expected_logits = torch.tensor([[[1.0, 2.0, 3.0]], [[0.5, 1.5, 2.5]]]) + expected = _textbook_log_softmax(expected_logits, input_ids) + + assert torch.allclose(result, expected, atol=1e-5) + assert result.shape == (2, 1) + + @pytest.mark.timeout(10) + def test_temperature(self): + """Test with different temperature values.""" + batch_size, seq_len, vocab_size = 2, 4, 6 + logits = torch.randn(batch_size, seq_len, vocab_size) + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len - 1)) + + # Manual calculation with temperature scaling + def _manual(temperature: float): + expected_logits = logits[:, 0:-1] / temperature + return _textbook_log_softmax(expected_logits, input_ids) + + temperatures = [1.0, 2.0, 4.5] + for temperature in temperatures: + result = compute_logprobs(logits, input_ids, temperature=temperature) + expected = _manual(temperature) + assert torch.allclose(result, expected, atol=1e-5) + assert result.shape == input_ids.shape + + @pytest.mark.timeout(10) + def test_edge_cases(self): + """Test edge cases.""" + # Test with very large values (numerical stability) + logits = torch.tensor([[[1000.0, 2000.0], [1500.0, 2500.0]]]) + input_ids = torch.tensor([[0]]) + result = compute_logprobs(logits, input_ids) + # Should not be NaN or inf + assert torch.isfinite(result).all() + + # Test with very small values + logits = torch.tensor([[[-1000.0, -2000.0], [-1500.0, -2500.0]]]) + input_ids = torch.tensor([[1]]) + result = compute_logprobs(logits, input_ids) + # Should not be NaN or inf + assert torch.isfinite(result).all() + + def test_compute_logprobs_empty_response(self): + """Test logprobs computation with empty response.""" + batch_size, seq_len, vocab_size = 1, 5, 1000 + logits = torch.randn(batch_size, seq_len, vocab_size) + input_ids = torch.tensor([[]]) + + result = compute_logprobs(logits, input_ids) + assert result.shape == (batch_size, 0) diff --git a/tests/unit_tests/util/test_ops.py b/tests/unit_tests/util/test_selective_log_softmax.py similarity index 99% rename from tests/unit_tests/util/test_ops.py rename to tests/unit_tests/util/test_selective_log_softmax.py index 834de3199..4ca94f2c3 100644 --- a/tests/unit_tests/util/test_ops.py +++ b/tests/unit_tests/util/test_selective_log_softmax.py @@ -10,7 +10,7 @@ from forge.util.ops import selective_log_softmax -class TestOps: +class TestSelectiveLogSoftmax: @pytest.mark.timeout(10) def test_basic_2d(self): """Test basic 2D case."""