diff --git a/src/forge/data/judge.py b/src/forge/data/judge.py new file mode 100644 index 000000000..96b59bae5 --- /dev/null +++ b/src/forge/data/judge.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum + +try: + from vllm.outputs import RequestOutput +except ImportError as e: + print(f"Failed to import RequestOutput from vllm.outputs: {e}") + RequestOutput = "RequestOutput" + +from forge.controller.service.interface import ServiceInterface + + +class EvaluationMethodology(str, Enum): + """Evaluation methodology for LLM Judge.""" + + MAJORITY = "Majority" + FIRST_SAMPLE = "First" + PASS_N = "Pass N" + + +@dataclass +class LLMJudge: + """Simple interface for Judges utilizing LLMs.""" + + judge_model: ServiceInterface + methodology: EvaluationMethodology = EvaluationMethodology.MAJORITY + + async def _generate(self, prompt: str) -> RequestOutput: + """Internally generate responses.""" + return await self.judge_model.generate.choose(prompt=prompt) + + async def evaluate_response(self, prompt: str, response: str) -> float: + """Evaluate a response to a prompt.""" + outputs: RequestOutput = await self._generate(prompt) + match self.methodology: + case EvaluationMethodology.MAJORITY: + return await self._majority_vote(response, outputs) + case EvaluationMethodology.FIRST_SAMPLE: + return await self._first_sample(response, outputs) + case EvaluationMethodology.PASS_N: + return await self._pass_n(response, outputs) + case _: + raise ValueError(f"Unknown evaluation methodology: {self.methodology}") + + async def _majority_vote(self, response: str, outputs: RequestOutput) -> bool: + """ + Return whether at least half of the outputs match the response + """ + matching = 0 + response_normalized = response.lower().strip() + + for output in outputs.outputs: + output_normalized = output.text.lower().strip() + if response_normalized == output_normalized: + matching += 1 + print(output.text) + + return matching > (len(outputs.outputs) // 2) + + async def _first_sample(self, response: str, outputs: RequestOutput) -> bool: + """ + Returns whether there is a match to the first output + """ + first_output = outputs.outputs[0] + output_normalized = first_output.text.lower().strip() + response_normalized = response.lower().strip() + + return output_normalized == response_normalized + + async def _pass_n(self, response: str, outputs: RequestOutput) -> bool: + """ + Return whether any of the outputs match the response + """ + response_normalized = response.lower().strip() + + for output in outputs.outputs: + output_normalized = output.text.lower().strip() + if response_normalized == output_normalized: + return True + + return False diff --git a/tests/unit_tests/data/test_judge.py b/tests/unit_tests/data/test_judge.py new file mode 100644 index 000000000..0e050966e --- /dev/null +++ b/tests/unit_tests/data/test_judge.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from forge.controller.service.interface import ServiceInterface + +from forge.data.judge import EvaluationMethodology, LLMJudge + + +# Mock classes to simulate VLLM RequestOutput structure +@dataclass +class MockCompletionOutput: + text: str + + +@dataclass +class MockRequestOutput: + outputs: List[MockCompletionOutput] + + +class TestLLMJudge: + @pytest.fixture + def mock_service(self): + """Create a mock ServiceInterface for testing.""" + service = Mock(spec=ServiceInterface) + service.generate = AsyncMock() + return service + + @pytest.fixture + def judge_majority(self, mock_service): + return LLMJudge( + judge_model=mock_service, methodology=EvaluationMethodology.MAJORITY + ) + + @pytest.fixture + def judge_first_sample(self, mock_service): + """Create an LLMJudge with FIRST_SAMPLE methodology.""" + return LLMJudge( + judge_model=mock_service, methodology=EvaluationMethodology.FIRST_SAMPLE + ) + + @pytest.fixture + def judge_pass_n(self, mock_service): + """Create an LLMJudge with PASS_N methodology.""" + return LLMJudge( + judge_model=mock_service, methodology=EvaluationMethodology.PASS_N + ) + + @pytest.mark.asyncio + async def test_majority_vote_true_case(self, judge_majority): + mock_outputs = [ + MockCompletionOutput(text="yes"), # matches + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="YES"), # matches (case insensitive) + MockCompletionOutput(text="yes "), # matches (stripped) + MockCompletionOutput(text="maybe"), # doesn't match + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_majority, "_generate", return_value=mock_request_output + ): + result = await judge_majority.evaluate_response("What is 2+2?", "yes") + assert result is True + + @pytest.mark.asyncio + async def test_majority_vote_false_case(self, judge_majority): + mock_outputs = [ + MockCompletionOutput(text="yes"), # matches + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="maybe"), # doesn't match + MockCompletionOutput(text="YES"), # matches (case insensitive) + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_majority, "_generate", return_value=mock_request_output + ): + result = await judge_majority.evaluate_response("What is 2+2?", "yes") + assert result is False + + @pytest.mark.asyncio + async def test_first_sample_true_case(self, judge_first_sample): + mock_outputs = [ + MockCompletionOutput(text="YES"), # matches (case insensitive) + MockCompletionOutput(text="no"), # doesn't matter + MockCompletionOutput(text="maybe"), # doesn't matter + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_first_sample, "_generate", return_value=mock_request_output + ): + result = await judge_first_sample.evaluate_response("What is 2+2?", "yes") + assert result is True + + @pytest.mark.asyncio + async def test_first_sample_false_case(self, judge_first_sample): + mock_outputs = [ + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="yes"), # doesn't matter + MockCompletionOutput(text="YES"), # doesn't matter + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_first_sample, "_generate", return_value=mock_request_output + ): + result = await judge_first_sample.evaluate_response("What is 2+2?", "yes") + assert result is False + + @pytest.mark.asyncio + async def test_pass_n_true_case(self, judge_pass_n): + mock_outputs = [ + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="maybe"), # doesn't match + MockCompletionOutput(text="YES"), # matches (case insensitive) + MockCompletionOutput(text="no"), # doesn't match + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object(judge_pass_n, "_generate", return_value=mock_request_output): + result = await judge_pass_n.evaluate_response("What is 2+2?", "yes") + assert result is True + + @pytest.mark.asyncio + async def test_pass_n_false_case(self, judge_pass_n): + mock_outputs = [ + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="maybe"), # doesn't match + MockCompletionOutput(text="four"), # doesn't match + MockCompletionOutput(text="nope"), # doesn't match + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object(judge_pass_n, "_generate", return_value=mock_request_output): + result = await judge_pass_n.evaluate_response("What is 2+2?", "yes") + assert result is False + + @pytest.mark.asyncio + async def test_case_insensitive_and_whitespace_handling(self, judge_majority): + mock_outputs = [ + MockCompletionOutput(text="YES"), # matches + MockCompletionOutput(text=" yes "), # matches (with whitespace) + MockCompletionOutput(text="Yes"), # matches + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="NO"), # doesn't match + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_majority, "_generate", return_value=mock_request_output + ): + result = await judge_majority.evaluate_response("What is 2+2?", " YES ") + assert result is True + + @pytest.mark.asyncio + async def test_empty_outputs_handling(self, judge_majority): + """Test handling of empty outputs list.""" + mock_outputs = [] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_majority, "_generate", return_value=mock_request_output + ): + result = await judge_majority.evaluate_response("What is 2+2?", "yes") + assert result is False # 0 out of 0 match, which is not > 0//2 = 0 + + @pytest.mark.asyncio + async def test_unknown_evaluation_methodology(self, mock_service): + """Test that unknown evaluation methodology raises ValueError.""" + judge = LLMJudge(judge_model=mock_service, methodology="INVALID") + + mock_outputs = [MockCompletionOutput(text="yes")] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object(judge, "_generate", return_value=mock_request_output): + with pytest.raises( + ValueError, match="Unknown evaluation methodology: INVALID" + ): + await judge.evaluate_response("What is 2+2?", "yes")