diff --git a/prompting/rewards/relevance.py b/prompting/rewards/relevance.py index 65691a5f3..4745f82c1 100644 --- a/prompting/rewards/relevance.py +++ b/prompting/rewards/relevance.py @@ -2,42 +2,17 @@ from typing import Optional import numpy as np -import requests from pydantic import ConfigDict from scipy import spatial from prompting.rewards.reward import BaseRewardModel, BatchRewardOutput -from shared import constants, settings +from shared import settings from shared.dendrite import DendriteResponseEvent +from shared.docker_utils import get_embeddings shared_settings = settings.shared_settings -def get_embeddings(inputs): - """ - Sends a POST request to the local embeddings endpoint and returns the response. - - Args: - inputs (str or list of str): A single input string or a list of input strings to embed. - - Returns: - dict: JSON response from the embeddings server. - """ - if isinstance(inputs, str): - inputs = [inputs] # convert single string to list - - url = f"{constants.DOCKER_BASE_URL}/v1/embeddings" - headers = {"Content-Type": "application/json"} - payload = {"input": inputs} - - try: - response = requests.post(url, headers=headers, json=payload) - response.raise_for_status() - return response.json() - except requests.RequestException as e: - return {"error": str(e)} - - class RelevanceRewardModel(BaseRewardModel): threshold: Optional[float] = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/prompting/rewards/web_retrieval.py b/prompting/rewards/web_retrieval.py index 7d67800d3..1f159e086 100644 --- a/prompting/rewards/web_retrieval.py +++ b/prompting/rewards/web_retrieval.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd -import requests import whois from loguru import logger from pydantic import BaseModel @@ -19,36 +18,10 @@ from prompting.rewards.relevance import RelevanceRewardModel from prompting.rewards.reward import BatchRewardOutput from prompting.tasks.base_task import BaseTextTask -from shared import constants from shared.dendrite import DendriteResponseEvent +from shared.docker_utils import get_embeddings from shared.misc import async_lru_cache - -def get_embeddings(inputs): - """ - Sends a POST request to the local embeddings endpoint and returns the response. - - Args: - inputs (str or list of str): A single input string or a list of input strings to embed. - - Returns: - dict: JSON response from the embeddings server. - """ - if isinstance(inputs, str): - inputs = [inputs] # convert single string to list - - url = f"{constants.DOCKER_BASE_URL}/v1/embeddings" - headers = {"Content-Type": "application/json"} - payload = {"input": inputs} - - try: - response = requests.post(url, headers=headers, json=payload) - response.raise_for_status() - return response.json() - except requests.RequestException as e: - return {"error": str(e)} - - MIN_RELEVANT_CHARS = 300 MIN_MATCH_THRESHOLD = 98 diff --git a/shared/docker_utils.py b/shared/docker_utils.py index 6a189f34b..fb9774228 100644 --- a/shared/docker_utils.py +++ b/shared/docker_utils.py @@ -55,3 +55,32 @@ async def get_logits( except requests.exceptions.JSONDecodeError: logger.error(f"Error generating logits. Status: {response.status_code}, Body: {response.text}") return "" + + +def get_embeddings(inputs): + """ + Sends a POST request to the local embeddings endpoint and returns the response. + + Args: + inputs (str or list of str): A single input string or a list of input strings to embed. + + Returns: + dict: JSON response from the embeddings server. + """ + if isinstance(inputs, str): + inputs = [inputs] # convert single string to list + + url = f"{constants.DOCKER_BASE_URL}/v1/embeddings" + headers = {"Content-Type": "application/json"} + payload = {"input": inputs} + + try: + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + return {"error": str(e)} + + +if __name__ == "__main__": + print(get_embeddings("Hello, world!")) diff --git a/tests/prompting/rewards/test_exact_match.py b/tests/prompting/rewards/test_exact_match.py index e2f3489b8..b0d2f96dd 100644 --- a/tests/prompting/rewards/test_exact_match.py +++ b/tests/prompting/rewards/test_exact_match.py @@ -1,342 +1,341 @@ -# from typing import Any -# from unittest.mock import AsyncMock, MagicMock, patch - -# import numpy as np -# import pytest - -# from prompting.llms.model_manager import ModelManager -# from prompting.rewards.exact_match import ( -# INCORRECT_PENALTY, -# MAX_VERIFY_TOKENS, -# MIN_SMOOTH_PENALTY_SCALE, -# MIN_VERIFY_TOKENS, -# PARTIAL_PENALTY, -# TOP_LOGPROBS, -# VERIFICATION_THRESH_SIM, -# LogitsRewardModel, -# ) -# from prompting.rewards.reward import BatchRewardOutput -# from prompting.tasks.base_task import BaseTextTask -# from shared.dendrite import DendriteResponseEvent - - -# @pytest.fixture -# def model_manager(): -# """Mock ModelManager for testing.""" -# manager = MagicMock(spec=ModelManager) -# model = MagicMock() -# tokenizer = MagicMock() -# tokenizer.eos_token = "<|endoftext|>" - -# model.tokenizer = tokenizer -# model.get_max_tokens = AsyncMock(return_value=2048) - -# manager.get_model.return_value = model - -# async def mock_generate_logits(*args, **kwargs): -# return {"token1": -0.1, "token2": -0.5, "<|endoftext|>": -1.0}, "prompt" - -# manager.generate_logits = AsyncMock(side_effect=mock_generate_logits) -# return manager - - -# @pytest.fixture -# def task(): -# """Mock Task for testing.""" -# task = MagicMock(spec=BaseTextTask) -# task.llm_model_id = "mockmodel" -# task.task_messages = [ -# {"role": "system", "content": "You are a helpful assistant."}, -# {"role": "user", "content": "Tell me a joke."}, -# ] -# task.sampling_params = {"temperature": 0.7, "max_tokens": 100} -# return task - - -# def create_chat_completion_chunk( -# content: str = "", -# logprobs: dict[str, float] | None = None, -# top_logprobs: int = 5, -# ) -> dict[str, Any]: -# """Return a dict that looks like an OpenAI `ChatCompletionChunk`.""" - -# # Default log-probabilities if none provided. -# if logprobs is None: -# logprobs = { -# content: -0.1, -# "token2": -0.5, -# "token3": -0.6, -# "token4": -0.7, -# "<|endoftext|>": -1.0, -# } - -# choice_dict: dict[str, Any] = { -# "index": 0, -# "delta": {"role": "assistant", "content": content}, -# } - -# # Only include the `logprobs` block when tokens were supplied. -# if logprobs: -# choice_dict["logprobs"] = { -# "content": [ -# {"top_logprobs": [{"token": tok, "logprob": lp} for tok, lp in list(logprobs.items())[:top_logprobs]]} -# ] -# } -# else: -# choice_dict["logprobs"] = None - -# # Assemble the full chunk. -# chunk_dict: dict[str, Any] = { -# "id": "chunk_id", -# "object": "chat.completion.chunk", -# "created": 1234567890, -# "model": "VeryStronkModel", -# "choices": [choice_dict], -# "usage": None, -# } - -# return chunk_dict - - -# async def create_response_event_mock(chunks_all, timings_all, timeout: float = 10) -> MagicMock: -# completions = ["".join(chunks) for chunks in chunks_all] -# chunk_dicts_raw = [] -# for chunks in chunks_all: -# chunk_dicts_raw.append([create_chat_completion_chunk(chunk) for chunk in chunks]) - -# response_event = MagicMock(spec=DendriteResponseEvent) -# response_event.stream_results_all_chunks = chunks_all -# response_event.stream_results_all_chunk_dicts_raw = chunk_dicts_raw -# response_event.uids = list(range(len(chunks_all))) -# response_event.stream_results_all_chunks_timings = timings_all -# response_event.completions = completions -# response_event.timeout = timeout -# return response_event - - -# @pytest.mark.asyncio -# async def test_correct_completion(model_manager, task): -# """Test case 1: Correct completion with reward >0.5 and ≤1.""" -# chunks_all = [["Hello", ", ", "world", "!"]] -# chunks_timings_all = [[0.1, 0.1, 0.1, 0.1]] -# response_event = await create_response_event_mock(chunks_all, chunks_timings_all) -# chunk_dicts_raw = [] -# for chunks in chunks_all: -# chunk_dicts_raw.append([create_chat_completion_chunk(chunk) for chunk in chunks]) - -# with ( -# patch("prompting.rewards.exact_match.MIN_VERIFY_TOKENS", 2), -# patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity", return_value=1), -# patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains", return_value=1), -# ): -# reward_model = LogitsRewardModel() -# result = await reward_model.reward( -# reference="", response_event=response_event, task=task, model_manager=model_manager -# ) -# assert isinstance(result, BatchRewardOutput) -# assert len(result.rewards) == 1 -# assert result.rewards[0] == pytest.approx(1.0) - - -# @pytest.mark.asyncio -# async def test_mixed_completions(model_manager, task): -# """Test case 2: One ideal completion, one with missing logprobs penalized.""" -# top_logprobs = 5 -# chunks_timings_all = [[0.1, 0.2, 0.3, 0.4] for _ in range(3)] -# chunks_all = [["Hello", ", ", "world", "!"], ["Fail", "ed", " ", "completion"], ["Wro", "ng", " ", "completion"]] -# chunk_dicts_raw: list[list[dict[str, float]]] = [] - -# correct_logprobs: list[dict[str, float]] = [] -# for part in chunks_all[0]: -# correct_logprobs.append(create_chat_completion_chunk(part, top_logprobs=top_logprobs)) -# chunk_dicts_raw.append(correct_logprobs) - -# incorrect_logprobs: list[dict[str, float]] = [] -# wrong_logprobs: dict[str, float] = { -# "wrong": -0.1, -# "log": -5.43, -# "prob": -8.54, -# "defined": -11, -# "<|endoftext|>": -3000000, -# } -# for part in chunks_all[1]: -# incorrect_logprobs.append(create_chat_completion_chunk(part, logprobs=wrong_logprobs)) -# chunk_dicts_raw.append(incorrect_logprobs) - -# empty_logprobs: list[dict[str, float]] = [] -# for part in chunks_all[2]: -# empty_logprobs.append(create_chat_completion_chunk(part, logprobs={})) -# chunk_dicts_raw.append(empty_logprobs) - -# response_event = await create_response_event_mock(chunks_all, chunks_timings_all) -# response_event.stream_results_all_chunk_dicts_raw = chunk_dicts_raw - -# def mock_verify_sim(original_logits, verification_logits): -# return 1.0 if original_logits and "wrong" not in original_logits else VERIFICATION_THRESH_SIM * 0.9 - -# with ( -# patch("prompting.rewards.exact_match.MIN_VERIFY_TOKENS", 2), -# patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity", side_effect=mock_verify_sim), -# patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains", return_value=1), -# ): -# reward_model = LogitsRewardModel() -# result = await reward_model.reward( -# reference="", response_event=response_event, task=task, model_manager=model_manager -# ) - -# assert isinstance(result, BatchRewardOutput) -# assert len(result.rewards) == len(chunk_dicts_raw) -# assert 0.2 < result.rewards[0] <= 1.0 -# assert result.rewards[1] == INCORRECT_PENALTY -# assert result.rewards[2] == INCORRECT_PENALTY - - -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# "eos_in_logits, expected_penalty", -# [ -# (True, None), -# (False, PARTIAL_PENALTY), -# ], -# ids=["eos_present", "eos_missing"], -# ) -# async def test_eos_handling(eos_in_logits, expected_penalty, model_manager, task): -# emitted = ["Hello", ", ", "world", "!"] -# timings = [[0.1] * len(emitted)] -# response_event = await create_response_event_mock([emitted], timings) -# verify_logits = {"tokA": -0.1, "tokB": -0.5} -# if eos_in_logits: -# verify_logits["<|endoftext|>"] = -1.0 -# model_manager.generate_logits = AsyncMock(return_value=(verify_logits, "prompt")) - -# with ( -# patch("prompting.rewards.exact_match.MIN_VERIFY_TOKENS", 2), -# patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity", return_value=1), -# patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains", return_value=1), -# ): -# reward_model = LogitsRewardModel() -# result: BatchRewardOutput = await reward_model.reward( -# reference="", -# response_event=response_event, -# task=task, -# model_manager=model_manager, -# ) - -# assert isinstance(result, BatchRewardOutput) -# assert len(result.rewards) == 1 -# if expected_penalty is None: -# # eos present. -# assert result.rewards[0] != PARTIAL_PENALTY -# else: -# # eos missing. -# assert result.rewards[0] == pytest.approx(expected_penalty) - - -# def test_verify_logit_similarity(): -# """Test the verify_logit_similarity similarity metric.""" -# original = {f"token{idx}": -0.01 for idx in range(TOP_LOGPROBS)} -# # Identical distributions -> 1.0. -# assert LogitsRewardModel.verify_logit_similarity(original, original) == pytest.approx(1.0) - -# with patch("prompting.rewards.exact_match.TOP_LOGPROBS", 5): -# # Disjoint tokens -> near zero. -# disjoint = {"foo": -0.1, "bar": -0.5, "foo1": -1.0, "bar1": -1.5, "foo2": -2.0} -# sim = LogitsRewardModel.verify_logit_similarity(original, disjoint) -# assert sim == pytest.approx(0.0) - -# # Partial overlap -> between 0 and 1. -# original = {"token1": -0.1, "token2": -0.5, "token3": -1.0, "foo1": -1.5, "bar1": -2.0} -# partial = {"token1": -0.1, "token2": -0.5, "token3": -1.0, "token4": -1.5, "token5": -2.0} -# sim2 = LogitsRewardModel.verify_logit_similarity(original, partial) -# assert sim2 == pytest.approx(0.6) - - -# def test_smooth_reward_scale(): -# """Test the smooth_reward_scale function under various conditions.""" -# # Test empty timings list. -# assert LogitsRewardModel.smooth_timings_reward([]) == 0.0 - -# # Test uniform timings (should give maximum reward). -# uniform_timings = [0.1, 0.1, 0.1, 0.1, 0.1] -# assert LogitsRewardModel.smooth_timings_reward(uniform_timings) == pytest.approx(1.0) - -# # Test high variance timings (should give minimum reward). -# high_var_timings = [0.1, 0.1, 15.0, 0.1, 0.1] -# assert LogitsRewardModel.smooth_timings_reward(high_var_timings) == MIN_SMOOTH_PENALTY_SCALE - -# # Test moderate variance timings. -# moderate_var_timings = [0.3, 0.2, 0.4, 0.1, 0.1] -# assert LogitsRewardModel.smooth_timings_reward(moderate_var_timings) == pytest.approx(1.0) - -# # Test with custom minimum reward. -# custom_min = 0.8 -# assert LogitsRewardModel.smooth_timings_reward(high_var_timings, min_reward=custom_min) == custom_min - -# # Test with single timing value. -# single_timing = [1.5] -# assert LogitsRewardModel.smooth_timings_reward(single_timing) == 1.0 - - -# @pytest.mark.parametrize( -# "value, min_value, expected", -# [ -# # Linear mapping. -# (0.6, 0.2, (0.6 - 0.2) / (1.0 - 0.2)), -# # Below min clips to 0.0. -# (0.1, 0.3, 0.0), -# # Above max clips to 1.0. -# (1.2, 0.0, 1.0), -# # At min boundary. -# (0.3, 0.3, 0.0), -# # At max boundary. -# (1.0, 0.3, 1.0), -# ], -# ) -# def test_rescale_various_cases(value, min_value, expected): -# assert LogitsRewardModel.rescale(value, min_value=min_value) == pytest.approx(expected) - - -# @pytest.mark.parametrize( -# "values, expected", -# [ -# # All valid. -# ([[0.1, 1.0], [5.0, 0.1], [6.5]], 0.55), -# # Mixed values. -# ([[-1.0, 0.5], [2.0, 0.1]], 1.05), -# # All negative. -# ([[-3.0, -0.1], [-2.5]], 1e-6), -# # Empty lists. -# ([[], []], 1e-6), -# # Zeros included. -# ([[0.0, -1.0], [0.0]], 0.0), -# ], -# ) -# def test_fastest_timing_various_cases(values, expected): -# assert LogitsRewardModel.fastest_timing(values) == pytest.approx(expected) - - -# @pytest.mark.parametrize( -# "completion_length", -# [ -# 5, -# (MIN_VERIFY_TOKENS + MAX_VERIFY_TOKENS) // 2, -# MAX_VERIFY_TOKENS, -# MAX_VERIFY_TOKENS + 5, -# ], -# ) -# def test_sample_verification_indices_properties(completion_length): -# indices = LogitsRewardModel.sample_verification_indices(completion_length) - -# # Compute expected number of sampled tokens with first and eos indices. -# expected_k = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) - -# # The result should have expected_k samples plus one EOS index. -# assert isinstance(indices, list) -# assert len(indices) == expected_k -# assert indices == sorted(indices) -# assert indices[-1] == completion_length -# # All other indices should be in the range [0, completion_length). -# sample_indices = indices[:-1] -# assert all(0 <= idx < completion_length for idx in sample_indices) -# # No duplicates overall. -# assert len(set(indices)) == len(indices) +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest + +from prompting.rewards.exact_match import ( + INCORRECT_PENALTY, + MAX_VERIFY_TOKENS, + MIN_SMOOTH_PENALTY_SCALE, + MIN_VERIFY_TOKENS, + PARTIAL_PENALTY, + TOP_LOGPROBS, + VERIFICATION_THRESH_SIM, + LogitsRewardModel, +) +from prompting.rewards.reward import BatchRewardOutput +from prompting.tasks.base_task import BaseTextTask +from shared.dendrite import DendriteResponseEvent + + +@pytest.fixture +def model_manager(): + """Mock ModelManager for testing.""" + manager = MagicMock() + model = MagicMock() + tokenizer = MagicMock() + tokenizer.eos_token = "<|endoftext|>" + + model.tokenizer = tokenizer + model.get_max_tokens = AsyncMock(return_value=2048) + + manager.get_model.return_value = model + + async def mock_generate_logits(*args, **kwargs): + return {"token1": -0.1, "token2": -0.5, "<|endoftext|>": -1.0}, "prompt" + + manager.generate_logits = AsyncMock(side_effect=mock_generate_logits) + return manager + + +@pytest.fixture +def task(): + """Mock Task for testing.""" + task = MagicMock(spec=BaseTextTask) + task.llm_model_id = "mockmodel" + task.task_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + task.sampling_params = {"temperature": 0.7, "max_tokens": 100} + return task + + +def create_chat_completion_chunk( + content: str = "", + logprobs: dict[str, float] | None = None, + top_logprobs: int = 5, +) -> dict[str, Any]: + """Return a dict that looks like an OpenAI `ChatCompletionChunk`.""" + + # Default log-probabilities if none provided. + if logprobs is None: + logprobs = { + content: -0.1, + "token2": -0.5, + "token3": -0.6, + "token4": -0.7, + "<|endoftext|>": -1.0, + } + + choice_dict: dict[str, Any] = { + "index": 0, + "delta": {"role": "assistant", "content": content}, + } + + # Only include the `logprobs` block when tokens were supplied. + if logprobs: + choice_dict["logprobs"] = { + "content": [ + {"top_logprobs": [{"token": tok, "logprob": lp} for tok, lp in list(logprobs.items())[:top_logprobs]]} + ] + } + else: + choice_dict["logprobs"] = None + + # Assemble the full chunk. + chunk_dict: dict[str, Any] = { + "id": "chunk_id", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "VeryStronkModel", + "choices": [choice_dict], + "usage": None, + } + + return chunk_dict + + +async def create_response_event_mock(chunks_all, timings_all, timeout: float = 10) -> MagicMock: + completions = ["".join(chunks) for chunks in chunks_all] + chunk_dicts_raw = [] + for chunks in chunks_all: + chunk_dicts_raw.append([create_chat_completion_chunk(chunk) for chunk in chunks]) + + response_event = MagicMock(spec=DendriteResponseEvent) + response_event.stream_results_all_chunks = chunks_all + response_event.stream_results_all_chunk_dicts_raw = chunk_dicts_raw + response_event.uids = list(range(len(chunks_all))) + response_event.stream_results_all_chunks_timings = timings_all + response_event.completions = completions + response_event.timeout = timeout + return response_event + + +@pytest.mark.asyncio +async def test_correct_completion(model_manager, task): + """Test case 1: Correct completion with reward >0.5 and ≤1.""" + chunks_all = [["Hello", ", ", "world", "!"]] + chunks_timings_all = [[0.1, 0.1, 0.1, 0.1]] + response_event = await create_response_event_mock(chunks_all, chunks_timings_all) + chunk_dicts_raw = [] + for chunks in chunks_all: + chunk_dicts_raw.append([create_chat_completion_chunk(chunk) for chunk in chunks]) + + with ( + patch("prompting.rewards.exact_match.MIN_VERIFY_TOKENS", 2), + patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity", return_value=1), + patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains", return_value=1), + ): + reward_model = LogitsRewardModel() + result = await reward_model.reward( + reference="", response_event=response_event, task=task, model_manager=model_manager + ) + assert isinstance(result, BatchRewardOutput) + assert len(result.rewards) == 1 + assert result.rewards[0] == pytest.approx(1.0) + + +@pytest.mark.asyncio +async def test_mixed_completions(model_manager, task): + """Test case 2: One ideal completion, one with missing logprobs penalized.""" + top_logprobs = 5 + chunks_timings_all = [[0.1, 0.2, 0.3, 0.4] for _ in range(3)] + chunks_all = [["Hello", ", ", "world", "!"], ["Fail", "ed", " ", "completion"], ["Wro", "ng", " ", "completion"]] + chunk_dicts_raw: list[list[dict[str, float]]] = [] + + correct_logprobs: list[dict[str, float]] = [] + for part in chunks_all[0]: + correct_logprobs.append(create_chat_completion_chunk(part, top_logprobs=top_logprobs)) + chunk_dicts_raw.append(correct_logprobs) + + incorrect_logprobs: list[dict[str, float]] = [] + wrong_logprobs: dict[str, float] = { + "wrong": -0.1, + "log": -5.43, + "prob": -8.54, + "defined": -11, + "<|endoftext|>": -3000000, + } + for part in chunks_all[1]: + incorrect_logprobs.append(create_chat_completion_chunk(part, logprobs=wrong_logprobs)) + chunk_dicts_raw.append(incorrect_logprobs) + + empty_logprobs: list[dict[str, float]] = [] + for part in chunks_all[2]: + empty_logprobs.append(create_chat_completion_chunk(part, logprobs={})) + chunk_dicts_raw.append(empty_logprobs) + + response_event = await create_response_event_mock(chunks_all, chunks_timings_all) + response_event.stream_results_all_chunk_dicts_raw = chunk_dicts_raw + + def mock_verify_sim(original_logits, verification_logits): + return 1.0 if original_logits and "wrong" not in original_logits else VERIFICATION_THRESH_SIM * 0.9 + + with ( + patch("prompting.rewards.exact_match.MIN_VERIFY_TOKENS", 2), + patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity", side_effect=mock_verify_sim), + patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains", return_value=1), + ): + reward_model = LogitsRewardModel() + result = await reward_model.reward( + reference="", response_event=response_event, task=task, model_manager=model_manager + ) + + assert isinstance(result, BatchRewardOutput) + assert len(result.rewards) == len(chunk_dicts_raw) + assert 0.2 < result.rewards[0] <= 1.0 + assert result.rewards[1] == INCORRECT_PENALTY + assert result.rewards[2] == INCORRECT_PENALTY + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "eos_in_logits, expected_penalty", + [ + (True, None), + (False, PARTIAL_PENALTY), + ], + ids=["eos_present", "eos_missing"], +) +async def test_eos_handling(eos_in_logits, expected_penalty, model_manager, task): + emitted = ["Hello", ", ", "world", "!"] + timings = [[0.1] * len(emitted)] + response_event = await create_response_event_mock([emitted], timings) + verify_logits = {"tokA": -0.1, "tokB": -0.5} + if eos_in_logits: + verify_logits["<|endoftext|>"] = -1.0 + model_manager.generate_logits = AsyncMock(return_value=(verify_logits, "prompt")) + + with ( + patch("prompting.rewards.exact_match.MIN_VERIFY_TOKENS", 2), + patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity", return_value=1), + patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains", return_value=1), + ): + reward_model = LogitsRewardModel() + result: BatchRewardOutput = await reward_model.reward( + reference="", + response_event=response_event, + task=task, + model_manager=model_manager, + ) + + assert isinstance(result, BatchRewardOutput) + assert len(result.rewards) == 1 + if expected_penalty is None: + # eos present. + assert result.rewards[0] != PARTIAL_PENALTY + else: + # eos missing. + assert result.rewards[0] == pytest.approx(expected_penalty) + + +def test_verify_logit_similarity(): + """Test the verify_logit_similarity similarity metric.""" + original = {f"token{idx}": -0.01 for idx in range(TOP_LOGPROBS)} + # Identical distributions -> 1.0. + assert LogitsRewardModel.verify_logit_similarity(original, original) == pytest.approx(1.0) + + with patch("prompting.rewards.exact_match.TOP_LOGPROBS", 5): + # Disjoint tokens -> near zero. + disjoint = {"foo": -0.1, "bar": -0.5, "foo1": -1.0, "bar1": -1.5, "foo2": -2.0} + sim = LogitsRewardModel.verify_logit_similarity(original, disjoint) + assert sim == pytest.approx(0.0) + + # Partial overlap -> between 0 and 1. + original = {"token1": -0.1, "token2": -0.5, "token3": -1.0, "foo1": -1.5, "bar1": -2.0} + partial = {"token1": -0.1, "token2": -0.5, "token3": -1.0, "token4": -1.5, "token5": -2.0} + sim2 = LogitsRewardModel.verify_logit_similarity(original, partial) + assert sim2 == pytest.approx(0.6) + + +def test_smooth_reward_scale(): + """Test the smooth_reward_scale function under various conditions.""" + # Test empty timings list. + assert LogitsRewardModel.smooth_timings_reward([]) == 0.0 + + # Test uniform timings (should give maximum reward). + uniform_timings = [0.1, 0.1, 0.1, 0.1, 0.1] + assert LogitsRewardModel.smooth_timings_reward(uniform_timings) == pytest.approx(1.0) + + # Test high variance timings (should give minimum reward). + high_var_timings = [0.1, 0.1, 15.0, 0.1, 0.1] + assert LogitsRewardModel.smooth_timings_reward(high_var_timings) == MIN_SMOOTH_PENALTY_SCALE + + # Test moderate variance timings. + moderate_var_timings = [0.3, 0.2, 0.4, 0.1, 0.1] + assert LogitsRewardModel.smooth_timings_reward(moderate_var_timings) == pytest.approx(1.0) + + # Test with custom minimum reward. + custom_min = 0.8 + assert LogitsRewardModel.smooth_timings_reward(high_var_timings, min_reward=custom_min) == custom_min + + # Test with single timing value. + single_timing = [1.5] + assert LogitsRewardModel.smooth_timings_reward(single_timing) == 1.0 + + +@pytest.mark.parametrize( + "value, min_value, expected", + [ + # Linear mapping. + (0.6, 0.2, (0.6 - 0.2) / (1.0 - 0.2)), + # Below min clips to 0.0. + (0.1, 0.3, 0.0), + # Above max clips to 1.0. + (1.2, 0.0, 1.0), + # At min boundary. + (0.3, 0.3, 0.0), + # At max boundary. + (1.0, 0.3, 1.0), + ], +) +def test_rescale_various_cases(value, min_value, expected): + assert LogitsRewardModel.rescale(value, min_value=min_value) == pytest.approx(expected) + + +@pytest.mark.parametrize( + "values, expected", + [ + # All valid. + ([[0.1, 1.0], [5.0, 0.1], [6.5]], 0.55), + # Mixed values. + ([[-1.0, 0.5], [2.0, 0.1]], 1.05), + # All negative. + ([[-3.0, -0.1], [-2.5]], 1e-6), + # Empty lists. + ([[], []], 1e-6), + # Zeros included. + ([[0.0, -1.0], [0.0]], 0.0), + ], +) +def test_fastest_timing_various_cases(values, expected): + assert LogitsRewardModel.fastest_timing(values) == pytest.approx(expected) + + +@pytest.mark.parametrize( + "completion_length", + [ + 5, + (MIN_VERIFY_TOKENS + MAX_VERIFY_TOKENS) // 2, + MAX_VERIFY_TOKENS, + MAX_VERIFY_TOKENS + 5, + ], +) +def test_sample_verification_indices_properties(completion_length): + indices = LogitsRewardModel.sample_verification_indices(completion_length) + + # Compute expected number of sampled tokens with first and eos indices. + expected_k = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) + + # The result should have expected_k samples plus one EOS index. + assert isinstance(indices, list) + assert len(indices) == expected_k + assert indices == sorted(indices) + assert indices[-1] == completion_length + # All other indices should be in the range [0, completion_length). + sample_indices = indices[:-1] + assert all(0 <= idx < completion_length for idx in sample_indices) + # No duplicates overall. + assert len(set(indices)) == len(indices)