From eefb9b26cd82cf48905d5b60e28de697d17c3295 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 16 Sep 2025 13:11:17 -0700 Subject: [PATCH 1/4] Push --- src/forge/data/judge.py | 44 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 src/forge/data/judge.py diff --git a/src/forge/data/judge.py b/src/forge/data/judge.py new file mode 100644 index 000000000..57990c7dd --- /dev/null +++ b/src/forge/data/judge.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +from enum import Enum + +from forge.controller.service.interface import ServiceInterface + + +class EvaluationMethodology(str, Enum): + """Evaluation methodology for LLM Judge.""" + + MAJORITY = "Majority" + FIRST_SAMPLE = "First sample" + PASS = "Pass" + + +@dataclass +class LLMJudge: + """Simple interface for Judges utilizing LLMs.""" + + judge_model: ServiceInterface = judge_model + methodology: EvaluationMethodology = EvaluationMethodology.MAJORITY + + async def _generate(self, prompt: str) -> RequestOutput: + """Internally generate responses.""" + return await self.judge_model.generate.call(prompt=prompt) + + async def evaluate_response(self, prompt: str, response: str) -> float: + """Evaluate a response to a prompt.""" + outputs: RequestOutput = await self._generate(prompt) + match self.methodology: + case EvaluationMethodology.MAJORITY: + return await self._majority_vote(response, outputs) + case EvaluationMethodology.FIRST_SAMPLE: + return await self._first_sample(response, outputs) + case EvaluationMethodology.PASS: + return await self._pass(response, outputs) + case _: + raise ValueError(f"Unknown evaluation methodology: {self.methodology}") From b256f32426518caef7414bb5039b0c7b669ea3fd Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 16 Sep 2025 17:26:22 -0700 Subject: [PATCH 2/4] Add test and impl --- src/forge/data/judge.py | 53 ++++++-- tests/unit_tests/data/test_judge.py | 189 ++++++++++++++++++++++++++++ 2 files changed, 235 insertions(+), 7 deletions(-) create mode 100644 tests/unit_tests/data/test_judge.py diff --git a/src/forge/data/judge.py b/src/forge/data/judge.py index 57990c7dd..ceffebda4 100644 --- a/src/forge/data/judge.py +++ b/src/forge/data/judge.py @@ -5,9 +5,10 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass - from enum import Enum +from vllm.outputs import RequestOutput + from forge.controller.service.interface import ServiceInterface @@ -15,20 +16,20 @@ class EvaluationMethodology(str, Enum): """Evaluation methodology for LLM Judge.""" MAJORITY = "Majority" - FIRST_SAMPLE = "First sample" - PASS = "Pass" + FIRST_SAMPLE = "First" + PASS_N = "Pass N" @dataclass class LLMJudge: """Simple interface for Judges utilizing LLMs.""" - judge_model: ServiceInterface = judge_model + judge_model: ServiceInterface methodology: EvaluationMethodology = EvaluationMethodology.MAJORITY async def _generate(self, prompt: str) -> RequestOutput: """Internally generate responses.""" - return await self.judge_model.generate.call(prompt=prompt) + return await self.judge_model.generate.choose(prompt=prompt) async def evaluate_response(self, prompt: str, response: str) -> float: """Evaluate a response to a prompt.""" @@ -38,7 +39,45 @@ async def evaluate_response(self, prompt: str, response: str) -> float: return await self._majority_vote(response, outputs) case EvaluationMethodology.FIRST_SAMPLE: return await self._first_sample(response, outputs) - case EvaluationMethodology.PASS: - return await self._pass(response, outputs) + case EvaluationMethodology.PASS_N: + return await self._pass_n(response, outputs) case _: raise ValueError(f"Unknown evaluation methodology: {self.methodology}") + + async def _majority_vote(self, response: str, outputs: RequestOutput) -> bool: + """ + Return whether at least half of the outputs match the response + """ + matching = 0 + response_normalized = response.lower().strip() + + for output in outputs.outputs: + output_normalized = output.text.lower().strip() + if response_normalized == output_normalized: + matching += 1 + print(output.text) + + return matching > (len(outputs.outputs) // 2) + + async def _first_sample(self, response: str, outputs: RequestOutput) -> float: + """ + Returns whether there is a match to the first output + """ + first_output = outputs.outputs[0] + output_normalized = first_output.text.lower().strip() + response_normalized = response.lower().strip() + + return output_normalized == response_normalized + + async def _pass_n(self, response: str, outputs: RequestOutput) -> float: + """ + Return whether any of the outputs match the response + """ + response_normalized = response.lower().strip() + + for output in outputs.outputs: + output_normalized = output.text.lower().strip() + if response_normalized == output_normalized: + return True + + return False diff --git a/tests/unit_tests/data/test_judge.py b/tests/unit_tests/data/test_judge.py new file mode 100644 index 000000000..0e050966e --- /dev/null +++ b/tests/unit_tests/data/test_judge.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from forge.controller.service.interface import ServiceInterface + +from forge.data.judge import EvaluationMethodology, LLMJudge + + +# Mock classes to simulate VLLM RequestOutput structure +@dataclass +class MockCompletionOutput: + text: str + + +@dataclass +class MockRequestOutput: + outputs: List[MockCompletionOutput] + + +class TestLLMJudge: + @pytest.fixture + def mock_service(self): + """Create a mock ServiceInterface for testing.""" + service = Mock(spec=ServiceInterface) + service.generate = AsyncMock() + return service + + @pytest.fixture + def judge_majority(self, mock_service): + return LLMJudge( + judge_model=mock_service, methodology=EvaluationMethodology.MAJORITY + ) + + @pytest.fixture + def judge_first_sample(self, mock_service): + """Create an LLMJudge with FIRST_SAMPLE methodology.""" + return LLMJudge( + judge_model=mock_service, methodology=EvaluationMethodology.FIRST_SAMPLE + ) + + @pytest.fixture + def judge_pass_n(self, mock_service): + """Create an LLMJudge with PASS_N methodology.""" + return LLMJudge( + judge_model=mock_service, methodology=EvaluationMethodology.PASS_N + ) + + @pytest.mark.asyncio + async def test_majority_vote_true_case(self, judge_majority): + mock_outputs = [ + MockCompletionOutput(text="yes"), # matches + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="YES"), # matches (case insensitive) + MockCompletionOutput(text="yes "), # matches (stripped) + MockCompletionOutput(text="maybe"), # doesn't match + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_majority, "_generate", return_value=mock_request_output + ): + result = await judge_majority.evaluate_response("What is 2+2?", "yes") + assert result is True + + @pytest.mark.asyncio + async def test_majority_vote_false_case(self, judge_majority): + mock_outputs = [ + MockCompletionOutput(text="yes"), # matches + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="maybe"), # doesn't match + MockCompletionOutput(text="YES"), # matches (case insensitive) + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_majority, "_generate", return_value=mock_request_output + ): + result = await judge_majority.evaluate_response("What is 2+2?", "yes") + assert result is False + + @pytest.mark.asyncio + async def test_first_sample_true_case(self, judge_first_sample): + mock_outputs = [ + MockCompletionOutput(text="YES"), # matches (case insensitive) + MockCompletionOutput(text="no"), # doesn't matter + MockCompletionOutput(text="maybe"), # doesn't matter + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_first_sample, "_generate", return_value=mock_request_output + ): + result = await judge_first_sample.evaluate_response("What is 2+2?", "yes") + assert result is True + + @pytest.mark.asyncio + async def test_first_sample_false_case(self, judge_first_sample): + mock_outputs = [ + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="yes"), # doesn't matter + MockCompletionOutput(text="YES"), # doesn't matter + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_first_sample, "_generate", return_value=mock_request_output + ): + result = await judge_first_sample.evaluate_response("What is 2+2?", "yes") + assert result is False + + @pytest.mark.asyncio + async def test_pass_n_true_case(self, judge_pass_n): + mock_outputs = [ + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="maybe"), # doesn't match + MockCompletionOutput(text="YES"), # matches (case insensitive) + MockCompletionOutput(text="no"), # doesn't match + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object(judge_pass_n, "_generate", return_value=mock_request_output): + result = await judge_pass_n.evaluate_response("What is 2+2?", "yes") + assert result is True + + @pytest.mark.asyncio + async def test_pass_n_false_case(self, judge_pass_n): + mock_outputs = [ + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="maybe"), # doesn't match + MockCompletionOutput(text="four"), # doesn't match + MockCompletionOutput(text="nope"), # doesn't match + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object(judge_pass_n, "_generate", return_value=mock_request_output): + result = await judge_pass_n.evaluate_response("What is 2+2?", "yes") + assert result is False + + @pytest.mark.asyncio + async def test_case_insensitive_and_whitespace_handling(self, judge_majority): + mock_outputs = [ + MockCompletionOutput(text="YES"), # matches + MockCompletionOutput(text=" yes "), # matches (with whitespace) + MockCompletionOutput(text="Yes"), # matches + MockCompletionOutput(text="no"), # doesn't match + MockCompletionOutput(text="NO"), # doesn't match + ] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_majority, "_generate", return_value=mock_request_output + ): + result = await judge_majority.evaluate_response("What is 2+2?", " YES ") + assert result is True + + @pytest.mark.asyncio + async def test_empty_outputs_handling(self, judge_majority): + """Test handling of empty outputs list.""" + mock_outputs = [] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object( + judge_majority, "_generate", return_value=mock_request_output + ): + result = await judge_majority.evaluate_response("What is 2+2?", "yes") + assert result is False # 0 out of 0 match, which is not > 0//2 = 0 + + @pytest.mark.asyncio + async def test_unknown_evaluation_methodology(self, mock_service): + """Test that unknown evaluation methodology raises ValueError.""" + judge = LLMJudge(judge_model=mock_service, methodology="INVALID") + + mock_outputs = [MockCompletionOutput(text="yes")] + mock_request_output = MockRequestOutput(outputs=mock_outputs) + + with patch.object(judge, "_generate", return_value=mock_request_output): + with pytest.raises( + ValueError, match="Unknown evaluation methodology: INVALID" + ): + await judge.evaluate_response("What is 2+2?", "yes") From df9b32db94b33cd6e9cf3342a7e2454677683424 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 16 Sep 2025 17:42:03 -0700 Subject: [PATCH 3/4] Fixed typehint --- src/forge/data/judge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/forge/data/judge.py b/src/forge/data/judge.py index ceffebda4..a040e7ed2 100644 --- a/src/forge/data/judge.py +++ b/src/forge/data/judge.py @@ -59,7 +59,7 @@ async def _majority_vote(self, response: str, outputs: RequestOutput) -> bool: return matching > (len(outputs.outputs) // 2) - async def _first_sample(self, response: str, outputs: RequestOutput) -> float: + async def _first_sample(self, response: str, outputs: RequestOutput) -> bool: """ Returns whether there is a match to the first output """ @@ -69,7 +69,7 @@ async def _first_sample(self, response: str, outputs: RequestOutput) -> float: return output_normalized == response_normalized - async def _pass_n(self, response: str, outputs: RequestOutput) -> float: + async def _pass_n(self, response: str, outputs: RequestOutput) -> bool: """ Return whether any of the outputs match the response """ From c3f2ec5dded1d4e7bab52e33d56a35d7569f914b Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 16 Sep 2025 17:51:59 -0700 Subject: [PATCH 4/4] Temp import check until we have vllm nightly intergrated --- src/forge/data/judge.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/forge/data/judge.py b/src/forge/data/judge.py index a040e7ed2..96b59bae5 100644 --- a/src/forge/data/judge.py +++ b/src/forge/data/judge.py @@ -7,7 +7,11 @@ from dataclasses import dataclass from enum import Enum -from vllm.outputs import RequestOutput +try: + from vllm.outputs import RequestOutput +except ImportError as e: + print(f"Failed to import RequestOutput from vllm.outputs: {e}") + RequestOutput = "RequestOutput" from forge.controller.service.interface import ServiceInterface