diff --git a/src/forge/actors/reference_actor.py b/src/forge/actors/reference_actor.py index c0b6aad24..28e0f9814 100644 --- a/src/forge/actors/reference_actor.py +++ b/src/forge/actors/reference_actor.py @@ -17,8 +17,6 @@ from typing import Any import torch - -from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf from torch import nn @@ -30,6 +28,8 @@ from torchtitan.experiments.forge.job_config import ForgeJobConfig from transformers import AutoModelForCausalLM +from forge.controller import ForgeActor + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -93,7 +93,7 @@ async def setup(self): async def forward(self, request: list[int], response: list[int]) -> torch.Tensor: """ Given a request and response tokens, return the log_probability of the - token_ids + token_ids, shape (completion_len, ) """ model_parts = self.engine.model_parts @@ -128,10 +128,11 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor logits = model_parts[0](input_ids) # Compute logprobs - input_ids = input_ids[:, len(response) :] + input_ids = input_ids[:, len(request) :] + # (bsz=1, completion_len) logprobs = compute_logprobs(logits, input_ids) - - return logprobs + # (completion_len, ) + return logprobs.squeeze(0) return pred @@ -140,14 +141,39 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor def compute_logprobs( logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 ) -> torch.Tensor: - context_length = logits.shape[1] - input_ids.shape[1] + """ + Compute log probs of the completion input_ids given the logits of the whole sequence. + Warning: only works if all prompts in the batch have the same length. TODO: support variable length prompts. + + Args: + logits (torch.Tensor): (batch_size, seq_len, vocab_size), the logits output from the model. + input_ids (torch.Tensor): (batch_size, completion_len), the token ids for the completion. + + Returns: + torch.Tensor: (batch_size, completion_len), the log probabilities of the completion tokens. - # Truncate request logits and drop last - logits = logits[:, context_length - 1 : -1] + Raises: + ValueError: If the inferred context length is less than or equal to 0. + """ + context_len = logits.shape[1] - input_ids.shape[1] + completion_len = input_ids.shape[1] + if context_len <= 0: + raise ValueError( + "Context length must be greater than 0. Otherwise the probability of the first token is undefined." + ) - # Compute logprobs - logprobs = torch.log_softmax(logits / temperature, dim=-1) - logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) + # (bsz, completion_len, vocab_size) + logits = logits[:, context_len - 1 : -1, :] + assert logits.shape == ( + input_ids.shape[0], + completion_len, + logits.shape[-1], + ), f"logits shape incorrect, {logits.shape=}, {input_ids.shape=}, {logits.shape[-1]=}" + token_logprobs = torch.log_softmax(logits / temperature, dim=-1) + # (bsz, completion_len, 1) + logprobs = torch.gather(token_logprobs, 2, input_ids.unsqueeze(-1)) + # (bsz, completion_len) + logprobs = logprobs.squeeze(-1) return logprobs diff --git a/tests/unit_tests/actors/test_reference_actor.py b/tests/unit_tests/actors/test_reference_actor.py new file mode 100644 index 000000000..9a8c8d35b --- /dev/null +++ b/tests/unit_tests/actors/test_reference_actor.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for reference_actor.py - compute_logprobs function +""" + +import pytest +import torch + +from forge.actors.reference_actor import compute_logprobs + + +class TestComputeLogprobs: + """Test the compute_logprobs utility function.""" + + def test_compute_logprobs_basic(self): + """Test basic logprobs computation.""" + batch_size = 1 + seq_len = 5 + vocab_size = 1000 + response_len = 3 + + logits = torch.randn(batch_size, seq_len, vocab_size) + + # Create mock input_ids for response tokens + input_ids = torch.randint(0, vocab_size, (batch_size, response_len)) + + result = compute_logprobs(logits, input_ids) + + # Verify output shape and properties + assert isinstance(result, torch.Tensor) + assert result.shape == (batch_size, response_len) + assert torch.all(result <= 0) # Log probabilities should be <= 0 + + def test_compute_logprobs_with_temperature(self): + """Test logprobs computation with temperature scaling.""" + batch_size = 1 + seq_len = 5 + vocab_size = 1000 + response_len = 3 + temperature = 0.1 + + logits = torch.randn(batch_size, seq_len, vocab_size) + input_ids = torch.randint(0, vocab_size, (batch_size, response_len)) + + result = compute_logprobs(logits, input_ids, temperature) + + assert isinstance(result, torch.Tensor) + assert result.shape == (batch_size, response_len) + assert torch.all(result <= 0) + default_result = compute_logprobs(logits, input_ids) + assert not torch.allclose(result, default_result) + + def test_compute_logprobs_single_token(self): + """Test logprobs computation with single token response.""" + batch_size = 1 + seq_len = 5 + vocab_size = 1000 + response_len = 1 + + logits = torch.randn(batch_size, seq_len, vocab_size) + input_ids = torch.randint(0, vocab_size, (batch_size, response_len)) + + result = compute_logprobs(logits, input_ids) + + assert result.shape == (batch_size, response_len) + assert result.numel() == 1 # Single element + + def test_compute_logprobs_empty_response(self): + """Test logprobs computation with empty response.""" + batch_size = 1 + seq_len = 5 + vocab_size = 1000 + response_len = 0 + + logits = torch.randn(batch_size, seq_len, vocab_size) + input_ids = torch.randint(0, vocab_size, (batch_size, response_len)) + + result = compute_logprobs(logits, input_ids) + + assert result.shape == (batch_size, response_len) + + def test_compute_logprobs_empty_prompt(self): + """Test logprobs computation with empty prompt.""" + batch_size = 1 + vocab_size = 1000 + prompt_len = 0 + response_len = 5 + seq_len = prompt_len + response_len + + logits = torch.randn(batch_size, seq_len, vocab_size) + input_ids = torch.randint(0, vocab_size, (batch_size, response_len)) + with pytest.raises(ValueError, match=r"(?i).*context length.*"): + _ = compute_logprobs(logits, input_ids)