diff --git a/prompting/llms/model_manager.py b/prompting/llms/model_manager.py index abff765e7..ed1fd107d 100644 --- a/prompting/llms/model_manager.py +++ b/prompting/llms/model_manager.py @@ -233,7 +233,7 @@ async def cleanup(self): class AsyncModelScheduler(AsyncLoopRunner): llm_model_manager: ModelManager mp_lock: AcquirerProxy - interval: int = 1200 + interval: int = 3600 scoring_queue: list | None = None memory_error: MemoryError | None = None diff --git a/prompting/rewards/exact_match.py b/prompting/rewards/exact_match.py index b1b21c66a..b97314193 100644 --- a/prompting/rewards/exact_match.py +++ b/prompting/rewards/exact_match.py @@ -14,14 +14,14 @@ shared_settings = settings.shared_settings +MIN_VERIFY_TOKENS = 10 +MAX_VERIFY_TOKENS = 30 NO_EOS_PENALTY = -0.1 INCORRECT_PENALTY = -2 MIN_SMOOTH_PENALTY_SCALE = 0.6 MIN_TIME_PENALTY_SCALE = 0.3 -VERIFICATION_THRESH_CONTAINS = 0.92 -VERIFICATION_THRESH_SIM = 0.90 -MIN_VERIFY_TOKENS = 10 -MAX_VERIFY_TOKENS = 30 +VERIFICATION_THRESH_CONTAINS = 0.96 +VERIFICATION_THRESH_SIM = 0.86 class LogitsRewardModel(BaseRewardModel): @@ -93,11 +93,15 @@ async def reward( # noqa: C901 scores_contains: list[float] = [] for idx in verify_indices: check_idx = min(idx, completion_length) + messages = task.task_messages.copy() + to_complete = "".join(chunks[:check_idx]) + if to_complete: + messages.extend([{"role": "assistant", "content": to_complete}]) verification_logits, _ = await model_manager.generate_logits( model=task.llm_model_id, - messages=task.task_messages + [{"role": "assistant", "content": "".join(chunks[:check_idx])}], + messages=messages, sampling_params=sampling_parameters, - continue_last_message=True, + continue_last_message=len(to_complete) > 0, ) if check_idx < eos_idx: if not chunk_dicts_raw[check_idx].choices[0].logprobs: @@ -126,7 +130,6 @@ async def reward( # noqa: C901 score_sim_mean = float(np.mean(scores_sim)) score_contains_mean = float(np.mean(scores_contains)) - logger.debug(f"Scores: {score_sim_mean}; {score_contains_mean}") if score_sim_mean < VERIFICATION_THRESH_SIM: raise ValueError(f"Logits similarity mean score is below threshold: {score_sim_mean:.2f}") @@ -181,9 +184,12 @@ async def reward( # noqa: C901 @staticmethod def sample_verification_indices(completion_length: int) -> list[int]: - """Sample random indices for verification, always add eos_token index.""" - num_verify = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) - verify_indices = random.sample(range(completion_length), num_verify) + """Sample random indices for verification, always add 0 and eos_token index.""" + # Sample indices without first and last index. + num_verify = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) - 2 + verify_indices = random.sample(range(1, completion_length - 1), num_verify) + # Add first index. + verify_indices.append(0) # Add eos_token index. verify_indices.append(completion_length) verify_indices.sort() diff --git a/prompting/tasks/task_registry.py b/prompting/tasks/task_registry.py index 879f07297..c16a4c83b 100644 --- a/prompting/tasks/task_registry.py +++ b/prompting/tasks/task_registry.py @@ -13,7 +13,6 @@ from prompting.tasks.inference import InferenceRewardConfig, InferenceTask from prompting.tasks.MSRv2_task import MSRv2RewardConfig, MSRv2Task from prompting.tasks.programming_task import ProgrammingRewardConfig, ProgrammingTask -from prompting.tasks.qa import QARewardConfig, WebQuestionAnsweringTask from prompting.tasks.web_retrieval import WebRetrievalRewardConfig, WebRetrievalTask from shared.base import BaseDataset @@ -33,27 +32,21 @@ def __hash__(self): class TaskRegistry(BaseModel): task_configs: ClassVar[list[TaskConfig]] = [ TaskConfig(task=MSRv2Task, probability=0.05, datasets=[DDGDataset], reward_model=MSRv2RewardConfig), - TaskConfig( - task=WebQuestionAnsweringTask, - probability=0.05, - datasets=[DDGDataset], - reward_model=QARewardConfig, - ), TaskConfig( task=InferenceTask, - probability=0.40, + probability=0.45, datasets=[SN13Dataset], reward_model=InferenceRewardConfig, ), TaskConfig( task=ProgrammingTask, - probability=0.20, + probability=0.10, datasets=[HuggingFaceGithubDataset], reward_model=ProgrammingRewardConfig, ), TaskConfig( task=WebRetrievalTask, - probability=0.3, + probability=0.40, datasets=[DDGDataset], reward_model=WebRetrievalRewardConfig, ), diff --git a/tests/prompting/rewards/test_exact_match.py b/tests/prompting/rewards/test_exact_match.py index eb670be25..8a40c5042 100644 --- a/tests/prompting/rewards/test_exact_match.py +++ b/tests/prompting/rewards/test_exact_match.py @@ -303,10 +303,10 @@ def test_fastest_timing_various_cases(values, expected): def test_sample_verification_indices_properties(completion_length): indices = LogitsRewardModel.sample_verification_indices(completion_length) - # Compute expected number of sampled tokens (before adding EOS) - expected_k = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) + 1 + # Compute expected number of sampled tokens with first and eos indices. + expected_k = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) - # The result should have expected_k samples plus one EOS index + # The result should have expected_k samples plus one EOS index. assert isinstance(indices, list) assert len(indices) == expected_k assert indices == sorted(indices)